Skip to content
Snippets Groups Projects
DisplayWindow.py 10.5 KiB
Newer Older
import sys
from PyQt6.QtCore import Qt
Michael Mutote's avatar
Michael Mutote committed
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
Michael Mutote's avatar
Michael Mutote committed
import modeling
Michael Mutote's avatar
Michael Mutote committed
from PyQt6.QtGui import QIcon
from sklearn.model_selection import train_test_split
import joblib
from pyqtgraph import PlotWidget, plot
import pyqtgraph as pg
Michael Mutote's avatar
Michael Mutote committed
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
Michael Mutote's avatar
Michael Mutote committed
                             QVBoxLayout, QGridLayout,
                             QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)

matplotlib.use('QtAgg')
Michael Mutote's avatar
Michael Mutote committed
diamonds = pd.read_csv('diamonds.csv')
diamonds.dropna(inplace=True)
Michael Mutote's avatar
Michael Mutote committed
mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0)
diamonds = diamonds.drop(diamonds[mask].index)
Michael Mutote's avatar
Michael Mutote committed
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
model = joblib.load('RF.joblib')
price = 0


Michael Mutote's avatar
Michael Mutote committed
class MplCanvas(FigureCanvasQTAgg):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        super().__init__(fig)


class MyTabs(QMainWindow):
Michael Mutote's avatar
Michael Mutote committed
    @staticmethod
    def numericise(df):
        if Advanced.Advanced_selections['cut']:
            df['cut'] = df['cut'].map({'Ideal': 0, 'Premium': 1, 'Very Good': 2, 'Good': 3, 'Fair': 4})
        if Advanced.Advanced_selections['color']:
            df['color'] = df['color'].map({'E': 0, 'I': 1, 'J': 2, 'H': 3, 'F': 4, 'G': 5, 'D': 6})
        if Advanced.Advanced_selections['clarity']:
            df['clarity'] = df['clarity'].map({'SI2': 0, 'SI1': 1, 'VS1': 2, 'VS2': 3, 'VVS2': 4, 'VVS1': 5,
                                               'IF': 7, 'I1': 6})
        return df

    def __init__(self):
        super().__init__()
Michael Mutote's avatar
Michael Mutote committed
        self.setWindowIcon(QIcon("icon.ico"))
        self.setMinimumSize(720, 640)
        # Create your custom pages (e.g., TradeView and OrderView)
Michael Mutote's avatar
Michael Mutote committed
        self.calc = MainWindow()
        self.adjust = Advanced()

        # Create a tab widget
        tab_widget = QTabWidget(self)
Michael Mutote's avatar
Michael Mutote committed
        tab_widget.addTab(self.calc, "Main")
        tab_widget.addTab(self.adjust, "Advanced")

        # Set the central widget
        self.setCentralWidget(tab_widget)


class MainWindow(QMainWindow):

    def __init__(self):
        super().__init__()

        # Create a central widget
        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)

Michael Mutote's avatar
Michael Mutote committed

        # Create a grid layout and set it to the central widget
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout = QGridLayout(central_widget)
Michael Mutote's avatar
Michael Mutote committed
        # List of Menu Items in dict
        self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
                           7: 'y', 8: 'z'}
        self.GUI_selections = {}
Michael Mutote's avatar
Michael Mutote committed
        # create new graphs
        self.scatt = MplCanvas(self, width=6, height=5, dpi=100)
        self.scatt.axes.set_title("Carats vs Price")
        self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2)

        # Create labels and input fields
        self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
                             QLineEdit(self), QLineEdit(self), QLineEdit(self), QLineEdit(self),
                             QLineEdit(self)]
        self.input_fields[0].addItems(cut)
        self.input_fields[1].addItems(colors)
        self.input_fields[2].addItems(clarity)
        # add a button to capture all the data
        self.calculate_button = QPushButton('Calculate Price', self)
        self.calculate_button.clicked.connect(self.update_values)

        # Add labels and input fields to the grid layout
Michael Mutote's avatar
Michael Mutote committed
        for i, val in self.menu_items.items():
            label = QLabel(val, self)
Michael Mutote's avatar
Michael Mutote committed
            self.grid_layout.addWidget(label, 0, i)
            self.grid_layout.addWidget(self.input_fields[i], 1, i)

        # Create a plot widget
Michael Mutote's avatar
Michael Mutote committed


        # Add the plot widget to the grid layout
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
        self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9)
        self.grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
Michael Mutote's avatar
Michael Mutote committed
        self.display_price = QLabel(str(0))
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1)

        self.setWindowTitle('Assistance Systems')
        self.show()

    def update_values(self):
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        self.GUI_selections = {}
Michael Mutote's avatar
Michael Mutote committed
        for key, value in self.menu_items.items():
            if Advanced.Advanced_selections[value]:
                if key < 3:
                    self.GUI_selections[value] = self.input_fields[key].currentText()
                else:
                    if self.validate_inputs(self.input_fields[key].text(), value):
Fadi Gattoussi's avatar
Fadi Gattoussi committed
                        self.GUI_selections[value] = float(self.input_fields[key].text())
Michael Mutote's avatar
Michael Mutote committed
                    else:
                        return

        # convert GUI_selections to dataframe
        X_test = pd.DataFrame(self.GUI_selections, index=[0])
        print(X_test.head())
Michael Mutote's avatar
Michael Mutote committed
        X_test = MyTabs.numericise(X_test)
