OAuth2 support
[pti/o2.git] / o2common / authmw / authprov.py
index 11243df..87bbc4e 100644 (file)
 #  limitations under the License.
 
 import ssl
-from o2common.helper import o2logging
 import urllib.request
 import urllib.parse
 import json
-
+from http import HTTPStatus
+from requests import post as requests_post
+from requests.auth import HTTPBasicAuth
+from requests.exceptions import HTTPError
+from jwt import decode as jwt_decode
+from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
+
+from o2common.authmw.exceptions import AuthRequiredExp
+from o2common.authmw.exceptions import AuthFailureExp
 from o2common.config.config import get_auth_provider, get_review_url
 from o2common.config.config import get_reviewer_token
+from o2common.config import conf
+from o2common.helper import o2logging
 
 ssl._create_default_https_context = ssl._create_unverified_context
 logger = o2logging.get_logger(__name__)
 
 
+class OAuthAuthenticationException(Exception):
+    def __init__(self, value):
+        self.value = value
+
+
 class K8SAuthenticaException(Exception):
     def __init__(self, value):
         self.value = value
@@ -45,7 +59,7 @@ class auth_definer():
         if auth_prv_conf == 'k8s':
             self.obj = k8s_auth_provider('k8s')
         else:
-            self.obj = keystone_auth_provider('keystone')
+            self.obj = oauth2_auth_provider('oauth2')
 
     def tokenissue(self):
         return self.obj.tokenissue()
@@ -53,7 +67,6 @@ class auth_definer():
     def sanity_check(self):
         return self.obj.sanity_check()
 
-    # call k8s api
     def authenticate(self, token):
         return self.obj.authenticate(token)
 
@@ -82,6 +95,7 @@ class k8s_auth_provider(auth_definer):
             raise Exception(str(ex))
 
     def authenticate(self, token):
+        ''' Call Kubenetes API to authenticate '''
         reviewer_token = get_reviewer_token()
         tokenreview = {
             "kind": "TokenReview",
@@ -125,18 +139,70 @@ class k8s_auth_provider(auth_definer):
         return True
 
 
-class keystone_auth_provider(auth_definer):
+class oauth2_auth_provider(auth_definer):
     def __init__(self, name):
         self.name = name
 
-    def tokenissue(self, *args1, **args2):
-        pass
+    def _format_public_key(self):
+        public_key_string = """-----BEGIN PUBLIC KEY----- \
+        %s \
+        -----END PUBLIC KEY-----""" % conf.OAUTH2.oauth2_public_key
+        return public_key_string
 
-    def authenticate(self, *args1, **args2):
+    def _verify_jwt_token_introspect(self, token):
+        introspect_endpoint = conf.OAUTH2.oauth2_introspection_endpoint
+        client_id = conf.OAUTH2.oauth2_client_id
+        client_secret = conf.OAUTH2.oauth2_client_secret
+        try:
+            response = requests_post(
+                introspect_endpoint,
+                data={'token': token, 'client_id': client_id},
+                auth=HTTPBasicAuth(client_id, client_secret)
+            )
+        except HTTPError as e:
+            logger.error('OAuth2 jwt token introspect verify failed.')
+            raise Exception(str(e))
+        if response.status_code == HTTPStatus.OK:
+            introspection_data = response.json()
+            if introspection_data.get('active'):
+                logger.info('OAuth2 jwt token introspect result active.')
+                return True
+        logger.info('OAuth2 jwt token introspect verify failed.')
         return False
 
-    def sanity_check(self):
-        pass
+    def _verify_jwt_token(self, token):
+        algorithm = conf.OAUTH2.oauth2_algorithm
+        public_key_string = self._format_public_key()
+        try:
+            options = {"verify_signature": True, "verify_aud": False,
+                       "exp": True}
+            decoded_token = jwt_decode(token, public_key_string,
+                                       algorithms=[algorithm], options=options)
+            logger.info(
+                'Verified Token from client: %s' %
+                decoded_token.get("clientHost"))
+            return True
+        except (ExpiredSignatureError,
+                InvalidTokenError) as e:
+            logger.error(f'OAuth2 jwt token validation failed: {e}')
+            raise AuthFailureExp(
+                'OAuth2 JWT Token Authentication failure.')
+        except Exception as e:
+            raise AuthRequiredExp(str(e))
 
-    def tokenrevoke(self, *args1, **args2):
+    def authenticate(self, token):
+        ''' Call the JWT to authenticate
+
+        If the verify type is introspection, call introspection endpoint to
+        verify the token.
+        If the verify type is jwt, call JWT SDK to verify the token.
+        '''
+        oauth2_verify_type = conf.OAUTH2.oauth2_verify_type
+        if oauth2_verify_type == 'introspection':
+            return self._verify_jwt_token_introspect(token)
+        elif oauth2_verify_type == 'jwt':
+            return self._verify_jwt_token(token)
         return False
+
+    def sanity_check(self):
+        pass