Add jwt-proxy functionality
[nonrtric.git] / service-exposure / utils / generatejwt / generatejwt.go
1 // -
2 //   ========================LICENSE_START=================================
3 //   O-RAN-SC
4 //   %%
5 //   Copyright (C) 2022: Nordix Foundation
6 //   %%
7 //   Licensed under the Apache License, Version 2.0 (the "License");
8 //   you may not use this file except in compliance with the License.
9 //   You may obtain a copy of the License at
10 //
11 //        http://www.apache.org/licenses/LICENSE-2.0
12 //
13 //   Unless required by applicable law or agreed to in writing, software
14 //   distributed under the License is distributed on an "AS IS" BASIS,
15 //   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 //   See the License for the specific language governing permissions and
17 //   limitations under the License.
18 //   ========================LICENSE_END===================================
19 //
20 package generatejwt
21
22 import (
23         "crypto/rsa"
24         "crypto/x509"
25         "encoding/pem"
26         "fmt"
27         "github.com/dgrijalva/jwt-go"
28         "io/ioutil"
29         "log"
30         "time"
31 )
32
33 type JWT struct {
34         privateKey []byte
35         publicKey  []byte
36 }
37
38 func NewJWT(privateKey []byte, publicKey []byte) JWT {
39         return JWT{
40                 privateKey: privateKey,
41                 publicKey:  publicKey,
42         }
43 }
44
45 func readFile(file string) []byte {
46         key, err := ioutil.ReadFile(file)
47         if err != nil {
48                 log.Fatalln(err)
49         }
50         return key
51 }
52
53 func (j JWT) createWithKey(ttl time.Duration, content interface{}, client, realm string) (string, error) {
54         key, err := jwt.ParseRSAPrivateKeyFromPEM(j.privateKey)
55         if err != nil {
56                 return "", fmt.Errorf("create: parse key: %w", err)
57         }
58
59         now := time.Now().UTC()
60
61         claims := make(jwt.MapClaims)
62         claims["dat"] = content             // Our custom data.
63         claims["exp"] = now.Add(ttl).Unix() // The expiration time after which the token must be disregarded.
64         claims["iat"] = now.Unix()          // The time at which the token was issued.
65         claims["nbf"] = now.Unix()          // The time before which the token must be disregarded.
66         claims["jti"] = "myJWTId" + fmt.Sprint(now.UnixNano())
67         claims["sub"] = client
68         claims["iss"] = client
69         claims["aud"] = realm
70
71         token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
72         tokenString, err := token.SignedString(key)
73         if err != nil {
74                 return "", fmt.Errorf("create: sign token: %w", err)
75         }
76
77         return tokenString, nil
78 }
79
80 func createWithSecret(ttl time.Duration, content interface{}, client, realm, secret string) (string, error) {
81         now := time.Now().UTC()
82
83         claims := make(jwt.MapClaims)
84         claims["dat"] = content             // Our custom data.
85         claims["exp"] = now.Add(ttl).Unix() // The expiration time after which the token must be disregarded.
86         claims["iat"] = now.Unix()          // The time at which the token was issued.
87         claims["nbf"] = now.Unix()          // The time before which the token must be disregarded.
88         claims["jti"] = "myJWTId" + fmt.Sprint(now.UnixNano())
89         claims["sub"] = client
90         claims["iss"] = client
91         claims["aud"] = realm
92
93         token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(secret))
94         if err != nil {
95                 return "", fmt.Errorf("create: sign token: %w", err)
96         }
97
98         return token, nil
99 }
100
101 func (j JWT) Validate(token string) (interface{}, error) {
102         key, err := jwt.ParseRSAPublicKeyFromPEM(j.publicKey)
103         if err != nil {
104                 return "", fmt.Errorf("validate: parse key: %w", err)
105         }
106
107         tok, err := jwt.Parse(token, func(jwtToken *jwt.Token) (interface{}, error) {
108                 if _, ok := jwtToken.Method.(*jwt.SigningMethodRSA); !ok {
109                         return nil, fmt.Errorf("unexpected method: %s", jwtToken.Header["alg"])
110                 }
111
112                 return key, nil
113         })
114         if err != nil {
115                 return nil, fmt.Errorf("validate: %w", err)
116         }
117
118         claims, ok := tok.Claims.(jwt.MapClaims)
119         if !ok || !tok.Valid {
120                 return nil, fmt.Errorf("validate: invalid")
121         }
122
123         return claims["dat"], nil
124 }
125
126 func createPublicKeyFromPrivateKey(privkey_bytes []byte) []byte {
127         block, _ := pem.Decode([]byte(privkey_bytes))
128         var privateKey *rsa.PrivateKey
129         pkcs1, err := x509.ParsePKCS1PrivateKey(block.Bytes)
130         if err != nil {
131                 pkcs8, err := x509.ParsePKCS8PrivateKey(block.Bytes)
132                 privateKey = pkcs8.(*rsa.PrivateKey)
133                 if err != nil {
134                         log.Fatal(err)
135                 }
136         } else {
137                 privateKey = pkcs1
138         }
139
140         publicKey := &privateKey.PublicKey
141
142         pubkey_bytes, err := x509.MarshalPKIXPublicKey(publicKey)
143         if err != nil {
144                 log.Fatal(err)
145         }
146
147         pubkey_pem := pem.EncodeToMemory(
148                 &pem.Block{
149                         Type:  "PUBLIC KEY",
150                         Bytes: pubkey_bytes,
151                 },
152         )
153         return pubkey_pem
154 }
155
156 func CreateJWT(privateKeyFile, secret, client, realm string) string {
157         if secret == "" {
158                 prvKey := readFile(privateKeyFile)
159                 pubKey := createPublicKeyFromPrivateKey(prvKey)
160
161                 jwtToken := NewJWT(prvKey, pubKey)
162
163                 // 1. Create a new JWT token.
164                 tok, err := jwtToken.createWithKey(time.Hour, "Can be anything", client, realm)
165                 if err != nil {
166                         log.Fatalln(err)
167                 }
168
169                 // 2. Validate an existing JWT token.
170                 _, err = jwtToken.Validate(tok)
171                 if err != nil {
172                         log.Fatalln(err)
173                 }
174                 return tok
175         } else {
176                 // 1. Create a new JWT token.
177                 tok, err := createWithSecret(time.Hour, "Can be anything", client, realm, secret)
178                 if err != nil {
179                         log.Fatalln(err)
180                 }
181                 return tok
182         }
183
184 }