Skip to content
Snippets Groups Projects
main.py 13.1 KiB
Newer Older
import sys
from PyQt6.QtCore import Qt
Michael Mutote's avatar
Michael Mutote committed
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
Michael Mutote's avatar
Michael Mutote committed
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
Michael Mutote's avatar
Michael Mutote committed
import modeling
from PyQt6.QtGui import QIcon, QMovie
Michael Mutote's avatar
Michael Mutote committed
from sklearn.model_selection import train_test_split
import joblib
Michael Mutote's avatar
Michael Mutote committed
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
Michael Mutote's avatar
Michael Mutote committed
                             QVBoxLayout, QGridLayout,
                             QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)

matplotlib.use('QtAgg')
Michael Mutote's avatar
Michael Mutote committed
diamonds = pd.read_csv('diamonds.csv')
diamonds.dropna(inplace=True)
Michael Mutote's avatar
Michael Mutote committed
mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0)
diamonds = diamonds.drop(diamonds[mask].index)
Michael Mutote's avatar
Michael Mutote committed
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
model = joblib.load('models/RF.joblib')
Michael Mutote's avatar
Michael Mutote committed
price = 0


Michael Mutote's avatar
Michael Mutote committed
class MplCanvas(FigureCanvasQTAgg):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        super().__init__(fig)


class MyTabs(QMainWindow):
Michael Mutote's avatar
Michael Mutote committed
    @staticmethod
    def numericise(df):
        if Advanced.Advanced_selections['cut']:
            df['cut'] = df['cut'].map({'Ideal': 0, 'Premium': 1, 'Very Good': 2, 'Good': 3, 'Fair': 4})
        if Advanced.Advanced_selections['color']:
            df['color'] = df['color'].map({'E': 0, 'I': 1, 'J': 2, 'H': 3, 'F': 4, 'G': 5, 'D': 6})
        if Advanced.Advanced_selections['clarity']:
            df['clarity'] = df['clarity'].map({'SI2': 0, 'SI1': 1, 'VS1': 2, 'VS2': 3, 'VVS2': 4, 'VVS1': 5,
                                               'IF': 7, 'I1': 6})
        return df

    def __init__(self):
        super().__init__()
Michael Mutote's avatar
Michael Mutote committed
        self.setWindowIcon(QIcon("icon.ico"))
        self.setMinimumSize(720, 640)
        # Create your custom pages (e.g., TradeView and OrderView)
Michael Mutote's avatar
Michael Mutote committed
        self.calc = MainWindow()
        self.adjust = Advanced()

        # Create a tab widget
        tab_widget = QTabWidget(self)
Michael Mutote's avatar
Michael Mutote committed
        tab_widget.addTab(self.calc, "Main")
        tab_widget.addTab(self.adjust, "Advanced")

        # Set the central widget
        self.setCentralWidget(tab_widget)


class MainWindow(QMainWindow):

    def __init__(self):
        super().__init__()

        # Create a central widget
        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)

        # Create a grid layout and set it to the central widget
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout = QGridLayout(central_widget)
Michael Mutote's avatar
Michael Mutote committed
        # List of Menu Items in dict
        self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
                           7: 'y', 8: 'z'}
        self.GUI_selections = {}
Michael Mutote's avatar
Michael Mutote committed
        # create new graphs
        self.scatt = MplCanvas(self, width=6, height=5, dpi=100)
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        # self.scatt.axes.set_title("Carats vs Price")
        # self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2, label='carat')
        # self.scatt.axes.scatter(self.GUI_selections['carat'], self.price, color='blue', s=50, label='expected')

        # Create labels and input fields
        self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
                             QLineEdit(self), QLineEdit(self), QLineEdit(self), QLineEdit(self),
                             QLineEdit(self)]
        self.input_fields[0].addItems(cut)
        self.input_fields[1].addItems(colors)
        self.input_fields[2].addItems(clarity)
        # add a button to capture all the data
        self.calculate_button = QPushButton('Calculate Price', self)
        self.calculate_button.clicked.connect(self.update_values)

        # Add labels and input fields to the grid layout
Michael Mutote's avatar
Michael Mutote committed
        for i, val in self.menu_items.items():
            label = QLabel(val, self)
Michael Mutote's avatar
Michael Mutote committed
            self.grid_layout.addWidget(label, 0, i)
            self.grid_layout.addWidget(self.input_fields[i], 1, i)

        # Create a plot widget
Michael Mutote's avatar
Michael Mutote committed

        # Add the plot widget to the grid layout
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
        self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9)
        predicted_price_label = QLabel("Predicted Price: ", self)
        predicted_price_label.setStyleSheet("font-size: 20px;")
        self.grid_layout.addWidget(predicted_price_label, 4, 0, 1, 1)
