import sys
import pandas as pd
from PyQt6.QtCore import Qt
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import modeling
from PyQt6.QtGui import QIcon, QMovie
from sklearn.model_selection import train_test_split
import joblib

from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
                             QVBoxLayout, QGridLayout,
                             QLabel, QPushButton, QLineEdit, QComboBox, QTabWidget, QCheckBox)

matplotlib.use('QtAgg')

diamonds = pd.read_csv('diamonds.csv')
diamonds.dropna(inplace=True)
mask = (diamonds['x'] == 0) | (diamonds['y'] == 0) | (diamonds['z'] == 0)
diamonds = diamonds.drop(diamonds[mask].index)
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
model = joblib.load('models/RF.joblib')
price = 0


class MplCanvas(FigureCanvasQTAgg):
    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        super().__init__(fig)


class MyTabs(QMainWindow):
    @staticmethod
    def numericise(df):
        if Advanced.Advanced_selections['cut']:
            df['cut'] = df['cut'].map({'Ideal': 0, 'Premium': 1, 'Very Good': 2, 'Good': 3, 'Fair': 4})
        if Advanced.Advanced_selections['color']:
            df['color'] = df['color'].map({'E': 0, 'I': 1, 'J': 2, 'H': 3, 'F': 4, 'G': 5, 'D': 6})
        if Advanced.Advanced_selections['clarity']:
            df['clarity'] = df['clarity'].map({'SI2': 0, 'SI1': 1, 'VS1': 2, 'VS2': 3, 'VVS2': 4, 'VVS1': 5,
                                               'IF': 7, 'I1': 6})
        return df

    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon("icon.ico"))
        self.setMinimumSize(720, 640)
        # Create your custom pages (e.g., TradeView and OrderView)
        self.calc = MainWindow()
        self.adjust = Advanced()

        # Create a tab widget
        tab_widget = QTabWidget(self)
        tab_widget.addTab(self.calc, "Main")
        tab_widget.addTab(self.adjust, "Advanced")

        # Set the central widget
        self.setCentralWidget(tab_widget)


