X-Git-Url: https://gerrit.o-ran-sc.org/r/gitweb?a=blobdiff_plain;f=src%2Fqptrain.py;h=946c4302504a94d65228df411780c48d10fcd346;hb=6b85eed29d60f343b0712d757c114bc6207ccdbe;hp=0ce3438dc2bf2b42b37022c4a9c06d3d5217ea02;hpb=2cabb3bf5345a223f13482df7d1b5ccba34b3eda;p=ric-app%2Fqp.git diff --git a/src/qptrain.py b/src/qptrain.py index 0ce3438..946c430 100644 --- a/src/qptrain.py +++ b/src/qptrain.py @@ -20,6 +20,7 @@ from mdclogpy import Logger from exceptions import DataNotMatchError from sklearn.metrics import mean_squared_error from math import sqrt +import pandas as pd import joblib import warnings warnings.filterwarnings("ignore") @@ -80,10 +81,14 @@ class PROCESS(object): def constant(self): val = True df = self.data.copy() - df = df.drop_duplicates().dropna() - df = df.loc[:, (df != 0).any(axis=0)] - if len(df) >= 10: - val = False + df = df[db.thptparam] + df = df.drop_duplicates() + df = df.loc[:, df.apply(pd.Series.nunique) != 1] + if df is not None: + df = df.dropna() + df = df.loc[:, (df != 0).any(axis=0)] + if len(df) >= 10: + val = False return val def evaluate_var(self, X, lag):