Michael Mutote's avatar
Michael Mutote committed
        self.display_price = QLabel(str(0))
        self.display_price.setStyleSheet("font-size: 20px; font-weight: bold; color: purple; font-family: Ubuntu;")
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1)

        self.setWindowTitle('Assistance Systems')
        self.show()
Fadi Gattoussi's avatar
Fadi Gattoussi committed

    def disable_input_field(self, field_name):
        for key, value in self.menu_items.items():
            if value == field_name:
                self.input_fields[key].setDisabled(True)
                self.input_fields[key].setStyleSheet("background-color: grey; color: white;")

    def enable_input_field(self, field_name):
        for key, value in self.menu_items.items():
            if value == field_name:
                self.input_fields[key].setEnabled(True)
                self.input_fields[key].setStyleSheet("background-color: white; color: black;")
    def update_values(self):
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        self.GUI_selections = {}
Michael Mutote's avatar
Michael Mutote committed
        if not Advanced.checkboxes_changed:
            for key, value in self.menu_items.items():
                if Advanced.Advanced_selections[value]:
                    if key < 3:
                        self.GUI_selections[value] = self.input_fields[key].currentText()
Michael Mutote's avatar
Michael Mutote committed
                    else:
                        if self.validate_inputs(self.input_fields[key].text(), value):
                            self.GUI_selections[value] = float(self.input_fields[key].text())
                        else:
                            return
Michael Mutote's avatar
Michael Mutote committed
        else:
            self.dialog_box("ERROR", "you need to retrain the model before recalculating,"
                                     " parameters have changed")
            return
Michael Mutote's avatar
Michael Mutote committed

        # convert GUI_selections to dataframe
        X_test = pd.DataFrame(self.GUI_selections, index=[0])
        print(X_test.head())
Michael Mutote's avatar
Michael Mutote committed
        X_test = MyTabs.numericise(X_test)
Michael Mutote's avatar
Michael Mutote committed
        self.price = model.predict_price(X_test)
Michael Mutote's avatar
Michael Mutote committed
        # price = model.predict(X_test)[0]
Michael Mutote's avatar
Michael Mutote committed
        self.display_price.setText(str(self.price))
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        self.create_graph()

    def create_graph(self):
        self.scatt.axes.clear()
        self.scatt.axes.set_title("Carats vs Price")
        self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2, label='carat')
        self.scatt.axes.axhline(y=self.price, color='blue', linestyle='--', label='Predicted Price')
        try:
            self.scatt.axes.scatter(self.GUI_selections['carat'], self.price, color='blue', s=50, label='expected')
            self.scatt.axes.axvline(x=self.GUI_selections['carat'], color='blue', linestyle='--', label='carat')
            
        except KeyError:
            pass
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        self.scatt.draw()
Michael Mutote's avatar
Michael Mutote committed

    def validate_inputs(self, val, attribute):
        if Advanced.Advanced_selections[attribute]:
            try:
                isinstance(float(val), float)
                return True
            except ValueError:
Michael Mutote's avatar
Michael Mutote committed
                self.dialog_box("ERROR", f"{attribute} numerical value required")
Michael Mutote's avatar
Michael Mutote committed
                return False
        else:
            return True
Michael Mutote's avatar
Michael Mutote committed
    def dialog_box(self, title, message):
        dlg = QDialog(self)
        dlg.setGeometry(400, 500, 200, 100)
        dlg.setWindowTitle(title)
        layout = QVBoxLayout()
        layout.addWidget(QLabel(message))
        button = QPushButton("OK")
        button.clicked.connect(lambda: dlg.reject())
        layout.addWidget(button)
        dlg.setLayout(layout)
        dlg.exec()


class Advanced(QMainWindow):
Michael Mutote's avatar
Michael Mutote committed
    Advanced_selections = {'cut': True, 'color': True, 'clarity': True, 'carat': True, 'depth': True,
                           'table': True, 'x': True, 'y': True, 'z': True}
Michael Mutote's avatar
Michael Mutote committed


    def __init__(self):
        super().__init__()

        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)
        # Create a grid layout and set it to the central widget
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout = QGridLayout(central_widget)

        # create new graphs
        self.histogram = MplCanvas(self, width=6, height=4, dpi=100)
        # ---------------
        self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
        # end graphs
        self.graph_selector_X = QComboBox(self)
        self.graph_selector_Y = QComboBox(self)
        self.regression_model = QComboBox(self)
