Skip to content
Snippets Groups Projects
Commit 3f021252 authored by Michael Mutote's avatar Michael Mutote
Browse files

plots completed

parent 1b3ab56d
No related branches found
No related tags found
No related merge requests found
import sys
import pandas as pd
from PyQt6.QtCore import Qt
from sklearn.ensemble import RandomForestRegressor
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import modeling
from data_processing import DataPreprocessor as dp
from modeling import Model
from PyQt6.QtGui import QIcon
from sklearn.model_selection import train_test_split
import joblib
from pyqtgraph import PlotWidget, plot
import pyqtgraph as pg
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
QVBoxLayout, QHBoxLayout, QGridLayout,
QLabel, QPushButton, QSlider, QDateTimeEdit,
QLineEdit, QComboBox, QDateEdit, QTabWidget, QCheckBox)
from PyQt6.QtGui import QPalette, QColor, QIcon
from sklearn.model_selection import train_test_split
import joblib
QVBoxLayout, QGridLayout,
QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)
matplotlib.use('QtAgg')
diamonds = pd.read_csv('diamonds.csv')
diamonds.dropna(inplace=True)
mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0)
diamonds = diamonds.drop(diamonds[mask].index)
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
......@@ -26,6 +27,13 @@ model = joblib.load('RF.joblib')
price = 0
class MplCanvas(FigureCanvasQTAgg):
def __init__(self, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
self.axes = fig.add_subplot(111)
super().__init__(fig)
class MyTabs(QMainWindow):
@staticmethod
def numericise(df):
......@@ -64,12 +72,17 @@ class MainWindow(QMainWindow):
central_widget = QWidget(self)
self.setCentralWidget(central_widget)
# Create a grid layout and set it to the central widget
grid_layout = QGridLayout(central_widget)
self.grid_layout = QGridLayout(central_widget)
# List of Menu Items in dict
self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
7: 'y', 8: 'z'}
self.GUI_selections = {}
# create new graphs
self.scatt = MplCanvas(self, width=6, height=5, dpi=100)
self.scatt.axes.set_title("Carats vs Price")
self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2)
# Create labels and input fields
self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
......@@ -85,20 +98,18 @@ class MainWindow(QMainWindow):
# Add labels and input fields to the grid layout
for i, val in self.menu_items.items():
label = QLabel(val, self)
grid_layout.addWidget(label, 0, i)
grid_layout.addWidget(self.input_fields[i], 1, i)
self.grid_layout.addWidget(label, 0, i)
self.grid_layout.addWidget(self.input_fields[i], 1, i)
# Create a plot widget
graphWidget = pg.PlotWidget(self)
graphWidget.setBackground('w')
graphWidget.plot(diamonds['carat'], diamonds['price'])
# Add the plot widget to the grid layout
grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
grid_layout.addWidget(graphWidget, 3, 0, 1, 9)
grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9)
self.grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
self.display_price = QLabel(str(0))
grid_layout.addWidget(self.display_price, 4, 2, 1, 1)
self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1)
self.setWindowTitle('Assistance Systems')
self.show()
......@@ -155,7 +166,13 @@ class Advanced(QMainWindow):
central_widget = QWidget(self)
self.setCentralWidget(central_widget)
# Create a grid layout and set it to the central widget
grid_layout = QGridLayout(central_widget)
self.grid_layout = QGridLayout(central_widget)
# create new graphs
self.histogram = MplCanvas(self, width=6, height=4, dpi=100)
# ---------------
self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
# end graphs
self.graph_selector_X = QComboBox(self)
self.graph_selector_Y = QComboBox(self)
......@@ -174,23 +191,26 @@ class Advanced(QMainWindow):
self.checkboxes = []
grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
grid_layout.addWidget(self.graph_selector_X, 3, 1)
grid_layout.addWidget(self.graph_selector_Y, 3, 3)
grid_layout.addWidget(self.regression_model, 2, 1)
grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
grid_layout.addWidget(self.plot_graph, 2, 6)
grid_layout.addWidget(self.teach, 2, 7, 1, 2)
grid_layout.addWidget(self.R2, 2, 4, 1, 2)
grid_layout.addWidget(self.MSE, 2, 2, 1, 2)
# grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
self.grid_layout.addWidget(self.graph_selector_X, 3, 1)
self.grid_layout.addWidget(self.graph_selector_Y, 3, 3)
self.grid_layout.addWidget(self.regression_model, 2, 1)
self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
self.grid_layout.addWidget(self.plot_graph, 2, 6)
self.grid_layout.addWidget(self.teach, 2, 7, 1, 2)
self.grid_layout.addWidget(self.R2, 2, 4, 1, 2)
self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2)
self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)
for i, label in enumerate(self.check_labels):
checkbox = QCheckBox(label, self)
checkbox.setChecked(True)
checkbox.stateChanged.connect(self.handle_checkbox_state) # Connect signal
grid_layout.addWidget(checkbox, 1, i)
self.grid_layout.addWidget(checkbox, 1, i)
self.checkboxes.append(checkbox) # Store checkboxes in the list
def handle_checkbox_state(self):
......@@ -216,11 +236,29 @@ class Advanced(QMainWindow):
training = model.train()
model.evaluate()
self.R2.setText(str(model.r2))
self.MSE.setText(str(model.mse))
self.R2.setText(f"R2 = {str(model.r2)}")
self.MSE.setText(f"MSE = {str(model.mse)}")
def create_graph(self):
pass
X = diamonds[self.graph_selector_X.currentText()]
Y = diamonds[self.graph_selector_Y.currentText()]
self.scatt.axes.clear()
self.histogram.axes.clear()
self.scatt.axes.scatter(X, Y, color='green', s=2)
self.histogram.axes.hist(X, bins=50)
self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs "
f"{self.graph_selector_Y.currentText()}" )
self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}")
self.scatt.axes.set_xlabel(self.graph_selector_X.currentText())
self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText())
self.histogram.axes.set_xlabel(self.graph_selector_X.currentText())
self.histogram.axes.set_xlabel("frequency")
self.scatt.draw()
self.histogram.draw()
app = QApplication(sys.argv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment