1 # ==================================================================================
2 # Copyright (c) 2020 HCL Technologies Limited.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 # ==================================================================================
22 class modelling(object):
23 def __init__(self, data):
24 """ Separating UEID and timestamp features to be mapped later after prediction """
25 self.time = data.MeasTimestampRF
27 self.data = data.drop(['UEID', 'MeasTimestampRF'], axis=1)
29 def predict(self, name):
31 Load the saved model and map the predicted category into Category field.
32 Map UEID, MeasTimestampRF with the predicted result.
34 model = joblib.load('ad/' + name)
35 pred = model.predict(self.data)
36 data = self.data.copy()
37 le = joblib.load('ad/LabelEncoder')
38 data['Category'] = le.inverse_transform(pred)
39 data['MeasTimestampRF'] = self.time
40 data['UEID'] = self.id
46 If the category of UEID is present in the segment file, it is considered as normal(0)
47 otherwise, the sample is considered as anomaly.
49 with open("ad/ue_seg.json", "r") as json_data:
50 segment = json.loads(json_data.read())
53 if df.loc[i, 'Category'] in segment[str(df.loc[i, 'UEID'])]:
62 Extract all the unique UEID
63 Call Predict method to get the final data for the randomly selected UEID
65 ue_list = df.UEID.unique() # Extract unique UEIDs
66 ue = random.choice(ue_list) # Randomly selected the ue list
67 df = df[df['UEID'] == ue]
69 db_df = db.predict('RF') # Calls predict module and store the result into db_df
71 db_df['Anomaly'] = compare(db_df)