Skip to content
Snippets Groups Projects
Commit 1b3ab56d authored by Fadi Gattoussi's avatar Fadi Gattoussi
Browse files

Fixed XGBRegressor

parent 45a53f5b
No related branches found
No related tags found
No related merge requests found
......@@ -22,19 +22,6 @@ diamonds.dropna(inplace=True)
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
# diamonds_numerical = diamonds.copy()
# cut_mapping = {cut: i for i, cut in enumerate(diamonds['cut'].unique())}
# color_mapping = {color: i for i, color in enumerate(diamonds['color'].unique())}
# clarity_mapping = {clarity: i for i, clarity in enumerate(diamonds['clarity'].unique())}
# diamonds_numerical['cut'] = diamonds['cut'].map(cut_mapping)
# diamonds_numerical['color'] = diamonds['color'].map(color_mapping)
# diamonds_numerical['clarity'] = diamonds['clarity'].map(clarity_mapping)
# X = diamonds_numerical.drop('price', axis = 1)
# y = diamonds_numerical['price']
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# model = Model(X_train, X_test, y_train, y_test,'RF')
# model.train()
# joblib.dump(model, 'RF.joblib')
model = joblib.load('RF.joblib')
price = 0
......@@ -117,14 +104,14 @@ class MainWindow(QMainWindow):
self.show()
def update_values(self):
self.GUI_selections = {}
for key, value in self.menu_items.items():
if Advanced.Advanced_selections[value]:
if key < 3:
self.GUI_selections[value] = self.input_fields[key].currentText()
else:
if self.validate_inputs(self.input_fields[key].text(), value):
self.GUI_selections[value] = self.input_fields[key].text()
self.GUI_selections[value] = float(self.input_fields[key].text())
else:
return
......@@ -180,8 +167,8 @@ class Advanced(QMainWindow):
self.graph_selector_Y.addItems((self.check_labels + ['price']))
self.regression_model.addItems(["Linear Regression", "XGBRegressor", "Neural Network",
"Random Forest Regressor"])
self.teach = QPushButton('Re-Teach', self)
self.teach.clicked.connect(self.re_teach)
self.teach = QPushButton('Re-Train', self)
self.teach.clicked.connect(self.re_train)
self.plot_graph = QPushButton('PLOT', self)
self.plot_graph.clicked.connect(self.create_graph)
......@@ -214,7 +201,7 @@ class Advanced(QMainWindow):
else:
Advanced.Advanced_selections[checkbox.text()] = True
def re_teach(self):
def re_train(self):
global diamonds
X = diamonds.copy()
for key, value in Advanced.Advanced_selections.items():
......
No preview for this file type
......@@ -21,7 +21,7 @@ class Model:
if self.modelname == "Linear Regression":
self.model = LinearRegression()
elif self.modelname == "XGBRegressor":
self.model = XGBRegressor(random_state=42, n_estimators=100)
self.model = XGBRegressor(random_state=42, n_estimators=100, enable_categorical = True)
elif self.modelname == "Neural Network":
self.model = MLPRegressor(random_state=42, max_iter=500)
elif self.modelname == "Random Forest Regressor":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment