X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=blobdiff_plain;f=src%2Fqptrain.py;fp=qp%2Fqptrain.py;h=0ce3438dc2bf2b42b37022c4a9c06d3d5217ea02;hb=30a9743fdfd0ef62164c7ea74a4a120cb1c86852;hp=2071ac35859e11b11be901b29e7c95d6ef1476fd;hpb=48aa41171b8ca141df1f14341ea0ec2fde745af5;p=ric-app%2Fqp.git diff --git a/qp/qptrain.py b/src/qptrain.py similarity index 55% rename from qp/qptrain.py rename to src/qptrain.py index 2071ac3..0ce3438 100644 --- a/qp/qptrain.py +++ b/src/qptrain.py @@ -16,11 +16,15 @@ from statsmodels.tsa.api import VAR from statsmodels.tsa.stattools import adfuller +from mdclogpy import Logger +from exceptions import DataNotMatchError +from sklearn.metrics import mean_squared_error +from math import sqrt import joblib +import warnings +warnings.filterwarnings("ignore") - -class DataNotMatchError(Exception): - pass +logger = Logger(name=__name__) class PROCESS(object): @@ -29,6 +33,14 @@ class PROCESS(object): self.diff = 0 self.data = data + def input_data(self): + try: + self.data = self.data[db.thptparam] + self.data = self.data.fillna(method='bfill') + except DataNotMatchError: + logger.error('Parameters Downlink throughput and Uplink throughput does not exist in provided data') + self.data = None + def adfuller_test(self, series, thresh=0.05, verbose=False): """ADFuller test for Stationarity of given series and return True or False""" r = adfuller(series, autolag='AIC') @@ -62,39 +74,68 @@ class PROCESS(object): return df def process(self): - """ Filter throughput parameters, call make_stationary() to check for Stationarity time series - """ - df = self.data.copy() - try: - df = df[['pdcpBytesDl', 'pdcpBytesUl']] - except DataNotMatchError: - print('Parameters pdcpBytesDl, pdcpBytesUl does not exist in provided data') - self.data = None - self.data = df.loc[:, (df != 0).any(axis=0)] + self.input_data() self.make_stationary() # check for Stationarity and make the Time Series Stationary - def valid(self): - val = False - if self.data is not None: - df = self.data.copy() - df = df.loc[:, (df != 0).any(axis=0)] - if len(df) != 0 and df.shape[1] == 2: - val = True + def constant(self): + val = True + df = self.data.copy() + df = df.drop_duplicates().dropna() + df = df.loc[:, (df != 0).any(axis=0)] + if len(df) >= 10: + val = False return val + def evaluate_var(self, X, lag): + # prepare training dataset + train_size = int(len(X) * 0.75) + train, test = X[0:train_size], X[train_size:] + # make predictions + model = VAR(train) + model_fit = model.fit(lag) + predictions = model_fit.forecast(y=train.values, steps=len(test)) + # calculate out of sample error + rmse = sqrt(mean_squared_error(test, predictions)) + return rmse + + def optimize_lag(self, df): + lag = range(1, 20, 1) + df = df.astype('float32') + best_score, best_lag = float("inf"), None + for l in lag: + try: + rmse = self.evaluate_var(df, l) + if rmse < best_score: + best_score, best_lag = rmse, l + except ValueError as v: + print(v) + # print('Best VAR%s RMSE=%.3f' % (best_lag, best_score)) + return best_lag + -def train(db, cid): +def train_cid(cid): """ Read the input file(based on cell id received from the main program) call process() to forecast the downlink and uplink of the input cell id Make a VAR model, call the fit method with the desired lag order. """ - db.read_data(meas='liveCell', cellid=cid) + # print(f'Training for {cid}') + db.read_data(cellid=cid, limit=4800) md = PROCESS(db.data) - md.process() - if md.valid(): + if md.data is not None and not md.constant(): + md.process() + lag = md.optimize_lag(md.data) model = VAR(md.data) # Make a VAR model - model_fit = model.fit(10) # call fit method with lag order - file_name = 'qp/'+cid.replace('/', '') - with open(file_name, 'wb') as f: - joblib.dump(model_fit, f) # Save the model with the cell id name + try: + model_fit = model.fit(lag) # call fit method with lag order + file_name = 'src/'+cid.replace('/', '') + with open(file_name, 'wb') as f: + joblib.dump(model_fit, f) # Save the model with the cell id name + except ValueError as v: + print("****************************************", v) + + +def train(database, cid): + global db + db = database + train_cid(cid)