Add oAuth2 for subscription and registration with SMO
[pti/o2.git] / o2common / service / command / handler.py
1 # Copyright (C) 2022 Wind River Systems, Inc.
2 #
3 #  Licensed under the Apache License, Version 2.0 (the "License");
4 #  you may not use this file except in compliance with the License.
5 #  You may obtain a copy of the License at
6 #
7 #      http://www.apache.org/licenses/LICENSE-2.0
8 #
9 #  Unless required by applicable law or agreed to in writing, software
10 #  distributed under the License is distributed on an "AS IS" BASIS,
11 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 #  See the License for the specific language governing permissions and
13 #  limitations under the License.
14
15 import os
16 import requests
17 import json
18 import http.client
19 import ssl
20 from requests_oauthlib import OAuth2Session
21 from oauthlib.oauth2 import LegacyApplicationClient
22 from requests.packages.urllib3.util.retry import Retry
23 from requests.adapters import HTTPAdapter
24 from requests.exceptions import RequestException, SSLError
25
26 from o2common.helper import o2logging
27 from o2common.config import config
28
29 logger = o2logging.get_logger(__name__)
30
31
32 def get_http_conn(callbackurl):
33     conn = http.client.HTTPConnection(callbackurl)
34     return conn
35
36
37 # with default CA
38 def get_https_conn_default(callbackurl):
39     sslctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
40     sslctx.check_hostname = True
41     sslctx.verify_mode = ssl.CERT_REQUIRED
42     sslctx.load_default_certs()
43     conn = http.client.HTTPSConnection(callbackurl, context=sslctx)
44     return conn
45
46
47 # with self signed ca
48 def get_https_conn_selfsigned(callbackurl):
49     sslctx = ssl.create_default_context(
50         purpose=ssl.Purpose.SERVER_AUTH)
51     smo_ca_path = config.get_smo_ca_config_path()
52     sslctx.load_verify_locations(smo_ca_path)
53     sslctx.check_hostname = False
54     sslctx.verify_mode = ssl.CERT_REQUIRED
55     conn = http.client.HTTPSConnection(callbackurl, context=sslctx)
56     return conn
57
58
59 class SMOClient:
60     def __init__(self, client_id=None, token_url=None, username=None,
61                  password=None, scope=None, retries=3, use_oauth=False):
62         self.client_id = client_id
63         self.token_url = token_url
64         self.username = username
65         self.password = password
66         self.scope = scope if scope else []
67         self.use_oauth = use_oauth
68         self.retries = retries
69
70         if self.use_oauth:
71             if not all([self.client_id, self.token_url, self.username,
72                         self.password]):
73                 raise ValueError(
74                     'client_id, token_url, username, and password ' +
75                     'must be provided when use_oauth is True.')
76
77             # Set OAUTHLIB_INSECURE_TRANSPORT environment variable
78             # if token_url uses http
79             if 'http://' in self.token_url:
80                 os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = '1'
81
82             # Create a LegacyApplicationClient for handling password flow
83             client = LegacyApplicationClient(client_id=self.client_id)
84             self.session = OAuth2Session(client=client)
85
86             # Check if token_url uses https and set SSL verification
87             if 'https://' in self.token_url:
88                 ca_path = config.get_smo_ca_config_path()
89                 if os.path.exists(ca_path):
90                     self.session.verify = ca_path
91                 else:
92                     self.session.verify = True
93
94             # Fetch the access token
95             self.fetch_token(self.session.verify)
96         else:
97             self.session = requests.Session()
98
99         # Create a Retry object for handling retries
100         retry_strategy = Retry(
101             total=retries,
102             backoff_factor=1,
103             status_forcelist=[429, 500, 502, 503, 504],
104             allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
105         )
106         adapter = HTTPAdapter(max_retries=retry_strategy)
107         self.session.mount("https://", adapter)
108         self.session.mount("http://", adapter)
109
110     def fetch_token(self, verify):
111         try:
112             self.session.fetch_token(
113                 token_url=self.token_url,
114                 username=self.username,
115                 password=self.password,
116                 client_id=self.client_id,
117                 verify=verify
118             )
119         except SSLError:
120             # If SSLError is raised, try again with verify=False
121             logger.warning('The SSLError occurred')
122             if verify is not False:
123                 self.fetch_token(verify=False)
124
125     def handle_post_data(self, resp):
126         if resp.status_code >= 200 and resp.status_code < 300:
127             return True
128         logger.error('Response code is: {}'.format(resp.status_code))
129         # TODO: write the status to extension db table.
130         return False
131
132     def post(self, url, data, retries=1):
133         if not all([url, data]):
134             raise ValueError(
135                 'url, data must be provided when call the post.')
136
137         # Check if token_url uses https and set SSL verification
138         if 'https://' in url:
139             ca_path = config.get_smo_ca_config_path()
140             if os.path.exists(ca_path):
141                 self.session.verify = ca_path
142             else:
143                 self.session.verify = True
144
145         if retries is None:
146             retries = self.retries
147
148         for _ in range(retries):
149             try:
150                 response = self.session.post(
151                     url, data=json.dumps(data))
152                 response.raise_for_status()
153                 return self.handle_post_data(response)
154             except (SSLError, RequestException) as e:
155                 logger.warning(f'Error occurred: {e}')
156                 pass
157         raise Exception(
158             f"POST request to {url} failed after {retries} retries.")