Michael Mutote's avatar
Michael Mutote committed
        self.R2 = QLabel("R-Squared Error = ")
        self.MSE = QLabel("Mean Squared Error = ")
        self.check_labels = ['cut', 'color', 'clarity', 'carat', 'depth', 'table', 'x', 'y', 'z']
        self.graph_selector_X.addItems((self.check_labels + ['price']))
        self.graph_selector_Y.addItems((self.check_labels + ['price']))
        self.regression_model.addItems(["Random Forest Regressor",
                                        "Linear Regression",
                                        "XGBRegressor",
                                        "Neural Network"])
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        self.teach = QPushButton('Re-Train', self)
        self.teach.clicked.connect(self.re_train)
Michael Mutote's avatar
Michael Mutote committed
        self.plot_graph = QPushButton('PLOT', self)
        self.plot_graph.clicked.connect(self.create_graph)
        self.checkboxes = []
Michael Mutote's avatar
Michael Mutote committed
        # grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
        self.grid_layout.addWidget(self.graph_selector_X, 3, 1)
        self.grid_layout.addWidget(self.graph_selector_Y, 3, 3)
        self.grid_layout.addWidget(self.regression_model, 2, 1)
        self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
        self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
        self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout.addWidget(self.plot_graph, 3, 4)
        self.grid_layout.addWidget(self.teach, 2, 6, 1, 1)
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout.addWidget(self.R2, 2, 4, 1, 2)
        self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2)

        self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
        self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)
        for i, label in enumerate(self.check_labels):
            checkbox = QCheckBox(label, self)
            checkbox.setChecked(True)
            checkbox.stateChanged.connect(self.handle_checkbox_state)  # Connect signal
Michael Mutote's avatar
Michael Mutote committed
            self.grid_layout.addWidget(checkbox, 1, i)
            self.checkboxes.append(checkbox)  # Store checkboxes in the list

    def handle_checkbox_state(self):
Michael Mutote's avatar
Michael Mutote committed
        Advanced.checkboxes_changed = True
        for i, checkbox in enumerate(self.checkboxes):
            state = checkbox.checkState()
Michael Mutote's avatar
Michael Mutote committed
            if state == Qt.CheckState.Unchecked:
Michael Mutote's avatar
Michael Mutote committed
                Advanced.Advanced_selections[checkbox.text()] = False
                window.calc.disable_input_field(checkbox.text())
Michael Mutote's avatar
Michael Mutote committed
            else:
                Advanced.Advanced_selections[checkbox.text()] = True
                window.calc.enable_input_field(checkbox.text())
Michael Mutote's avatar
Michael Mutote committed

    def re_train(self):      
        self.teach.setText('In progress')
        self.teach.setStyleSheet("background-color: red; color: white;")
        self.teach.repaint()
Michael Mutote's avatar
Michael Mutote committed
        global diamonds
        X = diamonds.copy()
        for key, value in Advanced.Advanced_selections.items():
            if not value:
                X = X.drop(key, axis=1)
        y = X['price']
        X = X.drop('price', axis=1)
        X = MyTabs.numericise(X)
Michael Mutote's avatar
Michael Mutote committed
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        global model
        model = modeling.Model(X_train, X_test, y_train, y_test, self.regression_model.currentText())
Michael Mutote's avatar
Michael Mutote committed
        model.evaluate()
Michael Mutote's avatar
Michael Mutote committed
        self.R2.setText(f"R2 = {str(model.r2)}")
        self.MSE.setText(f"MSE = {str(model.mse)}")
Michael Mutote's avatar
Michael Mutote committed
        Advanced.checkboxes_changed = False
        self.teach.setText('Re-Train')
        self.teach.setStyleSheet("background-color: white; color: black;")
        self.teach.setEnabled(True)
Michael Mutote's avatar
Michael Mutote committed

    def create_graph(self):
Michael Mutote's avatar
Michael Mutote committed
        X = diamonds[self.graph_selector_X.currentText()]
        Y = diamonds[self.graph_selector_Y.currentText()]

        self.scatt.axes.clear()
        self.histogram.axes.clear()

        self.scatt.axes.scatter(X, Y, color='green', s=2)
        self.histogram.axes.hist(X, bins=50)
        self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs "
Michael Mutote's avatar
Michael Mutote committed
                                  f"{self.graph_selector_Y.currentText()}")
Michael Mutote's avatar
Michael Mutote committed
        self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}")
        self.scatt.axes.set_xlabel(self.graph_selector_X.currentText())
        self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText())
        self.histogram.axes.set_xlabel(self.graph_selector_X.currentText())
Michael Mutote's avatar
Michael Mutote committed
        self.histogram.axes.set_ylabel("frequency")
Michael Mutote's avatar
Michael Mutote committed

        self.scatt.draw()
        self.histogram.draw()
Michael Mutote's avatar
Michael Mutote committed


app = QApplication(sys.argv)

window = MyTabs()
window.show()