import sys
import pandas as pd
from PyQt6.QtCore import Qt
import matplotlib
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import modeling
from PyQt6.QtGui import QIcon
from sklearn.model_selection import train_test_split
import joblib
from pyqtgraph import PlotWidget, plot
import pyqtgraph as pg
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
                             QVBoxLayout, QGridLayout,
                             QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)

matplotlib.use('QtAgg')

diamonds = pd.read_csv('diamonds.csv')
diamonds.dropna(inplace=True)
mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0)
diamonds = diamonds.drop(diamonds[mask].index)
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
model = joblib.load('RF.joblib')
price = 0


class MplCanvas(FigureCanvasQTAgg):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        super().__init__(fig)


class MyTabs(QMainWindow):
    @staticmethod
    def numericise(df):
        if Advanced.Advanced_selections['cut']:
            df['cut'] = df['cut'].map({'Ideal': 0, 'Premium': 1, 'Very Good': 2, 'Good': 3, 'Fair': 4})
        if Advanced.Advanced_selections['color']:
            df['color'] = df['color'].map({'E': 0, 'I': 1, 'J': 2, 'H': 3, 'F': 4, 'G': 5, 'D': 6})
        if Advanced.Advanced_selections['clarity']:
            df['clarity'] = df['clarity'].map({'SI2': 0, 'SI1': 1, 'VS1': 2, 'VS2': 3, 'VVS2': 4, 'VVS1': 5,
                                               'IF': 7, 'I1': 6})
        return df

    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon("icon.ico"))
        self.setMinimumSize(720, 640)
        # Create your custom pages (e.g., TradeView and OrderView)
        self.calc = MainWindow()
        self.adjust = Advanced()

        # Create a tab widget
        tab_widget = QTabWidget(self)
        tab_widget.addTab(self.calc, "Main")
        tab_widget.addTab(self.adjust, "Advanced")

        # Set the central widget
        self.setCentralWidget(tab_widget)


class MainWindow(QMainWindow):

    def __init__(self):
        super().__init__()

        # Create a central widget
        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)


        # Create a grid layout and set it to the central widget
        self.grid_layout = QGridLayout(central_widget)
        # List of Menu Items in dict
        self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
                           7: 'y', 8: 'z'}
        self.GUI_selections = {}
        # create new graphs
        self.scatt = MplCanvas(self, width=6, height=5, dpi=100)
        self.scatt.axes.set_title("Carats vs Price")
        self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2)

        # Create labels and input fields
        self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
                             QLineEdit(self), QLineEdit(self), QLineEdit(self), QLineEdit(self),
                             QLineEdit(self)]
        self.input_fields[0].addItems(cut)
        self.input_fields[1].addItems(colors)
        self.input_fields[2].addItems(clarity)
        # add a button to capture all the data
        self.calculate_button = QPushButton('Calculate Price', self)
        self.calculate_button.clicked.connect(self.update_values)

        # Add labels and input fields to the grid layout
        for i, val in self.menu_items.items():
            label = QLabel(val, self)
            self.grid_layout.addWidget(label, 0, i)
            self.grid_layout.addWidget(self.input_fields[i], 1, i)

        # Create a plot widget


        # Add the plot widget to the grid layout
        self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
        self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9)
        self.grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
        self.display_price = QLabel(str(0))
        self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1)

        self.setWindowTitle('Assistance Systems')
        self.show()

    def update_values(self):
        self.GUI_selections = {}
        for key, value in self.menu_items.items():
            if Advanced.Advanced_selections[value]:
                if key < 3:
                    self.GUI_selections[value] = self.input_fields[key].currentText()
                else:
                    if self.validate_inputs(self.input_fields[key].text(), value):
                        self.GUI_selections[value] = float(self.input_fields[key].text())
                    else:
                        return

        # convert GUI_selections to dataframe
        X_test = pd.DataFrame(self.GUI_selections, index=[0])
        print(X_test.head())

        X_test = MyTabs.numericise(X_test)

        # price = model.predict(X_test)[0]
        self.display_price.setText(str(model.predict_price(X_test)))

    def validate_inputs(self, val, attribute):
        if Advanced.Advanced_selections[attribute]:
            try:
                isinstance(float(val), float)
                return True
            except ValueError:
                dlg = QDialog(self)
                dlg.setGeometry(400, 500, 200, 100)
                dlg.setWindowTitle("ERROR")
                layout = QVBoxLayout()
                layout.addWidget(QLabel(f"{attribute} numerical value required"))
                button = QPushButton("OK")
                button.clicked.connect(lambda: dlg.reject())
                layout.addWidget(button)
                dlg.setLayout(layout)
                dlg.exec()
                return False
        else:
            return True


