From d1a050a48c60d2c862c17c6d2241e93ad396a98c Mon Sep 17 00:00:00 2001 From: Fadi <fadi.gattoussi@stud.th-deg.de> Date: Sun, 24 Dec 2023 22:08:03 +0100 Subject: [PATCH] Added Disabling of input fields when checkboxes are unchecked --- DisplayWindow.py | 49 +++++++++++++++++++-------- training3.ipynb | 86 ++++++++++++++++++++++-------------------------- 2 files changed, 75 insertions(+), 60 deletions(-) diff --git a/DisplayWindow.py b/DisplayWindow.py index 8a6f34a..210cbe2 100644 --- a/DisplayWindow.py +++ b/DisplayWindow.py @@ -8,8 +8,6 @@ import modeling from PyQt6.QtGui import QIcon from sklearn.model_selection import train_test_split import joblib -from pyqtgraph import PlotWidget, plot -import pyqtgraph as pg from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog, QVBoxLayout, QGridLayout, QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox) @@ -107,24 +105,38 @@ class MainWindow(QMainWindow): # Add the plot widget to the grid layout self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2) self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9) - self.grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1) + predicted_price_label = QLabel("Predicted Price: ", self) + predicted_price_label.setStyleSheet("font-size: 20px;") + self.grid_layout.addWidget(predicted_price_label, 4, 0, 1, 1) self.display_price = QLabel(str(0)) + self.display_price.setStyleSheet("font-size: 20px; font-weight: bold; color: purple; font-family: Ubuntu;") self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1) self.setWindowTitle('Assistance Systems') self.show() + def disable_input_field(self, field_name): + for key, value in self.menu_items.items(): + if value == field_name: + self.input_fields[key].setDisabled(True) + + def enable_input_field(self, field_name): + for key, value in self.menu_items.items(): + if value == field_name: + self.input_fields[key].setEnabled(True) + def update_values(self): self.GUI_selections = {} - for key, value in self.menu_items.items(): - if Advanced.Advanced_selections[value]: - if key < 3: - self.GUI_selections[value] = self.input_fields[key].currentText() - else: - if self.validate_inputs(self.input_fields[key].text(), value): - self.GUI_selections[value] = float(self.input_fields[key].text()) + if Advanced.checkboxes_changed: + for key, value in self.menu_items.items(): + if Advanced.Advanced_selections[value]: + if key < 3: + self.GUI_selections[value] = self.input_fields[key].currentText() else: - return + if self.validate_inputs(self.input_fields[key].text(), value): + self.GUI_selections[value] = float(self.input_fields[key].text()) + else: + return # convert GUI_selections to dataframe X_test = pd.DataFrame(self.GUI_selections, index=[0]) @@ -159,6 +171,8 @@ class MainWindow(QMainWindow): class Advanced(QMainWindow): Advanced_selections = {'cut': True, 'color': True, 'clarity': True, 'carat': True, 'depth': True, 'table': True, 'x': True, 'y': True, 'z': True} + + checkboxes_changed = False def __init__(self): super().__init__() @@ -182,8 +196,10 @@ class Advanced(QMainWindow): self.check_labels = ['cut', 'color', 'clarity', 'carat', 'depth', 'table', 'x', 'y', 'z'] self.graph_selector_X.addItems((self.check_labels + ['price'])) self.graph_selector_Y.addItems((self.check_labels + ['price'])) - self.regression_model.addItems(["Linear Regression", "XGBRegressor", "Neural Network", - "Random Forest Regressor"]) + self.regression_model.addItems(["Random Forest Regressor", + "Linear Regression", + "XGBRegressor", + "Neural Network"]) self.teach = QPushButton('Re-Train', self) self.teach.clicked.connect(self.re_train) self.plot_graph = QPushButton('PLOT', self) @@ -214,12 +230,16 @@ class Advanced(QMainWindow): self.checkboxes.append(checkbox) # Store checkboxes in the list def handle_checkbox_state(self): + Advanced.checkboxes_changed = False for i, checkbox in enumerate(self.checkboxes): state = checkbox.checkState() if state == Qt.CheckState.Unchecked: Advanced.Advanced_selections[checkbox.text()] = False + window.calc.disable_input_field(checkbox.text()) else: Advanced.Advanced_selections[checkbox.text()] = True + window.calc.enable_input_field(checkbox.text()) + def re_train(self): global diamonds @@ -238,6 +258,7 @@ class Advanced(QMainWindow): self.R2.setText(f"R2 = {str(model.r2)}") self.MSE.setText(f"MSE = {str(model.mse)}") + Advanced.checkboxes_changed = True def create_graph(self): @@ -259,7 +280,7 @@ class Advanced(QMainWindow): self.scatt.draw() self.histogram.draw() - + app = QApplication(sys.argv) diff --git a/training3.ipynb b/training3.ipynb index 4c5e3c5..6478f3a 100644 --- a/training3.ipynb +++ b/training3.ipynb @@ -1039,80 +1039,74 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean Squared Error: 292726.7655844448\n", + "R^2 Score: 0.9817410390026241\n" + ] + } + ], + "source": [ + "# random forest\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "model3 = RandomForestRegressor(n_estimators=100, random_state=42)\n", + "model3.fit(X_train, y_train)\n", + "\n", + "from sklearn.metrics import mean_squared_error, r2_score\n", + "y_pred3 = model3.predict(X_test)\n", + "mse3 = mean_squared_error(y_test, y_pred3)\n", + "r23 = r2_score(y_test, y_pred3)\n", + "print(\"Mean Squared Error:\", mse3)\n", + "print(\"R^2 Score:\", r23)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/fadi/Desktop/dev/sas2/clone2/ws-23-sas-02/sas2/lib/python3.10/site-packages/sklearn/base.py:465: UserWarning: X does not have valid feature names, but LinearRegression was fitted with feature names\n", + "/home/fadi/Desktop/dev/sas2/clone2/ws-23-sas-02/sas2/lib/python3.10/site-packages/sklearn/base.py:465: UserWarning: X does not have valid feature names, but RandomForestRegressor was fitted with feature names\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ - "array([3672.24864574])" + "array([374.44])" ] }, - "execution_count": 28, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "test_diamond = [[\n", - " 0.86,\n", - " 3,\n", - " 0,\n", - " 61,\n", - " 58,\n", - " 6.15,\n", - " 6.12,\n", - " 3.74\n", - "]]\n", - "test_price = model.predict(test_diamond)\n", + "test_diamond = [[0.23, 0, 0, 61.5, 55.0, 3.95, 3.98, 2.43]]\n", + "test_price = model3.predict(test_diamond)\n", "test_price\n" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'y_pred3' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[29], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m model3\u001b[38;5;241m.\u001b[39mfit(X_train, y_train)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msklearn\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmetrics\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m mean_squared_error, r2_score\n\u001b[0;32m----> 7\u001b[0m mse3 \u001b[38;5;241m=\u001b[39m mean_squared_error(y_test, \u001b[43my_pred3\u001b[49m)\n\u001b[1;32m 8\u001b[0m r23 \u001b[38;5;241m=\u001b[39m r2_score(y_test, y_pred3)\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMean Squared Error:\u001b[39m\u001b[38;5;124m\"\u001b[39m, mse3)\n", - "\u001b[0;31mNameError\u001b[0m: name 'y_pred3' is not defined" - ] - } - ], + "outputs": [], "source": [ - "# random forest\n", - "from sklearn.ensemble import RandomForestRegressor\n", - "model3 = RandomForestRegressor(n_estimators=100, random_state=42)\n", - "model3.fit(X_train, y_train)\n", + "import numpy as np\n", "\n", - "from sklearn.metrics import mean_squared_error, r2_score\n", - "y_pred3 = model3.predict(X_test)\n", - "mse3 = mean_squared_error(y_test, y_pred3)\n", - "r23 = r2_score(y_test, y_pred3)\n", - "print(\"Mean Squared Error:\", mse3)\n", - "print(\"R^2 Score:\", r23)" + "data = np.array([[0.23, 0, 0, 61.5, 55.0, 3.95, 3.98, 2.43]])\n", + "print(data)\n", + "0.23\t0\t0\t61.5\t55.0\t3.95\t3.98\t2.43" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { -- GitLab