Michael Mutote's avatar
Michael Mutote committed
        # price = model.predict(X_test)[0]
        self.display_price.setText(str(model.predict_price(X_test)))

    def validate_inputs(self, val, attribute):
        if Advanced.Advanced_selections[attribute]:
            try:
                isinstance(float(val), float)
                return True
            except ValueError:
                dlg = QDialog(self)
                dlg.setGeometry(400, 500, 200, 100)
                dlg.setWindowTitle("ERROR")
                layout = QVBoxLayout()
                layout.addWidget(QLabel(f"{attribute} numerical value required"))
                button = QPushButton("OK")
                button.clicked.connect(lambda: dlg.reject())
                layout.addWidget(button)
                dlg.setLayout(layout)
                dlg.exec()
                return False
        else:
            return True


class Advanced(QMainWindow):
Michael Mutote's avatar
Michael Mutote committed
    Advanced_selections = {'cut': True, 'color': True, 'clarity': True, 'carat': True, 'depth': True,
                           'table': True, 'x': True, 'y': True, 'z': True}

    def __init__(self):
        super().__init__()

        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)
        # Create a grid layout and set it to the central widget
Michael Mutote's avatar
Michael Mutote committed
        self.grid_layout = QGridLayout(central_widget)

        # create new graphs
        self.histogram = MplCanvas(self, width=6, height=4, dpi=100)
        # ---------------
        self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
        # end graphs
        self.graph_selector_X = QComboBox(self)
        self.graph_selector_Y = QComboBox(self)
        self.regression_model = QComboBox(self)
Michael Mutote's avatar
Michael Mutote committed
        self.R2 = QLabel("R-Squared Error = ")
        self.MSE = QLabel("Mean Squared Error = ")
        self.check_labels = ['cut', 'color', 'clarity', 'carat', 'depth', 'table', 'x', 'y', 'z']
        self.graph_selector_X.addItems((self.check_labels + ['price']))
        self.graph_selector_Y.addItems((self.check_labels + ['price']))
Michael Mutote's avatar
Michael Mutote committed
        self.regression_model.addItems(["Linear Regression", "XGBRegressor", "Neural Network",
                                        "Random Forest Regressor"])
Fadi Gattoussi's avatar
Fadi Gattoussi committed
        self.teach = QPushButton('Re-Train', self)
        self.teach.clicked.connect(self.re_train)
Michael Mutote's avatar
Michael Mutote committed
        self.plot_graph = QPushButton('PLOT', self)
        self.plot_graph.clicked.connect(self.create_graph)
Michael Mutote's avatar
Michael Mutote committed
        self.checkboxes = []
Michael Mutote's avatar
Michael Mutote committed
        # grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
        self.grid_layout.addWidget(self.graph_selector_X, 3, 1)
        self.grid_layout.addWidget(self.graph_selector_Y, 3, 3)
        self.grid_layout.addWidget(self.regression_model, 2, 1)
        self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
        self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
        self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
        self.grid_layout.addWidget(self.plot_graph, 2, 6)
        self.grid_layout.addWidget(self.teach, 2, 7, 1, 2)
        self.grid_layout.addWidget(self.R2, 2, 4, 1, 2)
        self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2)

        self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
        self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)
        for i, label in enumerate(self.check_labels):
            checkbox = QCheckBox(label, self)
            checkbox.setChecked(True)
            checkbox.stateChanged.connect(self.handle_checkbox_state)  # Connect signal
Michael Mutote's avatar
Michael Mutote committed
            self.grid_layout.addWidget(checkbox, 1, i)
            self.checkboxes.append(checkbox)  # Store checkboxes in the list

    def handle_checkbox_state(self):
        for i, checkbox in enumerate(self.checkboxes):
            state = checkbox.checkState()
Michael Mutote's avatar
Michael Mutote committed
            if state == Qt.CheckState.Unchecked:
Michael Mutote's avatar
Michael Mutote committed
                Advanced.Advanced_selections[checkbox.text()] = False
            else:
                Advanced.Advanced_selections[checkbox.text()] = True

Fadi Gattoussi's avatar
Fadi Gattoussi committed
    def re_train(self):
Michael Mutote's avatar
Michael Mutote committed
        global diamonds
        X = diamonds.copy()
        for key, value in Advanced.Advanced_selections.items():
            if not value:
                X = X.drop(key, axis=1)
        y = X['price']
        X = X.drop('price', axis=1)
        X = MyTabs.numericise(X)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        global model
        model = modeling.Model(X_train, X_test, y_train, y_test, self.regression_model.currentText())
        training = model.train()
        model.evaluate()

Michael Mutote's avatar
Michael Mutote committed
        self.R2.setText(f"R2 = {str(model.r2)}")
        self.MSE.setText(f"MSE = {str(model.mse)}")
Michael Mutote's avatar
Michael Mutote committed

    def create_graph(self):
Michael Mutote's avatar
Michael Mutote committed

        X = diamonds[self.graph_selector_X.currentText()]
        Y = diamonds[self.graph_selector_Y.currentText()]

        self.scatt.axes.clear()
        self.histogram.axes.clear()

        self.scatt.axes.scatter(X, Y, color='green', s=2)
        self.histogram.axes.hist(X, bins=50)
        self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs "
                                  f"{self.graph_selector_Y.currentText()}" )
        self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}")
        self.scatt.axes.set_xlabel(self.graph_selector_X.currentText())
        self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText())
        self.histogram.axes.set_xlabel(self.graph_selector_X.currentText())
Michael Mutote's avatar
Michael Mutote committed
        self.histogram.axes.set_ylabel("frequency")
Michael Mutote's avatar
Michael Mutote committed

        self.scatt.draw()
        self.histogram.draw()


app = QApplication(sys.argv)

window = MyTabs()
window.show()
app.exec()