From: SANDEEP KUMAR JAISAWAL Date: Mon, 17 Oct 2022 05:35:42 +0000 (+0530) Subject: Common files for training manager X-Git-Tag: 1.0.0~31 X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=commitdiff_plain;h=9ffbb7c1e5433b8e0911ff3f17a911dfc5375daa;p=aiml-fw%2Fawmf%2Ftm.git Common files for training manager Issue-Id: AIMLFW-2 Signed-off-by: SANDEEP KUMAR JAISAWAL Change-Id: Id04477e324428c9241a9e700cfed07ff0c490a0e --- diff --git a/README.md b/README.md new file mode 100644 index 0000000..1db81c3 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +This folder contains all files realted to training manager. +#To install training maager as package +pip3 install . \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..226af0c --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +from setuptools import setup, find_packages + +setup( + name="trainingmgr", + version="0.1", + packages=find_packages(exclude=["tests.*", "tests"]), + author='SANDEEP KUMAR JAISAWAL', + author_email='s.jaisawal@samsung.com', + description="AIMLFW Training manager", + url="https://gerrit.o-ran-sc.org/r/admin/repos/aiml-fw/awmf/tm,general", + keywords="AIMLWF TM", + license="Apache 2.0", +) \ No newline at end of file diff --git a/trainingmgr/__init__.py b/trainingmgr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trainingmgr/common/__init__.py b/trainingmgr/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trainingmgr/common/conf_log.yaml b/trainingmgr/common/conf_log.yaml new file mode 100644 index 0000000..957c663 --- /dev/null +++ b/trainingmgr/common/conf_log.yaml @@ -0,0 +1,38 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== +version: 1 +formatters: + simple: + format: '%(asctime)s | %(filename)s %(lineno)s %(funcName)s() | %(levelname)s | %(message)s' +handlers: + console: + class: logging.StreamHandler + level: DEBUG + formatter: simple + stream: ext://sys.stdout + access_file: + class: logging.handlers.RotatingFileHandler + level: DEBUG + formatter: simple + filename: /var/log/training_manager.log + maxBytes: 10485760 + backupCount: 20 + encoding: utf8 +root: + level: DEBUG + handlers: [access_file,console] \ No newline at end of file diff --git a/trainingmgr/common/exceptions_utls.py b/trainingmgr/common/exceptions_utls.py new file mode 100644 index 0000000..948edd5 --- /dev/null +++ b/trainingmgr/common/exceptions_utls.py @@ -0,0 +1,90 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +class APIException(Exception): + """ + A class used to represent an Api exception + + Attributes + ---------- + message : str + a formatted string to print out what is exception + code : int + http status code + """ + + def __init__(self, code, message="exception occured"): + """ + Parameters + ---------- + message : str + a formatted string to print out what is exception + code : int + http statuse code + + """ + + self.code = code + self.message = message + super().__init__(self.message) + +class TMException(Exception): + """ + A class used to represent an Training Manager exception + + Attributes + ---------- + message : str + a formatted string to print out what is exception + """ + + def __init__(self, message="TM exception occured"): + """ + Parameters + ---------- + message : str + a formatted string to print out what is exception + code : int + http statuse code + + """ + self.message = message + super().__init__(self.message) + +class DBException(Exception): + """ + A class used to represent an DB related exception + + Attributes + ---------- + message : str + a formatted string to print out what is exception + """ + + def __init__(self, message="DB exception occured"): + """ + Parameters + ---------- + message : str + a formatted string to print out what is exception + code : int + http statuse code + + """ + self.message = message + super().__init__(self.message) diff --git a/trainingmgr/common/tmgr_logger.py b/trainingmgr/common/tmgr_logger.py new file mode 100644 index 0000000..82f6c27 --- /dev/null +++ b/trainingmgr/common/tmgr_logger.py @@ -0,0 +1,65 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +#!/usr/bin/python3 + +"""tmgr_logger.py +This module is for Initializing Logger Framework +""" + +import logging +import logging.config +import yaml + + +class TMLogger(object):# pylint: disable=too-few-public-methods + """ + This is a class for initiliazing logger configuration for TMLogger + Attributes: None + """ + + def __init__(self, conf_file): + """ + The constructor for TMLogger class. + Parameters:None + """ + + try: + with open(conf_file, 'r') as file: + log_config = yaml.safe_load(file.read()) + logging.config.dictConfig(log_config) + self.LogLevel = log_config["root"]["level"] + self.logger = logging.getLogger(__name__) + except FileNotFoundError as err: + print("error opening yaml config file") + print(err) + + @property + def get_logger(self): + """ + Function for giving logger instance to the caller of the function + Args:None + Returns: + logger: logger handle to be used in other modules + """ + return self.logger + + @property + def get_logLevel(self): + return self.LogLevel + \ No newline at end of file diff --git a/trainingmgr/common/trainingmgr_config.py b/trainingmgr/common/trainingmgr_config.py new file mode 100644 index 0000000..4b98958 --- /dev/null +++ b/trainingmgr/common/trainingmgr_config.py @@ -0,0 +1,199 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +""" +This module is for loading training manager configuration. +""" + +from os import getenv +from trainingmgr.common.tmgr_logger import TMLogger + + +class TrainingMgrConfig: + """ + This class conatains method for getting configuration varibles. + """ + + def __init__(self): + """ + This constructor filling configuration varibles. + """ + self.__kf_adapter_port = getenv('KF_ADAPTER_PORT').rstrip() + self.__kf_adapter_ip = getenv('KF_ADAPTER_IP').rstrip() + + self.__data_extraction_port = getenv('DATA_EXTRACTION_API_PORT').rstrip() + self.__data_extraction_ip = getenv('DATA_EXTRACTION_API_IP').rstrip() + + self.__my_port = getenv('TRAINING_MANAGER_PORT').rstrip() + self.__my_ip = getenv('TRAINING_MANAGER_IP').rstrip() + + self.__ps_user = getenv('PS_USER').rstrip() + self.__ps_password = getenv('PS_PASSWORD').rstrip() + self.__ps_ip = getenv('PS_IP').rstrip() + self.__ps_port = getenv('PS_PORT').rstrip() + + self.tmgr_logger = TMLogger("common/conf_log.yaml") + self.__logger = self.tmgr_logger.logger + + @property + def kf_adapter_port(self): + """ + Function for getting port number where kf adapter is accessible + + Args:None + + Returns: + port number where kf adapter is accessible + """ + return self.__kf_adapter_port + + @property + def kf_adapter_ip(self): + """ + Function for getting ip address or service name where kf adapter is accessible + + Args:None + + Returns: + ip address or service name where kf adapter is accessible + """ + return self.__kf_adapter_ip + + @property + def data_extraction_port(self): + """ + Function for getting port number where data extraction module is accessible + + Args:None + + Returns: + port number where data extraction module is accessible + """ + return self.__data_extraction_port + + @property + def data_extraction_ip(self): + """ + Function for getting ip address or service name where data extraction module is accessible + + Args:None + + Returns: + ip address or service name where data extraction module is accessible + """ + return self.__data_extraction_ip + + @property + def my_port(self): + """ + Function for getting port number where training manager is running + + Args:None + + Returns: + port number where training manager is running + """ + return self.__my_port + + @property + def my_ip(self): + """ + Function for getting ip address where training manager is running + + Args:None + + Returns: + ip address where training manager is running + """ + return self.__my_ip + + @property + def logger(self): + """ + Function for getting logger instance. + + Args:None + + Returns: + logger instance. + """ + return self.__logger + + @property + def ps_user(self): + """ + Function for getting postgres db's user. + + Args:None + + Returns: + postgres db's user. + """ + return self.__ps_user + + @property + def ps_password(self): + """ + Function for getting postgres db's password. + + Args:None + + Returns: + postgres db's password. + """ + return self.__ps_password + + @property + def ps_ip(self): + """ + Function for getting ip address or service name where postgres db is accessible + + Args:None + + Returns: + ip address or service name where postgres db is accessible + """ + return self.__ps_ip + + @property + def ps_port(self): + """ + Function for getting port number where postgres db is accessible + + Args:None + + Returns: + port number where postgres db is accessible + """ + return self.__ps_port + + def is_config_loaded_properly(self): + """ + This function checks where all environment variable got value or not. + if all environment variables got value then function returns True + otherwise it return False. + """ + all_present = True + + for var in [self.__kf_adapter_ip, self.__kf_adapter_port, + self.__data_extraction_ip, self.__data_extraction_port, + self.__my_port, self.__ps_ip, self.__ps_port, self.__ps_user, + self.__ps_password, self.__my_ip, self.__logger]: + if var is None: + all_present = False + return all_present diff --git a/trainingmgr/common/trainingmgr_operations.py b/trainingmgr/common/trainingmgr_operations.py new file mode 100644 index 0000000..29a626b --- /dev/null +++ b/trainingmgr/common/trainingmgr_operations.py @@ -0,0 +1,111 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +"""" +Training manager main operations +. +""" + +import json +import requests + +def data_extraction_start(training_config_obj, trainingjob_name, feature_list, query_filter, + datalake_source, _measurement, bucket): + """ + This function calls data extraction module for data extraction of trainingjob_name training and + returns response which we is gotten by calling data extraction module. + """ + logger = training_config_obj.logger + logger.debug('training manager is calling data extraction for '+trainingjob_name) + data_extraction_ip = training_config_obj.data_extraction_ip + data_extraction_port = training_config_obj.data_extraction_port + url = 'http://'+str(data_extraction_ip)+':'+str(data_extraction_port)+'/feature-groups' + logger.debug(url) + + source = {} + source['source'] = datalake_source + if 'InfluxSource' in datalake_source: + source['source']['InfluxSource']['query']='''from(bucket:"'''+\ + bucket + '''") |> '''+\ + '''range(start: 0, stop: now()) '''+\ + '''|> filter(fn: (r) => r._measurement == "'''+\ + _measurement + '''") '''+\ + '''|> pivot(rowKey:["_time"], '''+\ + '''columnKey: ["_field"], '''+\ + '''valueColumn: "_value")''' + + transform = {} + transform['transform'] = [] + transform_inner_dic = {} + transform_inner_dic['operation'] = "SQLTransform" + transform_inner_dic['FeatureList'] = feature_list + transform_inner_dic['SQLFilter'] = query_filter + transform['transform'].append(transform_inner_dic) + + sink = {} + sink_inner_dic = {} + sink_inner_dic['CollectionName'] = trainingjob_name + sink['CassandraSink'] = sink_inner_dic + + dictionary = {} + dictionary.update(source) + dictionary.update(transform) + dictionary['sink'] = sink + + logger.debug(json.dumps(dictionary)) + + response = requests.post(url, + data=json.dumps(dictionary), + headers={'content-type': 'application/json', + 'Accept-Charset': 'UTF-8'}) + return response + +def data_extraction_status(trainingjob_name,training_config_obj): + """ + This function calls data extraction module for getting data extraction status of + trainingjob_name training and returns it. + """ + logger = training_config_obj.logger + logger.debug('training manager is calling data extraction for '+trainingjob_name) + data_extraction_ip = training_config_obj.data_extraction_ip + data_extraction_port = training_config_obj.data_extraction_port + url = 'http://'+str(data_extraction_ip)+':'+str(data_extraction_port)+\ + '/task-status/'+trainingjob_name + logger.debug(url) + response = requests.get(url) + return response + +def training_start(training_config_obj, dict_data, trainingjob_name): + """ + This function calls kf_adapter module to start pipeline of trainingjob_name training and returns + response which is gotten by calling kf adapter module. + """ + logger = training_config_obj.logger + logger.debug('training manager is calling kf_adapter for pipeline run for '+trainingjob_name) + logger.debug('training manager will send to kf_adapter: '+json.dumps(dict_data)) + kf_adapter_ip = training_config_obj.kf_adapter_ip + kf_adapter_port = training_config_obj.kf_adapter_port + url = 'http://'+str(kf_adapter_ip)+':'+str(kf_adapter_port)+\ + '/trainingjobs/' + trainingjob_name + '/execution' + logger.debug(url) + response = requests.post(url, + data=json.dumps(dict_data), + headers={'content-type': 'application/json', + 'Accept-Charset': 'UTF-8'}) + + return response diff --git a/trainingmgr/common/trainingmgr_util.py b/trainingmgr/common/trainingmgr_util.py new file mode 100644 index 0000000..f758f56 --- /dev/null +++ b/trainingmgr/common/trainingmgr_util.py @@ -0,0 +1,223 @@ +# ================================================================================== +# +# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ================================================================================== + +"""" +This file contains Training management utility functions +""" +import json +from flask_api import status +import requests +from trainingmgr.db.common_db_fun import change_in_progress_to_failed_by_latest_version, \ + get_field_by_latest_version, change_field_of_latest_version, \ + get_latest_version_trainingjob_name,get_all_versions_info_by_name + +from trainingmgr.constants.states import States +from trainingmgr.common.exceptions_utls import APIException,TMException,DBException + +def response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk): + """ + Post training job completion,this function provides notifications to the subscribers, + who subscribed for the result of training job and provided a notification url during + training job creation. + returns tuple containing result dictionary and status code. + """ + logger.debug("Training job result: " + str(code) + " " + message + " " + str(is_success)) + + try : + #TODO DB query optimization, all data to fetch in one call + notif_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "notification_url") + if notif_url_result : + notification_url = notif_url_result[0][0] + model_url_result = None + if notification_url != '': + model_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "model_url") + model_url = model_url_result[0][0] + version = get_latest_version_trainingjob_name(trainingjob_name, ps_db_obj) + metrics = get_metrics(trainingjob_name, version, mm_sdk) + + req_json = None + if is_success: + req_json = { + "result": "success", "model_url": model_url, + "trainingjob_name": trainingjob_name, "metrics": metrics + } + else: + req_json = {"result": "failed", "trainingjob_name": trainingjob_name} + + response = requests.post(notification_url, + data=json.dumps(req_json), + headers={ + 'content-type': 'application/json', + 'Accept-Charset': 'UTF-8' + }) + if ( response.headers['content-type'] != "application/json" + or response.status_code != status.HTTP_200_OK ): + err_msg = "Failed to notify the subscribed url " + trainingjob_name + raise TMException(err_msg) + except Exception as err: + change_in_progress_to_failed_by_latest_version(trainingjob_name, ps_db_obj) + raise APIException(status.HTTP_500_INTERNAL_SERVER_ERROR, + str(err) + "(trainingjob name is " + trainingjob_name + ")") from None + if is_success: + return {"result": message}, code + return {"Exception": message}, code + + +def check_key_in_dictionary(fields, dictionary): + ''' + This function raises exception if any string from fields list does not present in a dictionary + as a key + ''' + isKeyAvailable = True + for field_name in fields: + if field_name not in dictionary: + isKeyAvailable = False + break + #Log (field_name + " not provide") + return isKeyAvailable + +def get_one_word_status(steps_state): + """ + This function converts steps_state to one word status(we call it overall_status also) + and return it. + """ + failed_count = 0 + finished_count = 0 + not_started_count = 0 + in_progress_count = 0 + for step in steps_state: + if steps_state[step] == States.FAILED.name: + failed_count = failed_count + 1 + elif steps_state[step] == States.FINISHED.name: + finished_count = finished_count + 1 + elif steps_state[step] == States.NOT_STARTED.name: + not_started_count = not_started_count + 1 + else: + in_progress_count = in_progress_count + 1 + if failed_count > 0: + return States.FAILED.name + if not_started_count == len(steps_state): + return States.NOT_STARTED.name + if finished_count == len(steps_state): + return States.FINISHED.name + return States.IN_PROGRESS.name + + +def check_trainingjob_data(trainingjob_name, json_data): + """ + This function checks validation for json_data dictionary and return tuple which conatins + values of different keys in jsn_data. + """ + try: + if check_key_in_dictionary(["feature_list", "pipeline_version", \ + "pipeline_name", "experiment_name", + "arguments", "enable_versioning", + "datalake_source", "description", + "query_filter", "_measurement", + "bucket"], json_data): + + description = json_data["description"] + feature_list = json_data["feature_list"] + pipeline_name = json_data["pipeline_name"] + experiment_name = json_data["experiment_name"] + arguments = json_data["arguments"] + + if not isinstance(arguments, dict): + raise TMException("Please pass agruments as dictionary for " + trainingjob_name) + query_filter = json_data["query_filter"] + enable_versioning = json_data["enable_versioning"] + pipeline_version = json_data["pipeline_version"] + datalake_source = json_data["datalake_source"] + _measurement = json_data["_measurement"] + bucket = json_data["bucket"] + else : + raise TMException("check_trainingjob_data- supplied data doesn't have" + \ + "all the required fields ") + except Exception as err: + raise APIException(status.HTTP_400_BAD_REQUEST, + str(err)) from None + return (feature_list, description, pipeline_name, experiment_name, + arguments, query_filter, enable_versioning, pipeline_version, + datalake_source, _measurement, bucket) + + +def get_one_key(dictionary): + ''' + this function finds any one key from dictionary and retuen it. + ''' + only_key = None + for key in dictionary: + only_key = key + return only_key + + +def get_metrics(trainingjob_name, version, mm_sdk): + """ + Download metrics from object database and returns metrics as string if metrics presents, + otherwise returns "No data available" string for trainingjob. + """ + data = None + try: + present = mm_sdk.check_object(trainingjob_name, version, "metrics.json") + if present: + data = json.dumps(mm_sdk.get_metrics(trainingjob_name, version)) + if data is None: + raise Exception("Problem while downloading metrics") + else: + data = "No data available" + except Exception as err: + errMsg = str(err) + raise TMException ( "Problem while downloading metric" + errMsg) + return data + + +def handle_async_feature_engineering_status_exception_case(lock, dataextraction_job_cache, code, + message, logger, is_success, + trainingjob_name, ps_db_obj, mm_sdk): + """ + This function changes IN_PROGRESS state to FAILED state and calls response_for_training function + and remove trainingjob_name from dataextraction_job_cache. + """ + try: + change_in_progress_to_failed_by_latest_version(trainingjob_name, ps_db_obj) + response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk) + except Exception as err: + logger.error("Failed in handle_async_feature_engineering_status_exception_case" + str(err)) + finally: + #Post success/failure handle,process next item from DATAEXTRACTION_JOBS_CACHE + with lock: + try: + dataextraction_job_cache.pop(trainingjob_name) + except KeyError as key_err: + logger.error("The training job key doesn't exist in DATAEXTRACTION_JOBS_CACHE: " + str(key_err)) + +def validate_trainingjob_name(trainingjob_name, ps_db_obj): + """ + This function returns True if given trainingjob_name exists in db otherwise + it returns False. + """ + results = None + isAvailable = False + try: + results = get_all_versions_info_by_name(trainingjob_name, ps_db_obj) + except Exception as err: + errMsg = str(err) + raise DBException("Could not get info from db for " + trainingjob_name + "," + errMsg) + if results: + isAvailable = True + return isAvailable \ No newline at end of file