class MainWindow(QMainWindow):

    def __init__(self):
        super().__init__()

        # Create a central widget
        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)

        # Create a grid layout and set it to the central widget
        self.grid_layout = QGridLayout(central_widget)
        # List of Menu Items in dict
        self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
                           7: 'y', 8: 'z'}
        self.GUI_selections = {}
        # create new graphs
        self.scatt = MplCanvas(self, width=6, height=5, dpi=100)
        # self.scatt.axes.set_title("Carats vs Price")
        # self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2, label='carat')
        # self.scatt.axes.scatter(self.GUI_selections['carat'], self.price, color='blue', s=50, label='expected')

        # Create labels and input fields
        self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
                             QLineEdit(self), QLineEdit(self), QLineEdit(self), QLineEdit(self),
                             QLineEdit(self)]
        self.input_fields[0].addItems(cut)
        self.input_fields[1].addItems(colors)
        self.input_fields[2].addItems(clarity)
        # add a button to capture all the data
        self.calculate_button = QPushButton('Calculate Price', self)
        self.calculate_button.clicked.connect(self.update_values)

        # Add labels and input fields to the grid layout
        for i, val in self.menu_items.items():
            label = QLabel(val, self)
            self.grid_layout.addWidget(label, 0, i)
            self.grid_layout.addWidget(self.input_fields[i], 1, i)

        # Create a plot widget

        # Add the plot widget to the grid layout
        self.grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
        self.grid_layout.addWidget(self.scatt, 3, 0, 1, 9)
        predicted_price_label = QLabel("Predicted Price: ", self)
        predicted_price_label.setStyleSheet("font-size: 20px;")
        self.grid_layout.addWidget(predicted_price_label, 4, 0, 1, 1)
        self.display_price = QLabel(str(0))
        self.display_price.setStyleSheet("font-size: 20px; font-weight: bold; color: purple; font-family: Ubuntu;")
        self.grid_layout.addWidget(self.display_price, 4, 2, 1, 1)

        self.setWindowTitle('Assistance Systems')
        self.show()

    def disable_input_field(self, field_name):
        for key, value in self.menu_items.items():
            if value == field_name:
                self.input_fields[key].setDisabled(True)
                self.input_fields[key].setStyleSheet("background-color: grey; color: white;")

    def enable_input_field(self, field_name):
        for key, value in self.menu_items.items():
            if value == field_name:
                self.input_fields[key].setEnabled(True)
                self.input_fields[key].setStyleSheet("background-color: white; color: black;")

    def update_values(self):
        self.GUI_selections = {}
        if not Advanced.checkboxes_changed:
            for key, value in self.menu_items.items():
                if Advanced.Advanced_selections[value]:
                    if key < 3:
                        self.GUI_selections[value] = self.input_fields[key].currentText()
                    else:
                        if self.validate_inputs(self.input_fields[key].text(), value):
                            self.GUI_selections[value] = float(self.input_fields[key].text())
                        else:
                            return
        else:
            self.dialog_box("ERROR", "you need to retrain the model before recalculating,"
                                     " parameters have changed")
            return

        # convert GUI_selections to dataframe
        X_test = pd.DataFrame(self.GUI_selections, index=[0])
        print(X_test.head())

        X_test = MyTabs.numericise(X_test)
        self.price = model.predict_price(X_test)
        # price = model.predict(X_test)[0]
        self.display_price.setText(str(self.price))
        self.create_graph()

    def create_graph(self):
        self.scatt.axes.clear()
        self.scatt.axes.set_title("Carats vs Price")
        self.scatt.axes.scatter(diamonds['carat'], diamonds['price'], color='red', s=2, label='carat')
        self.scatt.axes.axhline(y=self.price, color='blue', linestyle='--', label='Predicted Price')
        try:
            self.scatt.axes.scatter(self.GUI_selections['carat'], self.price, color='blue', s=50, label='expected')
            self.scatt.axes.axvline(x=self.GUI_selections['carat'], color='blue', linestyle='--', label='carat')
            
        except KeyError:
            pass
        self.scatt.draw()

    def validate_inputs(self, val, attribute):
        if Advanced.Advanced_selections[attribute]:
            try:
                isinstance(float(val), float)
                return True
            except ValueError:
                self.dialog_box("ERROR", f"{attribute} numerical value required")
                return False
        else:
            return True

    def dialog_box(self, title, message):
        dlg = QDialog(self)
        dlg.setGeometry(400, 500, 200, 100)
        dlg.setWindowTitle(title)
        layout = QVBoxLayout()
        layout.addWidget(QLabel(message))
        button = QPushButton("OK")
        button.clicked.connect(lambda: dlg.reject())
        layout.addWidget(button)
        dlg.setLayout(layout)
        dlg.exec()


