From 3f021252cc24631bf714c25c084c7c706f6d416d Mon Sep 17 00:00:00 2001
From: Michael <michael.mutote@stud.th-deg.de>
Date: Sun, 24 Dec 2023 11:40:02 -0800
Subject: [PATCH] plots completed

---
 DisplayWindow.py | 112 +++++++++++++++++++++++++++++++----------------
 1 file changed, 75 insertions(+), 37 deletions(-)

diff --git a/DisplayWindow.py b/DisplayWindow.py
index d4518be..81c72a7 100644
--- a/DisplayWindow.py
+++ b/DisplayWindow.py
@@ -1,24 +1,25 @@
 import sys
 import pandas as pd
 from PyQt6.QtCore import Qt
-from sklearn.ensemble import RandomForestRegressor
-
+import matplotlib
+from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
+from matplotlib.figure import Figure
 import modeling
-from data_processing import DataPreprocessor as dp
-from modeling import Model
-
+from PyQt6.QtGui import QIcon
+from sklearn.model_selection import train_test_split
+import joblib
 from pyqtgraph import PlotWidget, plot
 import pyqtgraph as pg
 from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
-                             QVBoxLayout, QHBoxLayout, QGridLayout,
-                             QLabel, QPushButton, QSlider, QDateTimeEdit,
-                             QLineEdit, QComboBox, QDateEdit, QTabWidget, QCheckBox)
-from PyQt6.QtGui import QPalette, QColor, QIcon
-from sklearn.model_selection import train_test_split
-import joblib
+                             QVBoxLayout, QGridLayout,
+                             QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)
+
+matplotlib.use('QtAgg')
 
 diamonds = pd.read_csv('diamonds.csv')
 diamonds.dropna(inplace=True)
+mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0)
+diamonds = diamonds.drop(diamonds[mask].index)
 cut = list(diamonds["cut"].unique())
 colors = list(diamonds["color"].unique())
 clarity = list(diamonds["clarity"].unique())
@@ -26,6 +27,13 @@ model = joblib.load('RF.joblib')
 price = 0
 
 
+class MplCanvas(FigureCanvasQTAgg):
+    def __init__(self, parent=None, width=5, height=4, dpi=100):
+        fig = Figure(figsize=(width, height), dpi=dpi)
+        self.axes = fig.add_subplot(111)
+        super().__init__(fig)
+
+
 class MyTabs(QMainWindow):
     @staticmethod
     def numericise(df):
@@ -64,12 +72,17 @@ class MainWindow(QMainWindow):
         central_widget = QWidget(self)
         self.setCentralWidget(central_widget)
 
+
         # Create a grid layout and set it to the central widget
-        grid_layout = QGridLayout(central_widget)
+        self.grid_layout = QGridLayout(central_widget)
         # List of Menu Items in dict
         self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
                            7: 'y', 8: 'z'}
         self.GUI_selections = {}
+        # create new graphs
+        self.scatt = MplCanvas(self, width=6, height=5, dpi=100)
+        self.scatt.axes.set_title("Carats vs Price")
+        self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2)
 
         # Create labels and input fields
         self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
@@ -85,20 +98,18 @@ class MainWindow(QMainWindow):
         # Add labels and input fields to the grid layout
         for i, val in self.menu_items.items():
             label = QLabel(val, self)
-            grid_layout.addWidget(label, 0, i)
-            grid_layout.addWidget(self.input_fields[i], 1, i)
+            self.grid_layout.addWidget(label, 0, i)
+            self.grid_layout.addWidget(self.input_fields[i], 1, i)
 
         # Create a plot widget
-        graphWidget = pg.PlotWidget(self)
-        graphWidget.setBackground('w')
-        graphWidget.plot(diamonds['carat'], diamonds['price'])
+
 
         # Add the plot widget to the grid layout
-        grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
-        grid_layout.addWidget(graphWidget, 3, 0, 1, 9)
-        grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
+        self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
+        self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9)
+        self.grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
         self.display_price = QLabel(str(0))
-        grid_layout.addWidget(self.display_price, 4, 2, 1, 1)
+        self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1)
 
         self.setWindowTitle('Assistance Systems')
         self.show()
@@ -155,7 +166,13 @@ class Advanced(QMainWindow):
         central_widget = QWidget(self)
         self.setCentralWidget(central_widget)
         # Create a grid layout and set it to the central widget
-        grid_layout = QGridLayout(central_widget)
+        self.grid_layout = QGridLayout(central_widget)
+
+        # create new graphs
+        self.histogram = MplCanvas(self, width=6, height=4, dpi=100)
+        # ---------------
+        self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
+        # end graphs
 
         self.graph_selector_X = QComboBox(self)
         self.graph_selector_Y = QComboBox(self)
@@ -174,23 +191,26 @@ class Advanced(QMainWindow):
 
         self.checkboxes = []
 
-        grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
-        grid_layout.addWidget(self.graph_selector_X, 3, 1)
-        grid_layout.addWidget(self.graph_selector_Y, 3, 3)
-        grid_layout.addWidget(self.regression_model, 2, 1)
-        grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
-        grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
-        grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
-        grid_layout.addWidget(self.plot_graph, 2, 6)
-        grid_layout.addWidget(self.teach, 2, 7, 1, 2)
-        grid_layout.addWidget(self.R2, 2, 4, 1, 2)
-        grid_layout.addWidget(self.MSE, 2, 2, 1, 2)
+        # grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
+        self.grid_layout.addWidget(self.graph_selector_X, 3, 1)
+        self.grid_layout.addWidget(self.graph_selector_Y, 3, 3)
+        self.grid_layout.addWidget(self.regression_model, 2, 1)
+        self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
+        self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
+        self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
+        self.grid_layout.addWidget(self.plot_graph, 2, 6)
+        self.grid_layout.addWidget(self.teach, 2, 7, 1, 2)
+        self.grid_layout.addWidget(self.R2, 2, 4, 1, 2)
+        self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2)
+
+        self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
+        self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)
 
         for i, label in enumerate(self.check_labels):
             checkbox = QCheckBox(label, self)
             checkbox.setChecked(True)
             checkbox.stateChanged.connect(self.handle_checkbox_state)  # Connect signal
-            grid_layout.addWidget(checkbox, 1, i)
+            self.grid_layout.addWidget(checkbox, 1, i)
             self.checkboxes.append(checkbox)  # Store checkboxes in the list
 
     def handle_checkbox_state(self):
@@ -216,11 +236,29 @@ class Advanced(QMainWindow):
         training = model.train()
         model.evaluate()
 
-        self.R2.setText(str(model.r2))
-        self.MSE.setText(str(model.mse))
+        self.R2.setText(f"R2 = {str(model.r2)}")
+        self.MSE.setText(f"MSE = {str(model.mse)}")
 
     def create_graph(self):
-        pass
+
+        X = diamonds[self.graph_selector_X.currentText()]
+        Y = diamonds[self.graph_selector_Y.currentText()]
+
+        self.scatt.axes.clear()
+        self.histogram.axes.clear()
+
+        self.scatt.axes.scatter(X, Y, color='green', s=2)
+        self.histogram.axes.hist(X, bins=50)
+        self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs "
+                                  f"{self.graph_selector_Y.currentText()}" )
+        self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}")
+        self.scatt.axes.set_xlabel(self.graph_selector_X.currentText())
+        self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText())
+        self.histogram.axes.set_xlabel(self.graph_selector_X.currentText())
+        self.histogram.axes.set_xlabel("frequency")
+
+        self.scatt.draw()
+        self.histogram.draw()
 
 
 app = QApplication(sys.argv)
-- 
GitLab