import sys
import pandas as pd
from PyQt6.QtCore import Qt
from sklearn.ensemble import RandomForestRegressor

import modeling
from data_processing import DataPreprocessor as dp
from modeling import Model

from pyqtgraph import PlotWidget, plot
import pyqtgraph as pg
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QDialog,
                             QVBoxLayout, QHBoxLayout, QGridLayout,
                             QLabel, QPushButton, QSlider, QDateTimeEdit,
                             QLineEdit, QComboBox, QDateEdit, QTabWidget, QCheckBox)
from PyQt6.QtGui import QPalette, QColor, QIcon
from sklearn.model_selection import train_test_split
import joblib

diamonds = pd.read_csv('diamonds.csv')
diamonds.dropna(inplace=True)
cut = list(diamonds["cut"].unique())
colors = list(diamonds["color"].unique())
clarity = list(diamonds["clarity"].unique())
# diamonds_numerical = diamonds.copy()
# cut_mapping = {cut: i for i, cut in enumerate(diamonds['cut'].unique())}
# color_mapping = {color: i for i, color in enumerate(diamonds['color'].unique())}
# clarity_mapping = {clarity: i for i, clarity in enumerate(diamonds['clarity'].unique())}
# diamonds_numerical['cut'] = diamonds['cut'].map(cut_mapping)
# diamonds_numerical['color'] = diamonds['color'].map(color_mapping)
# diamonds_numerical['clarity'] = diamonds['clarity'].map(clarity_mapping)
# X = diamonds_numerical.drop('price', axis = 1)
# y = diamonds_numerical['price']
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# model = Model(X_train, X_test, y_train, y_test,'RF')
# model.train()
# joblib.dump(model, 'RF.joblib')
model = joblib.load('RF.joblib')
price = 0


class MyTabs(QMainWindow):
    @staticmethod
    def numericise(df):
        if Advanced.Advanced_selections['cut']:
            df['cut'] = df['cut'].map({'Ideal': 0, 'Premium': 1, 'Very Good': 2, 'Good': 3, 'Fair': 4})
        if Advanced.Advanced_selections['color']:
            df['color'] = df['color'].map({'E': 0, 'I': 1, 'J': 2, 'H': 3, 'F': 4, 'G': 5, 'D': 6})
        if Advanced.Advanced_selections['clarity']:
            df['clarity'] = df['clarity'].map({'SI2': 0, 'SI1': 1, 'VS1': 2, 'VS2': 3, 'VVS2': 4, 'VVS1': 5,
                                               'IF': 7, 'I1': 6})
        return df

    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon("icon.ico"))
        self.setMinimumSize(720, 640)
        # Create your custom pages (e.g., TradeView and OrderView)
        self.calc = MainWindow()
        self.adjust = Advanced()

        # Create a tab widget
        tab_widget = QTabWidget(self)
        tab_widget.addTab(self.calc, "Main")
        tab_widget.addTab(self.adjust, "Advanced")

        # Set the central widget
        self.setCentralWidget(tab_widget)


class MainWindow(QMainWindow):

    def __init__(self):
        super().__init__()

        # Create a central widget
        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)

        # Create a grid layout and set it to the central widget
        grid_layout = QGridLayout(central_widget)
        # List of Menu Items in dict
        self.menu_items = {3: 'carat', 0: 'cut', 1: 'color', 2: 'clarity', 4: 'depth', 5: 'table', 6: 'x',
                           7: 'y', 8: 'z'}
        self.GUI_selections = {}

        # Create labels and input fields
        self.input_fields = [QComboBox(self), QComboBox(self), QComboBox(self), QLineEdit(self),
                             QLineEdit(self), QLineEdit(self), QLineEdit(self), QLineEdit(self),
                             QLineEdit(self)]
        self.input_fields[0].addItems(cut)
        self.input_fields[1].addItems(colors)
        self.input_fields[2].addItems(clarity)
        # add a button to capture all the data
        self.calculate_button = QPushButton('Calculate Price', self)
        self.calculate_button.clicked.connect(self.update_values)

        # Add labels and input fields to the grid layout
        for i, val in self.menu_items.items():
            label = QLabel(val, self)
            grid_layout.addWidget(label, 0, i)
            grid_layout.addWidget(self.input_fields[i], 1, i)

        # Create a plot widget
        graphWidget = pg.PlotWidget(self)
        graphWidget.setBackground('w')
        graphWidget.plot(diamonds['carat'], diamonds['price'])

        # Add the plot widget to the grid layout
        grid_layout.addWidget(self.calculate_button, 2, 0, 1, 2)
        grid_layout.addWidget(graphWidget, 3, 0, 1, 9)
        grid_layout.addWidget(QLabel("Predicted Price: ", self), 4, 0, 1, 1)
        self.display_price = QLabel(str(0))
        grid_layout.addWidget(self.display_price, 4, 2, 1, 1)

        self.setWindowTitle('Assistance Systems')
        self.show()

    def update_values(self):

        for key, value in self.menu_items.items():
            if Advanced.Advanced_selections[value]:
                if key < 3:
                    self.GUI_selections[value] = self.input_fields[key].currentText()
                else:
                    if self.validate_inputs(self.input_fields[key].text(), value):
                        self.GUI_selections[value] = self.input_fields[key].text()
                    else:
                        return

        # convert GUI_selections to dataframe
        X_test = pd.DataFrame(self.GUI_selections, index=[0])
        print(X_test.head())

        X_test = MyTabs.numericise(X_test)

        # price = model.predict(X_test)[0]
        self.display_price.setText(str(model.predict_price(X_test)))

    def validate_inputs(self, val, attribute):
        if Advanced.Advanced_selections[attribute]:
            try:
                isinstance(float(val), float)
                return True
            except ValueError:
                dlg = QDialog(self)
                dlg.setGeometry(400, 500, 200, 100)
                dlg.setWindowTitle("ERROR")
                layout = QVBoxLayout()
                layout.addWidget(QLabel(f"{attribute} numerical value required"))
                button = QPushButton("OK")
                button.clicked.connect(lambda: dlg.reject())
                layout.addWidget(button)
                dlg.setLayout(layout)
                dlg.exec()
                return False
        else:
            return True


class Advanced(QMainWindow):
    Advanced_selections = {'cut': True, 'color': True, 'clarity': True, 'carat': True, 'depth': True,
                           'table': True, 'x': True, 'y': True, 'z': True}

    def __init__(self):
        super().__init__()

        central_widget = QWidget(self)
        self.setCentralWidget(central_widget)
        # Create a grid layout and set it to the central widget
        grid_layout = QGridLayout(central_widget)

        self.graph_selector_X = QComboBox(self)
        self.graph_selector_Y = QComboBox(self)
        self.regression_model = QComboBox(self)
        self.R2 = QLabel("R-Squared Error = ")
        self.MSE = QLabel("Mean Squared Error = ")
        self.check_labels = ['cut', 'color', 'clarity', 'carat', 'depth', 'table', 'x', 'y', 'z']
        self.graph_selector_X.addItems((self.check_labels + ['price']))
        self.graph_selector_Y.addItems((self.check_labels + ['price']))
        self.regression_model.addItems(["Linear Regression", "XGBRegressor", "Neural Network",
                                        "Random Forest Regressor"])
        self.teach = QPushButton('Re-Teach', self)
        self.teach.clicked.connect(self.re_teach)
        self.plot_graph = QPushButton('PLOT', self)
        self.plot_graph.clicked.connect(self.create_graph)

        self.checkboxes = []

        grid_layout.addWidget(QLabel("select features to include in Modelling"), 0, 0, 1, 0)
        grid_layout.addWidget(self.graph_selector_X, 3, 1)
        grid_layout.addWidget(self.graph_selector_Y, 3, 3)
        grid_layout.addWidget(self.regression_model, 2, 1)
        grid_layout.addWidget(QLabel("X-PLOT", self), 3, 0)
        grid_layout.addWidget(QLabel("Y-PLOT", self), 3, 2)
        grid_layout.addWidget(QLabel("Regression Model : ", self), 2, 0, 1, 1)
        grid_layout.addWidget(self.plot_graph, 2, 6)
        grid_layout.addWidget(self.teach, 2, 7, 1, 2)
        grid_layout.addWidget(self.R2, 2, 4, 1, 2)
        grid_layout.addWidget(self.MSE, 2, 2, 1, 2)

        for i, label in enumerate(self.check_labels):
            checkbox = QCheckBox(label, self)
            checkbox.setChecked(True)
            checkbox.stateChanged.connect(self.handle_checkbox_state)  # Connect signal
            grid_layout.addWidget(checkbox, 1, i)
            self.checkboxes.append(checkbox)  # Store checkboxes in the list

    def handle_checkbox_state(self):
        for i, checkbox in enumerate(self.checkboxes):
            state = checkbox.checkState()
            if state == Qt.CheckState.Unchecked:
                Advanced.Advanced_selections[checkbox.text()] = False
            else:
                Advanced.Advanced_selections[checkbox.text()] = True

    def re_teach(self):
        global diamonds
        X = diamonds.copy()
        for key, value in Advanced.Advanced_selections.items():
            if not value:
                X = X.drop(key, axis=1)
        y = X['price']
        X = X.drop('price', axis=1)
        X = MyTabs.numericise(X)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        global model
        model = modeling.Model(X_train, X_test, y_train, y_test, self.regression_model.currentText())
        training = model.train()
        model.evaluate()

        self.R2.setText(str(model.r2))
        self.MSE.setText(str(model.mse))

    def create_graph(self):
        pass


app = QApplication(sys.argv)

window = MyTabs()
window.show()
app.exec()