class Advanced(QMainWindow):
    Advanced_selections = {'cut': True, 'color': True, 'clarity': True, 'carat': True, 'depth': True,
                           'table': True, 'x': True, 'y': True, 'z': True}

    checkboxes_changed = False

    def __init__(self):
        super().__init__()

        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)
        # Create a grid layout and set it to the central widget
        self.grid_layout = QGridLayout(central_widget)

        # create new graphs
        self.histogram = MplCanvas(self, width=6, height=4, dpi=100)
        # ---------------
        self.scatt = MplCanvas(self, width=6, height=4, dpi=100)
        # end graphs
        self.graph_selector_X = QComboBox(self)
        self.graph_selector_Y = QComboBox(self)
        self.regression_model = QComboBox(self)
        self.R2 = QLabel("R-Squared Error = ")
        self.MSE = QLabel("Mean Squared Error = ")
        self.check_labels = ['cut', 'color', 'clarity', 'carat', 'depth', 'table', 'x', 'y', 'z']
        self.graph_selector_X.addItems((self.check_labels + ['price']))
        self.graph_selector_Y.addItems((self.check_labels + ['price']))
        self.regression_model.addItems(["Random Forest Regressor",
                                        "Linear Regression",
                                        "XGBRegressor",
                                        "Neural Network"])
        self.teach = QPushButton('Re-Train', self)
        self.teach.clicked.connect(self.re_train)
        self.plot_graph = QPushButton('PLOT', self)
        self.plot_graph.clicked.connect(self.create_graph)
        self.checkboxes = []

        # grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
        self.grid_layout.addWidget(self.graph_selector_X, 3, 1)
        self.grid_layout.addWidget(self.graph_selector_Y, 3, 3)
        self.grid_layout.addWidget(self.regression_model, 2, 1)
        self.grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
        self.grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
        self.grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
        self.grid_layout.addWidget(self.plot_graph, 3, 4)
        self.grid_layout.addWidget(self.teach, 2, 6, 1, 1)
        self.grid_layout.addWidget(self.R2, 2, 4, 1, 2)
        self.grid_layout.addWidget(self.MSE, 2, 2, 1, 2)

        self.grid_layout.addWidget(self.histogram, 0, 0, 1, 4)
        self.grid_layout.addWidget(self.scatt, 0, 4, 1, 4)
        for i, label in enumerate(self.check_labels):
            checkbox = QCheckBox(label, self)
            checkbox.setChecked(True)
            checkbox.stateChanged.connect(self.handle_checkbox_state)  # Connect signal
            self.grid_layout.addWidget(checkbox, 1, i)
            self.checkboxes.append(checkbox)  # Store checkboxes in the list

    def handle_checkbox_state(self):
        Advanced.checkboxes_changed = True
        for i, checkbox in enumerate(self.checkboxes):
            state = checkbox.checkState()
            if state == Qt.CheckState.Unchecked:
                Advanced.Advanced_selections[checkbox.text()] = False
                window.calc.disable_input_field(checkbox.text())
            else:
                Advanced.Advanced_selections[checkbox.text()] = True
                window.calc.enable_input_field(checkbox.text())
        

    def re_train(self):      
        self.teach.setText('In progress')
        self.teach.setStyleSheet("background-color: red; color: white;")
        self.teach.repaint()
        global diamonds
        X = diamonds.copy()
        for key, value in Advanced.Advanced_selections.items():
            if not value:
                X = X.drop(key, axis=1)
        y = X['price']
        X = X.drop('price', axis=1)
        X = MyTabs.numericise(X)
        self.teach.setEnabled(False)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        global model
        model = modeling.Model(X_train, X_test, y_train, y_test, self.regression_model.currentText())
        model.train()
        model.evaluate()
        self.R2.setText(f"R2 = {str(model.r2)}")
        self.MSE.setText(f"MSE = {str(model.mse)}")
        Advanced.checkboxes_changed = False
        self.teach.setText('Re-Train')
        self.teach.setStyleSheet("background-color: white; color: black;")
        self.teach.setEnabled(True)

    def create_graph(self):
        X = diamonds[self.graph_selector_X.currentText()]
        Y = diamonds[self.graph_selector_Y.currentText()]

        self.scatt.axes.clear()
        self.histogram.axes.clear()

        self.scatt.axes.scatter(X, Y, color='green', s=2)
        self.histogram.axes.hist(X, bins=50)
        self.scatt.axes.set_title(f"Scatter Plot {self.graph_selector_X.currentText()} vs "
                                  f"{self.graph_selector_Y.currentText()}")
        self.histogram.axes.set_title(f"Histogram {self.graph_selector_X.currentText()}")
        self.scatt.axes.set_xlabel(self.graph_selector_X.currentText())
        self.scatt.axes.set_ylabel(self.graph_selector_Y.currentText())
        self.histogram.axes.set_xlabel(self.graph_selector_X.currentText())
        self.histogram.axes.set_ylabel("frequency")

        self.scatt.draw()
        self.histogram.draw()


app = QApplication(sys.argv)

window = MyTabs()
window.show()
app.exec()