Skip to content
Snippets Groups Projects
Commit dc54365f authored by Fadi Gattoussi's avatar Fadi Gattoussi
Browse files

Added buttons styling for training in progress and disabled fields

parent 70f980ee
No related branches found
No related tags found
No related merge requests found
......@@ -2,12 +2,15 @@ import sys
import pandas as pd
from PyQt6.QtCore import Qt
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import modeling
from PyQt6.QtGui import QIcon
from PyQt6.QtGui import QIcon, QMovie
from sklearn.model_selection import train_test_split
import joblib
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
QVBoxLayout, QGridLayout,
QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)
......@@ -118,11 +121,13 @@ class MainWindow(QMainWindow):
for key, value in self.menu_items.items():
if value == field_name:
self.input_fields[key].setDisabled(True)
self.input_fields[key].setStyleSheet("background-color: grey; color: white;")
def enable_input_field(self, field_name):
for key, value in self.menu_items.items():
if value == field_name:
self.input_fields[key].setEnabled(True)
self.input_fields[key].setStyleSheet("background-color: white; color: black;")
def update_values(self):
self.GUI_selections = {}
......@@ -153,14 +158,15 @@ class MainWindow(QMainWindow):
def create_graph(self):
self.scatt.axes.clear()
self.scatt.axes.set_title("Carats vs Price")
self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2, label='carat')
self.scatt.axes.scatter(self.GUI_selections['carat'], self.price, color='blue', s=50, label='expected')
# self.scatt.axes.axhline(y=self.price, color='blue', linestyle='--', label='Predicted Price')
# self.scatt.axes.axhline(x=self.GUI_selections['carat'], color='blue', linestyle='--', label='carat')
# self.scatt.legend()
print("I am error")
self.scatt.axes.axhline(y=self.price, color='blue', linestyle='--', label='Predicted Price')
try:
self.scatt.axes.scatter(self.GUI_selections['carat'], self.price, color='blue', s=50, label='expected')
self.scatt.axes.axvline(x=self.GUI_selections['carat'], color='blue', linestyle='--', label='carat')
except KeyError:
pass
self.scatt.draw()
def validate_inputs(self, val, attribute):
......@@ -206,7 +212,6 @@ class Advanced(QMainWindow):
# ---------------
self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
# end graphs
self.graph_selector_X = QComboBox(self)
self.graph_selector_Y = QComboBox(self)
self.regression_model = QComboBox(self)
......@@ -223,7 +228,6 @@ class Advanced(QMainWindow):
self.teach.clicked.connect(self.re_train)
self.plot_graph = QPushButton('PLOT', self)
self.plot_graph.clicked.connect(self.create_graph)
self.checkboxes = []
# grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
......@@ -240,7 +244,6 @@ class Advanced(QMainWindow):
self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)
for i, label in enumerate(self.check_labels):
checkbox = QCheckBox(label, self)
checkbox.setChecked(True)
......@@ -258,8 +261,12 @@ class Advanced(QMainWindow):
else:
Advanced.Advanced_selections[checkbox.text()] = True
window.calc.enable_input_field(checkbox.text())
def re_train(self):
def re_train(self):
self.teach.setText('In progress')
self.teach.setStyleSheet("background-color: red; color: white;")
self.teach.repaint()
global diamonds
X = diamonds.copy()
for key, value in Advanced.Advanced_selections.items():
......@@ -268,18 +275,20 @@ class Advanced(QMainWindow):
y = X['price']
X = X.drop('price', axis=1)
X = MyTabs.numericise(X)
self.teach.setEnabled(False)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
global model
model = modeling.Model(X_train, X_test, y_train, y_test, self.regression_model.currentText())
training = model.train()
model.train()
model.evaluate()
self.R2.setText(f"R2 = {str(model.r2)}")
self.MSE.setText(f"MSE = {str(model.mse)}")
Advanced.checkboxes_changed = False
self.teach.setText('Re-Train')
self.teach.setStyleSheet("background-color: white; color: black;")
self.teach.setEnabled(True)
def create_graph(self):
X = diamonds[self.graph_selector_X.currentText()]
Y = diamonds[self.graph_selector_Y.currentText()]
......
No preview for this file type
......@@ -28,10 +28,8 @@ class Model:
self.model = RandomForestRegressor(random_state=42, n_estimators=100)
else:
raise Exception("Model not implemented")
self.model.fit(self.X_train, self.y_train)
print('done')
return True
return
def save_model(self, filename):
dump(self.model, filename)
......
src/loading.gif

363 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment