From 3f021252cc24631bf714c25c084c7c706f6d416d Mon Sep 17 00:00:00 2001 From: Michael <michael.mutote@stud.th-deg.de> Date: Sun, 24 Dec 2023 11:40:02 -0800 Subject: [PATCH] plots completed --- DisplayWindow.py | 112 +++++++++++++++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 37 deletions(-) diff --git a/DisplayWindow.py b/DisplayWindow.py index d4518be..81c72a7 100644 --- a/DisplayWindow.py +++ b/DisplayWindow.py @@ -1,24 +1,25 @@ import sys import pandas as pd from PyQt6.QtCore import Qt -from sklearn.ensemble import RandomForestRegressor - +import matplotlib +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg +from matplotlib.figure import Figure import modeling -from data_processing import DataPreprocessor as dp -from modeling import Model - +from PyQt6.QtGui import QIcon +from sklearn.model_selection import train_test_split +import joblib from pyqtgraph import PlotWidget, plot import pyqtgraph as pg from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog, - QVBoxLayout, QHBoxLayout, QGridLayout, - QLabel, QPushButton, QSlider, QDateTimeEdit, - QLineEdit, QComboBox, QDateEdit, QTabWidget, QCheckBox) -from PyQt6.QtGui import QPalette, QColor, QIcon -from sklearn.model_selection import train_test_split -import joblib + QVBoxLayout, QGridLayout, + QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox) + +matplotlib.use('QtAgg') diamonds = pd.read_csv('diamonds.csv') diamonds.dropna(inplace=True) +mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0) +diamonds = diamonds.drop(diamonds[mask].index) cut = list(diamonds["cut"].unique()) colors = list(diamonds["color"].unique()) clarity = list(diamonds["clarity"].unique()) @@ -26,6 +27,13 @@ model = joblib.load('RF.joblib') price = 0 +class MplCanvas(FigureCanvasQTAgg): + def __init__(self, parent=None, width=5, height=4, dpi=100): + fig = Figure(figsize=(width, height), dpi=dpi) + self.axes = fig.add_subplot(111) + super().__init__(fig) + + class MyTabs(QMainWindow): @staticmethod def numericise(df): @@ -64,12 +72,17 @@ class MainWindow(QMainWindow): central_widget = QWidget(self) self.setCentralWidget(central_widget) + # Create a grid layout and set it to the central widget - grid_layout = QGridLayout(central_widget) + self.grid_layout = QGridLayout(central_widget) # List of Menu Items in dict self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x', 7: 'y', 8: 'z'} self.GUI_selections = {} + # create new graphs + self.scatt = MplCanvas(self, width=6, height=5, dpi=100) + self.scatt.axes.set_title("Carats vs Price") + self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2) # Create labels and input fields self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self), @@ -85,20 +98,18 @@ class MainWindow(QMainWindow): # Add labels and input fields to the grid layout for i, val in self.menu_items.items(): label = QLabel(val, self) - grid_layout.addWidget(label, 0, i) - grid_layout.addWidget(self.input_fields[i], 1, i) + self.grid_layout.addWidget(label, 0, i) + self.grid_layout.addWidget(self.input_fields[i], 1, i) # Create a plot widget - graphWidget = pg.PlotWidget(self) - graphWidget.setBackground('w') - graphWidget.plot(diamonds['carat'], diamonds['price']) + # Add the plot widget to the grid layout - grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2) - grid_layout.addWidget(graphWidget, 3, 0, 1, 9) - grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1) + self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2) + self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9) + self.grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1) self.display_price = QLabel(str(0)) - grid_layout.addWidget(self.display_price, 4, 2, 1, 1) + self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1) self.setWindowTitle('Assistance Systems') self.show() @@ -155,7 +166,13 @@ class Advanced(QMainWindow): central_widget = QWidget(self) self.setCentralWidget(central_widget) # Create a grid layout and set it to the central widget - grid_layout = QGridLayout(central_widget) + self.grid_layout = QGridLayout(central_widget) + + # create new graphs + self.histogram = MplCanvas(self, width=6, height=4, dpi=100) + # --------------- + self.scatt = MplCanvas(self, width=6, height=4, dpi=100) + # end graphs self.graph_selector_X = QComboBox(self) self.graph_selector_Y = QComboBox(self) @@ -174,23 +191,26 @@ class Advanced(QMainWindow): self.checkboxes = [] - grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0) - grid_layout.addWidget(self.graph_selector_X, 3, 1) - grid_layout.addWidget(self.graph_selector_Y, 3, 3) - grid_layout.addWidget(self.regression_model, 2, 1) - grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0) - grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2) - grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1) - grid_layout.addWidget(self.plot_graph, 2, 6) - grid_layout.addWidget(self.teach, 2, 7, 1, 2) - grid_layout.addWidget(self.R2, 2, 4, 1, 2) - grid_layout.addWidget(self.MSE, 2, 2, 1, 2) + # grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0) + self.grid_layout.addWidget(self.graph_selector_X, 3, 1) + self.grid_layout.addWidget(self.graph_selector_Y, 3, 3) + self.grid_layout.addWidget(self.regression_model, 2, 1) + self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0) + self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2) + self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1) + self.grid_layout.addWidget(self.plot_graph, 2, 6) + self.grid_layout.addWidget(self.teach, 2, 7, 1, 2) + self.grid_layout.addWidget(self.R2, 2, 4, 1, 2) + self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2) + + self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4) + self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4) for i, label in enumerate(self.check_labels): checkbox = QCheckBox(label, self) checkbox.setChecked(True) checkbox.stateChanged.connect(self.handle_checkbox_state) # Connect signal - grid_layout.addWidget(checkbox, 1, i) + self.grid_layout.addWidget(checkbox, 1, i) self.checkboxes.append(checkbox) # Store checkboxes in the list def handle_checkbox_state(self): @@ -216,11 +236,29 @@ class Advanced(QMainWindow): training = model.train() model.evaluate() - self.R2.setText(str(model.r2)) - self.MSE.setText(str(model.mse)) + self.R2.setText(f"R2 = {str(model.r2)}") + self.MSE.setText(f"MSE = {str(model.mse)}") def create_graph(self): - pass + + X = diamonds[self.graph_selector_X.currentText()] + Y = diamonds[self.graph_selector_Y.currentText()] + + self.scatt.axes.clear() + self.histogram.axes.clear() + + self.scatt.axes.scatter(X, Y, color='green', s=2) + self.histogram.axes.hist(X, bins=50) + self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs " + f"{self.graph_selector_Y.currentText()}" ) + self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}") + self.scatt.axes.set_xlabel(self.graph_selector_X.currentText()) + self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText()) + self.histogram.axes.set_xlabel(self.graph_selector_X.currentText()) + self.histogram.axes.set_xlabel("frequency") + + self.scatt.draw() + self.histogram.draw() app = QApplication(sys.argv) -- GitLab