--- /dev/null
+This folder contains all files realted to training manager.
+#To install training maager as package
+pip3 install .
\ No newline at end of file
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+from setuptools import setup, find_packages
+
+setup(
+ name="trainingmgr",
+ version="0.1",
+ packages=find_packages(exclude=["tests.*", "tests"]),
+ author='SANDEEP KUMAR JAISAWAL',
+ author_email='s.jaisawal@samsung.com',
+ description="AIMLFW Training manager",
+ url="https://gerrit.o-ran-sc.org/r/admin/repos/aiml-fw/awmf/tm,general",
+ keywords="AIMLWF TM",
+ license="Apache 2.0",
+)
\ No newline at end of file
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+version: 1
+formatters:
+ simple:
+ format: '%(asctime)s | %(filename)s %(lineno)s %(funcName)s() | %(levelname)s | %(message)s'
+handlers:
+ console:
+ class: logging.StreamHandler
+ level: DEBUG
+ formatter: simple
+ stream: ext://sys.stdout
+ access_file:
+ class: logging.handlers.RotatingFileHandler
+ level: DEBUG
+ formatter: simple
+ filename: /var/log/training_manager.log
+ maxBytes: 10485760
+ backupCount: 20
+ encoding: utf8
+root:
+ level: DEBUG
+ handlers: [access_file,console]
\ No newline at end of file
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+class APIException(Exception):
+ """
+ A class used to represent an Api exception
+
+ Attributes
+ ----------
+ message : str
+ a formatted string to print out what is exception
+ code : int
+ http status code
+ """
+
+ def __init__(self, code, message="exception occured"):
+ """
+ Parameters
+ ----------
+ message : str
+ a formatted string to print out what is exception
+ code : int
+ http statuse code
+
+ """
+
+ self.code = code
+ self.message = message
+ super().__init__(self.message)
+
+class TMException(Exception):
+ """
+ A class used to represent an Training Manager exception
+
+ Attributes
+ ----------
+ message : str
+ a formatted string to print out what is exception
+ """
+
+ def __init__(self, message="TM exception occured"):
+ """
+ Parameters
+ ----------
+ message : str
+ a formatted string to print out what is exception
+ code : int
+ http statuse code
+
+ """
+ self.message = message
+ super().__init__(self.message)
+
+class DBException(Exception):
+ """
+ A class used to represent an DB related exception
+
+ Attributes
+ ----------
+ message : str
+ a formatted string to print out what is exception
+ """
+
+ def __init__(self, message="DB exception occured"):
+ """
+ Parameters
+ ----------
+ message : str
+ a formatted string to print out what is exception
+ code : int
+ http statuse code
+
+ """
+ self.message = message
+ super().__init__(self.message)
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+#!/usr/bin/python3
+
+"""tmgr_logger.py
+This module is for Initializing Logger Framework
+"""
+
+import logging
+import logging.config
+import yaml
+
+
+class TMLogger(object):# pylint: disable=too-few-public-methods
+ """
+ This is a class for initiliazing logger configuration for TMLogger
+ Attributes: None
+ """
+
+ def __init__(self, conf_file):
+ """
+ The constructor for TMLogger class.
+ Parameters:None
+ """
+
+ try:
+ with open(conf_file, 'r') as file:
+ log_config = yaml.safe_load(file.read())
+ logging.config.dictConfig(log_config)
+ self.LogLevel = log_config["root"]["level"]
+ self.logger = logging.getLogger(__name__)
+ except FileNotFoundError as err:
+ print("error opening yaml config file")
+ print(err)
+
+ @property
+ def get_logger(self):
+ """
+ Function for giving logger instance to the caller of the function
+ Args:None
+ Returns:
+ logger: logger handle to be used in other modules
+ """
+ return self.logger
+
+ @property
+ def get_logLevel(self):
+ return self.LogLevel
+
\ No newline at end of file
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+"""
+This module is for loading training manager configuration.
+"""
+
+from os import getenv
+from trainingmgr.common.tmgr_logger import TMLogger
+
+
+class TrainingMgrConfig:
+ """
+ This class conatains method for getting configuration varibles.
+ """
+
+ def __init__(self):
+ """
+ This constructor filling configuration varibles.
+ """
+ self.__kf_adapter_port = getenv('KF_ADAPTER_PORT').rstrip()
+ self.__kf_adapter_ip = getenv('KF_ADAPTER_IP').rstrip()
+
+ self.__data_extraction_port = getenv('DATA_EXTRACTION_API_PORT').rstrip()
+ self.__data_extraction_ip = getenv('DATA_EXTRACTION_API_IP').rstrip()
+
+ self.__my_port = getenv('TRAINING_MANAGER_PORT').rstrip()
+ self.__my_ip = getenv('TRAINING_MANAGER_IP').rstrip()
+
+ self.__ps_user = getenv('PS_USER').rstrip()
+ self.__ps_password = getenv('PS_PASSWORD').rstrip()
+ self.__ps_ip = getenv('PS_IP').rstrip()
+ self.__ps_port = getenv('PS_PORT').rstrip()
+
+ self.tmgr_logger = TMLogger("common/conf_log.yaml")
+ self.__logger = self.tmgr_logger.logger
+
+ @property
+ def kf_adapter_port(self):
+ """
+ Function for getting port number where kf adapter is accessible
+
+ Args:None
+
+ Returns:
+ port number where kf adapter is accessible
+ """
+ return self.__kf_adapter_port
+
+ @property
+ def kf_adapter_ip(self):
+ """
+ Function for getting ip address or service name where kf adapter is accessible
+
+ Args:None
+
+ Returns:
+ ip address or service name where kf adapter is accessible
+ """
+ return self.__kf_adapter_ip
+
+ @property
+ def data_extraction_port(self):
+ """
+ Function for getting port number where data extraction module is accessible
+
+ Args:None
+
+ Returns:
+ port number where data extraction module is accessible
+ """
+ return self.__data_extraction_port
+
+ @property
+ def data_extraction_ip(self):
+ """
+ Function for getting ip address or service name where data extraction module is accessible
+
+ Args:None
+
+ Returns:
+ ip address or service name where data extraction module is accessible
+ """
+ return self.__data_extraction_ip
+
+ @property
+ def my_port(self):
+ """
+ Function for getting port number where training manager is running
+
+ Args:None
+
+ Returns:
+ port number where training manager is running
+ """
+ return self.__my_port
+
+ @property
+ def my_ip(self):
+ """
+ Function for getting ip address where training manager is running
+
+ Args:None
+
+ Returns:
+ ip address where training manager is running
+ """
+ return self.__my_ip
+
+ @property
+ def logger(self):
+ """
+ Function for getting logger instance.
+
+ Args:None
+
+ Returns:
+ logger instance.
+ """
+ return self.__logger
+
+ @property
+ def ps_user(self):
+ """
+ Function for getting postgres db's user.
+
+ Args:None
+
+ Returns:
+ postgres db's user.
+ """
+ return self.__ps_user
+
+ @property
+ def ps_password(self):
+ """
+ Function for getting postgres db's password.
+
+ Args:None
+
+ Returns:
+ postgres db's password.
+ """
+ return self.__ps_password
+
+ @property
+ def ps_ip(self):
+ """
+ Function for getting ip address or service name where postgres db is accessible
+
+ Args:None
+
+ Returns:
+ ip address or service name where postgres db is accessible
+ """
+ return self.__ps_ip
+
+ @property
+ def ps_port(self):
+ """
+ Function for getting port number where postgres db is accessible
+
+ Args:None
+
+ Returns:
+ port number where postgres db is accessible
+ """
+ return self.__ps_port
+
+ def is_config_loaded_properly(self):
+ """
+ This function checks where all environment variable got value or not.
+ if all environment variables got value then function returns True
+ otherwise it return False.
+ """
+ all_present = True
+
+ for var in [self.__kf_adapter_ip, self.__kf_adapter_port,
+ self.__data_extraction_ip, self.__data_extraction_port,
+ self.__my_port, self.__ps_ip, self.__ps_port, self.__ps_user,
+ self.__ps_password, self.__my_ip, self.__logger]:
+ if var is None:
+ all_present = False
+ return all_present
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+""""
+Training manager main operations
+.
+"""
+
+import json
+import requests
+
+def data_extraction_start(training_config_obj, trainingjob_name, feature_list, query_filter,
+ datalake_source, _measurement, bucket):
+ """
+ This function calls data extraction module for data extraction of trainingjob_name training and
+ returns response which we is gotten by calling data extraction module.
+ """
+ logger = training_config_obj.logger
+ logger.debug('training manager is calling data extraction for '+trainingjob_name)
+ data_extraction_ip = training_config_obj.data_extraction_ip
+ data_extraction_port = training_config_obj.data_extraction_port
+ url = 'http://'+str(data_extraction_ip)+':'+str(data_extraction_port)+'/feature-groups'
+ logger.debug(url)
+
+ source = {}
+ source['source'] = datalake_source
+ if 'InfluxSource' in datalake_source:
+ source['source']['InfluxSource']['query']='''from(bucket:"'''+\
+ bucket + '''") |> '''+\
+ '''range(start: 0, stop: now()) '''+\
+ '''|> filter(fn: (r) => r._measurement == "'''+\
+ _measurement + '''") '''+\
+ '''|> pivot(rowKey:["_time"], '''+\
+ '''columnKey: ["_field"], '''+\
+ '''valueColumn: "_value")'''
+
+ transform = {}
+ transform['transform'] = []
+ transform_inner_dic = {}
+ transform_inner_dic['operation'] = "SQLTransform"
+ transform_inner_dic['FeatureList'] = feature_list
+ transform_inner_dic['SQLFilter'] = query_filter
+ transform['transform'].append(transform_inner_dic)
+
+ sink = {}
+ sink_inner_dic = {}
+ sink_inner_dic['CollectionName'] = trainingjob_name
+ sink['CassandraSink'] = sink_inner_dic
+
+ dictionary = {}
+ dictionary.update(source)
+ dictionary.update(transform)
+ dictionary['sink'] = sink
+
+ logger.debug(json.dumps(dictionary))
+
+ response = requests.post(url,
+ data=json.dumps(dictionary),
+ headers={'content-type': 'application/json',
+ 'Accept-Charset': 'UTF-8'})
+ return response
+
+def data_extraction_status(trainingjob_name,training_config_obj):
+ """
+ This function calls data extraction module for getting data extraction status of
+ trainingjob_name training and returns it.
+ """
+ logger = training_config_obj.logger
+ logger.debug('training manager is calling data extraction for '+trainingjob_name)
+ data_extraction_ip = training_config_obj.data_extraction_ip
+ data_extraction_port = training_config_obj.data_extraction_port
+ url = 'http://'+str(data_extraction_ip)+':'+str(data_extraction_port)+\
+ '/task-status/'+trainingjob_name
+ logger.debug(url)
+ response = requests.get(url)
+ return response
+
+def training_start(training_config_obj, dict_data, trainingjob_name):
+ """
+ This function calls kf_adapter module to start pipeline of trainingjob_name training and returns
+ response which is gotten by calling kf adapter module.
+ """
+ logger = training_config_obj.logger
+ logger.debug('training manager is calling kf_adapter for pipeline run for '+trainingjob_name)
+ logger.debug('training manager will send to kf_adapter: '+json.dumps(dict_data))
+ kf_adapter_ip = training_config_obj.kf_adapter_ip
+ kf_adapter_port = training_config_obj.kf_adapter_port
+ url = 'http://'+str(kf_adapter_ip)+':'+str(kf_adapter_port)+\
+ '/trainingjobs/' + trainingjob_name + '/execution'
+ logger.debug(url)
+ response = requests.post(url,
+ data=json.dumps(dict_data),
+ headers={'content-type': 'application/json',
+ 'Accept-Charset': 'UTF-8'})
+
+ return response
--- /dev/null
+# ==================================================================================
+#
+# Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==================================================================================
+
+""""
+This file contains Training management utility functions
+"""
+import json
+from flask_api import status
+import requests
+from trainingmgr.db.common_db_fun import change_in_progress_to_failed_by_latest_version, \
+ get_field_by_latest_version, change_field_of_latest_version, \
+ get_latest_version_trainingjob_name,get_all_versions_info_by_name
+
+from trainingmgr.constants.states import States
+from trainingmgr.common.exceptions_utls import APIException,TMException,DBException
+
+def response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk):
+ """
+ Post training job completion,this function provides notifications to the subscribers,
+ who subscribed for the result of training job and provided a notification url during
+ training job creation.
+ returns tuple containing result dictionary and status code.
+ """
+ logger.debug("Training job result: " + str(code) + " " + message + " " + str(is_success))
+
+ try :
+ #TODO DB query optimization, all data to fetch in one call
+ notif_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "notification_url")
+ if notif_url_result :
+ notification_url = notif_url_result[0][0]
+ model_url_result = None
+ if notification_url != '':
+ model_url_result = get_field_by_latest_version(trainingjob_name, ps_db_obj, "model_url")
+ model_url = model_url_result[0][0]
+ version = get_latest_version_trainingjob_name(trainingjob_name, ps_db_obj)
+ metrics = get_metrics(trainingjob_name, version, mm_sdk)
+
+ req_json = None
+ if is_success:
+ req_json = {
+ "result": "success", "model_url": model_url,
+ "trainingjob_name": trainingjob_name, "metrics": metrics
+ }
+ else:
+ req_json = {"result": "failed", "trainingjob_name": trainingjob_name}
+
+ response = requests.post(notification_url,
+ data=json.dumps(req_json),
+ headers={
+ 'content-type': 'application/json',
+ 'Accept-Charset': 'UTF-8'
+ })
+ if ( response.headers['content-type'] != "application/json"
+ or response.status_code != status.HTTP_200_OK ):
+ err_msg = "Failed to notify the subscribed url " + trainingjob_name
+ raise TMException(err_msg)
+ except Exception as err:
+ change_in_progress_to_failed_by_latest_version(trainingjob_name, ps_db_obj)
+ raise APIException(status.HTTP_500_INTERNAL_SERVER_ERROR,
+ str(err) + "(trainingjob name is " + trainingjob_name + ")") from None
+ if is_success:
+ return {"result": message}, code
+ return {"Exception": message}, code
+
+
+def check_key_in_dictionary(fields, dictionary):
+ '''
+ This function raises exception if any string from fields list does not present in a dictionary
+ as a key
+ '''
+ isKeyAvailable = True
+ for field_name in fields:
+ if field_name not in dictionary:
+ isKeyAvailable = False
+ break
+ #Log (field_name + " not provide")
+ return isKeyAvailable
+
+def get_one_word_status(steps_state):
+ """
+ This function converts steps_state to one word status(we call it overall_status also)
+ and return it.
+ """
+ failed_count = 0
+ finished_count = 0
+ not_started_count = 0
+ in_progress_count = 0
+ for step in steps_state:
+ if steps_state[step] == States.FAILED.name:
+ failed_count = failed_count + 1
+ elif steps_state[step] == States.FINISHED.name:
+ finished_count = finished_count + 1
+ elif steps_state[step] == States.NOT_STARTED.name:
+ not_started_count = not_started_count + 1
+ else:
+ in_progress_count = in_progress_count + 1
+ if failed_count > 0:
+ return States.FAILED.name
+ if not_started_count == len(steps_state):
+ return States.NOT_STARTED.name
+ if finished_count == len(steps_state):
+ return States.FINISHED.name
+ return States.IN_PROGRESS.name
+
+
+def check_trainingjob_data(trainingjob_name, json_data):
+ """
+ This function checks validation for json_data dictionary and return tuple which conatins
+ values of different keys in jsn_data.
+ """
+ try:
+ if check_key_in_dictionary(["feature_list", "pipeline_version", \
+ "pipeline_name", "experiment_name",
+ "arguments", "enable_versioning",
+ "datalake_source", "description",
+ "query_filter", "_measurement",
+ "bucket"], json_data):
+
+ description = json_data["description"]
+ feature_list = json_data["feature_list"]
+ pipeline_name = json_data["pipeline_name"]
+ experiment_name = json_data["experiment_name"]
+ arguments = json_data["arguments"]
+
+ if not isinstance(arguments, dict):
+ raise TMException("Please pass agruments as dictionary for " + trainingjob_name)
+ query_filter = json_data["query_filter"]
+ enable_versioning = json_data["enable_versioning"]
+ pipeline_version = json_data["pipeline_version"]
+ datalake_source = json_data["datalake_source"]
+ _measurement = json_data["_measurement"]
+ bucket = json_data["bucket"]
+ else :
+ raise TMException("check_trainingjob_data- supplied data doesn't have" + \
+ "all the required fields ")
+ except Exception as err:
+ raise APIException(status.HTTP_400_BAD_REQUEST,
+ str(err)) from None
+ return (feature_list, description, pipeline_name, experiment_name,
+ arguments, query_filter, enable_versioning, pipeline_version,
+ datalake_source, _measurement, bucket)
+
+
+def get_one_key(dictionary):
+ '''
+ this function finds any one key from dictionary and retuen it.
+ '''
+ only_key = None
+ for key in dictionary:
+ only_key = key
+ return only_key
+
+
+def get_metrics(trainingjob_name, version, mm_sdk):
+ """
+ Download metrics from object database and returns metrics as string if metrics presents,
+ otherwise returns "No data available" string for <trainingjob_name, version> trainingjob.
+ """
+ data = None
+ try:
+ present = mm_sdk.check_object(trainingjob_name, version, "metrics.json")
+ if present:
+ data = json.dumps(mm_sdk.get_metrics(trainingjob_name, version))
+ if data is None:
+ raise Exception("Problem while downloading metrics")
+ else:
+ data = "No data available"
+ except Exception as err:
+ errMsg = str(err)
+ raise TMException ( "Problem while downloading metric" + errMsg)
+ return data
+
+
+def handle_async_feature_engineering_status_exception_case(lock, dataextraction_job_cache, code,
+ message, logger, is_success,
+ trainingjob_name, ps_db_obj, mm_sdk):
+ """
+ This function changes IN_PROGRESS state to FAILED state and calls response_for_training function
+ and remove trainingjob_name from dataextraction_job_cache.
+ """
+ try:
+ change_in_progress_to_failed_by_latest_version(trainingjob_name, ps_db_obj)
+ response_for_training(code, message, logger, is_success, trainingjob_name, ps_db_obj, mm_sdk)
+ except Exception as err:
+ logger.error("Failed in handle_async_feature_engineering_status_exception_case" + str(err))
+ finally:
+ #Post success/failure handle,process next item from DATAEXTRACTION_JOBS_CACHE
+ with lock:
+ try:
+ dataextraction_job_cache.pop(trainingjob_name)
+ except KeyError as key_err:
+ logger.error("The training job key doesn't exist in DATAEXTRACTION_JOBS_CACHE: " + str(key_err))
+
+def validate_trainingjob_name(trainingjob_name, ps_db_obj):
+ """
+ This function returns True if given trainingjob_name exists in db otherwise
+ it returns False.
+ """
+ results = None
+ isAvailable = False
+ try:
+ results = get_all_versions_info_by_name(trainingjob_name, ps_db_obj)
+ except Exception as err:
+ errMsg = str(err)
+ raise DBException("Could not get info from db for " + trainingjob_name + "," + errMsg)
+ if results:
+ isAvailable = True
+ return isAvailable
\ No newline at end of file