class Advanced(QMainWindow):
    Advanced_selections = {'cut': True, 'color': True, 'clarity': True, 'carat': True, 'depth': True,
                           'table': True, 'x': True, 'y': True, 'z': True}

    def __init__(self):
        super().__init__()

        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)
        # Create a grid layout and set it to the central widget
        self.grid_layout = QGridLayout(central_widget)

        # create new graphs
        self.histogram = MplCanvas(self, width=6, height=4, dpi=100)
        # ---------------
        self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
        # end graphs

        self.graph_selector_X = QComboBox(self)
        self.graph_selector_Y = QComboBox(self)
        self.regression_model = QComboBox(self)
        self.R2 = QLabel("R-Squared Error = ")
        self.MSE = QLabel("Mean Squared Error = ")
        self.check_labels = ['cut', 'color', 'clarity', 'carat', 'depth', 'table', 'x', 'y', 'z']
        self.graph_selector_X.addItems((self.check_labels + ['price']))
        self.graph_selector_Y.addItems((self.check_labels + ['price']))
        self.regression_model.addItems(["Linear Regression", "XGBRegressor", "Neural Network",
                                        "Random Forest Regressor"])
        self.teach = QPushButton('Re-Train', self)
        self.teach.clicked.connect(self.re_train)
        self.plot_graph = QPushButton('PLOT', self)
        self.plot_graph.clicked.connect(self.create_graph)

        self.checkboxes = []

        # grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
        self.grid_layout.addWidget(self.graph_selector_X, 3, 1)
        self.grid_layout.addWidget(self.graph_selector_Y, 3, 3)
        self.grid_layout.addWidget(self.regression_model, 2, 1)
        self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
        self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
        self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
        self.grid_layout.addWidget(self.plot_graph, 2, 6)
        self.grid_layout.addWidget(self.teach, 2, 7, 1, 2)
        self.grid_layout.addWidget(self.R2, 2, 4, 1, 2)
        self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2)

        self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
        self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)

        for i, label in enumerate(self.check_labels):
            checkbox = QCheckBox(label, self)
            checkbox.setChecked(True)
            checkbox.stateChanged.connect(self.handle_checkbox_state)  # Connect signal
            self.grid_layout.addWidget(checkbox, 1, i)
            self.checkboxes.append(checkbox)  # Store checkboxes in the list

    def handle_checkbox_state(self):
        for i, checkbox in enumerate(self.checkboxes):
            state = checkbox.checkState()
            if state == Qt.CheckState.Unchecked:
                Advanced.Advanced_selections[checkbox.text()] = False
            else:
                Advanced.Advanced_selections[checkbox.text()] = True

    def re_train(self):
        global diamonds
        X = diamonds.copy()
        for key, value in Advanced.Advanced_selections.items():
            if not value:
                X = X.drop(key, axis=1)
        y = X['price']
        X = X.drop('price', axis=1)
        X = MyTabs.numericise(X)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        global model
        model = modeling.Model(X_train, X_test, y_train, y_test, self.regression_model.currentText())
        training = model.train()
        model.evaluate()

        self.R2.setText(f"R2 = {str(model.r2)}")
        self.MSE.setText(f"MSE = {str(model.mse)}")

    def create_graph(self):

        X = diamonds[self.graph_selector_X.currentText()]
        Y = diamonds[self.graph_selector_Y.currentText()]

        self.scatt.axes.clear()
        self.histogram.axes.clear()

        self.scatt.axes.scatter(X, Y, color='green', s=2)
        self.histogram.axes.hist(X, bins=50)
        self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs "
                                  f"{self.graph_selector_Y.currentText()}" )
        self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}")
        self.scatt.axes.set_xlabel(self.graph_selector_X.currentText())
        self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText())
        self.histogram.axes.set_xlabel(self.graph_selector_X.currentText())
        self.histogram.axes.set_ylabel("frequency")

        self.scatt.draw()
        self.histogram.draw()


app = QApplication(sys.argv)

window = MyTabs()
window.show()
app.exec()