diff --git a/app.py b/app.py index ce765c6f11f76d8d8610c0ef0a0642cc3d96ee6d..0510bee10f750920df06d788c0c58005402aaebc 100644 --- a/app.py +++ b/app.py @@ -6,47 +6,47 @@ from sklearn.preprocessing import LabelEncoder import joblib # For loading Scikit-Learn models # Load the dataset -try: - df = pd.read_csv("FakeDataset.csv") -except FileNotFoundError: - df = None - st.error("Dataset not found. Please ensure 'FakeDataset.csv' is in the correct path.") +df = pd.read_csv("FakeDataset.csv") # Load Scikit-Learn models -try: - model_1 = joblib.load("model_1.pkl") # Model for satisfaction prediction - model_2 = joblib.load("model_2.pkl") # Model for flight delay prediction -except FileNotFoundError: - st.error("Model files not found. Please ensure 'model_1.pkl' and 'model_2.pkl' are in the correct path.") +model_1 = joblib.load("model_1.pkl") # Model for satisfaction prediction +model_2 = joblib.load("model_2.pkl") # Model for flight delay prediction # Dynamically create the encoders if they don't exist country_code_encoder = None airport_encoder = None -if df is not None: - if not os.path.exists("country_code_encoder.pkl"): - st.warning("Country code encoder file not found. Creating a new encoder dynamically...") - if "Airport Country Code" in df.columns: - country_code_encoder = LabelEncoder() - df["Airport Country Code"] = country_code_encoder.fit_transform(df["Airport Country Code"]) - joblib.dump(country_code_encoder, "country_code_encoder.pkl") - st.success("Country code encoder created and saved as 'country_code_encoder.pkl'.") - else: - st.error("The dataset is missing the 'Airport Country Code' column.") +if not os.path.exists("country_code_encoder.pkl"): + st.warning("Country code encoder file not found. Creating a new encoder dynamically...") + if "Airport Country Code" in df.columns: + # Create a new encoder using the dataset + country_code_encoder = LabelEncoder() + df["Airport Country Code"] = country_code_encoder.fit_transform(df["Airport Country Code"]) + + # Save the encoder for future use + joblib.dump(country_code_encoder, "country_code_encoder.pkl") + st.success("Country code encoder has been created and saved as 'country_code_encoder.pkl'.") else: - country_code_encoder = joblib.load("country_code_encoder.pkl") - - if not os.path.exists("airport_encoder.pkl"): - st.warning("Airport encoder file not found. Creating a new encoder dynamically...") - if "Arrival Airport" in df.columns: - airport_encoder = LabelEncoder() - df["Arrival Airport"] = airport_encoder.fit_transform(df["Arrival Airport"]) - joblib.dump(airport_encoder, "airport_encoder.pkl") - st.success("Airport encoder created and saved as 'airport_encoder.pkl'.") - else: - st.error("The dataset is missing the 'Arrival Airport' column.") + st.error("The dataset is missing the 'Airport Country Code' column.") +else: + # Load the encoder if it already exists + country_code_encoder = joblib.load("country_code_encoder.pkl") + +if not os.path.exists("airport_encoder.pkl"): + st.warning("Airport encoder file not found. Creating a new encoder dynamically...") + if "Arrival Airport" in df.columns: + # Create a new encoder using the dataset + airport_encoder = LabelEncoder() + df["Arrival Airport"] = airport_encoder.fit_transform(df["Arrival Airport"]) + + # Save the encoder for future use + joblib.dump(airport_encoder, "airport_encoder.pkl") + st.success("Airport encoder has been created and saved as 'airport_encoder.pkl'.") else: - airport_encoder = joblib.load("airport_encoder.pkl") + st.error("The dataset is missing the 'Arrival Airport' column.") +else: + # Load the encoder if it already exists + airport_encoder = joblib.load("airport_encoder.pkl") # Initialize Streamlit st.title("Airline Chatbot & Data Explorer with Predictions") @@ -57,19 +57,114 @@ st.image("Chatbot_image.gif", caption="Welcome to the Airline Chatbot & Data Exp # Add navigation menu = st.sidebar.selectbox("Select a Feature", ["Chatbot", "Data Explorer", "Visualizations", "Predictions"]) -# Functions for different sections +def show_chatbot_section(): + st.subheader("Chat with Airline Assistant") + + # Initialize session states for conversation flow + if "greeted" not in st.session_state: + st.session_state.greeted = False + if "help_query" not in st.session_state: + st.session_state.help_query = False + if "details_fetched" not in st.session_state: + st.session_state.details_fetched = False + if "farewell" not in st.session_state: + st.session_state.farewell = False + + # Reset button functionality + if st.button("Reset Chatbot"): + st.session_state.greeted = False + st.session_state.help_query = False + st.session_state.details_fetched = False + st.session_state.farewell = False + st.success("The chatbot has been reset. How can I help you?") + + # Greeting and initial input + if not st.session_state.greeted: + user_input = st.text_input("Start by saying 'Hey', 'Good Morning', or 'Hello':") + if user_input: + if any(greet in user_input.lower() for greet in ["hey", "good morning", "hello", "hi", "hey there", "good afternoon", "good evening", "greetings", "good day", "howdy", "sup", "what's up"]): + st.write("Hello! How can I assist you today?") + st.session_state.greeted = True + else: + st.warning("Please start with a greeting like 'Hey' or 'Good Morning'.") + + # Respond to user query + if st.session_state.greeted and not st.session_state.help_query: + user_query = st.text_input("What would you like assistance with? (e.g., 'I want to know my flight details'):") + if user_query: + if any(keyword in user_query.lower() for keyword in ["details", "my details", "know my details", "flight details", "flight status", "my flight details", "get details", "get my details", "flight information", "flight info"]): + st.success("Understood, let's proceed.") + st.session_state.help_query = True + else: + st.info("Let me know how I can assist you further.") + + # Passenger ID and Last Name functionality + if st.session_state.help_query and not st.session_state.details_fetched: + passenger_id = st.text_input("Please enter your Passenger ID:") + last_name = st.text_input("Please enter your Last Name:") + + if st.button("Get Flight Details"): + if not passenger_id: + st.warning("Please enter your Passenger ID.") + elif not last_name: + st.warning("Please enter your Last Name.") + else: + result = df[df.applymap(lambda x: passenger_id.lower() in str(x).lower() if pd.notnull(x) else False).any(axis=1)] + result = result[result["Last Name"].str.lower() == last_name.lower()] + if not result.empty: + st.success("**Here are your flight details:**") + for index, row in result.iterrows(): + st.markdown(f"**Passenger ID:** <code>{row['Passenger ID']}</code>", unsafe_allow_html=True) + st.markdown(f"**First Name:** <code>{row['First Name']}</code>", unsafe_allow_html=True) + st.markdown(f"**Last Name:** <code>{row['Last Name']}</code>", unsafe_allow_html=True) + st.markdown(f"**Gender:** <code>{row['Gender']}</code>", unsafe_allow_html=True) + st.markdown(f"**Age:** <code>{row['Age']}</code>", unsafe_allow_html=True) + st.markdown(f"**Nationality:** <code>{row['Nationality']}</code>", unsafe_allow_html=True) + st.markdown(f"**Airport Name:** <code>{row['Airport Name']}</code>", unsafe_allow_html=True) + st.markdown(f"**Departure Date:** <code>{row['Departure Date']}</code>", unsafe_allow_html=True) + st.markdown(f"**Arrival Airport:** <code>{row['Arrival Airport']}</code>", unsafe_allow_html=True) + st.markdown(f"**Pilot Name:** <code>{row['Pilot Name']}</code>", unsafe_allow_html=True) + st.markdown(f"**Flight Status:** <code>{row['Flight Status']}</code>", unsafe_allow_html=True) + st.write("---") + st.session_state.details_fetched = True + else: + st.error("No matching records found. Please check the Passenger ID or Last Name.") + + # Further assistance or goodbye + if st.session_state.details_fetched and not st.session_state.farewell: + further_assistance = st.text_input("Do you need any further assistance? (Yes/No)") + if further_assistance: + if "no" in further_assistance.lower() or "goodbye" in further_assistance.lower(): + st.success("Happy to assist you!") + st.session_state.farewell = True + elif "yes" in further_assistance.lower(): + st.info("Please describe your next query or ask for more details.") + else: + st.warning("Please respond with 'Yes' or 'No'.") + + # Reset message after farewell + if st.session_state.farewell: + st.info("Press the **Reset Chatbot** button to restart the conversation and ask for other details.") + + +def show_data_explorer_section(): + st.subheader("Explore the Airline Dataset") + + # Display dataset columns + columns = st.multiselect("Select columns to view:", options=df.columns.tolist(), default=df.columns.tolist()) + st.dataframe(df[columns]) + + # Filter by Flight Status + flight_statuses = st.multiselect("Filter by Flight Status:", options=df["Flight Status"].unique().tolist(), default=df["Flight Status"].unique().tolist()) + filtered_data = df[df["Flight Status"].isin(flight_statuses)] + st.write(f"Filtered Records: {len(filtered_data)}") + st.dataframe(filtered_data) + + +# Updated Visualization Function def show_visualizations_section(df): st.subheader("Visualize the Airline Dataset") - # Flight Status Bar Chart - status_counts = df["Flight Status"].value_counts().reset_index() - status_counts.columns = ["Flight Status", "Count"] - status_chart = alt.Chart(status_counts).mark_bar().encode( - x=alt.X("Flight Status", title="Flight Status"), - y=alt.Y("Count", title="Count"), - tooltip=["Flight Status", "Count"] - ).properties(title="Flight Status Distribution") - st.altair_chart(status_chart, use_container_width=True) - + # Passenger Nationality Pie Chart nationality_counts = df["Nationality"].value_counts().reset_index().head(10) nationality_counts.columns = ["Nationality", "Count"] @@ -88,15 +183,115 @@ def show_visualizations_section(df): ).properties(title="Passenger Age Distribution") st.altair_chart(age_histogram, use_container_width=True) + # Flight Status Bar Chart + status_counts = df["Flight Status"].value_counts().reset_index() + status_counts.columns = ["Flight Status", "Count"] + status_chart = alt.Chart(status_counts).mark_bar().encode( + x=alt.X("Flight Status", title="Flight Status"), + y=alt.Y("Count", title="Count"), + tooltip=["Flight Status", "Count"] + ).properties(title="Flight Status Distribution") + st.altair_chart(status_chart, use_container_width=True) + +def show_predictions_section(): + st.subheader("Make Predictions with Scikit-Learn Models") + + # Input fields for Model 1 (Passenger Satisfaction Prediction) + st.write("### Passenger Satisfaction Prediction") + age = st.slider("Age", min_value=18, max_value=100, value=30) + gender = st.selectbox("Gender", ["Male", "Female"]) + flight_class = st.selectbox("Flight Class", ["Economy", "Business", "First Class"]) + flight_duration = st.number_input("Flight Duration (hours)", min_value=0.0, value=2.0) + + # Map gender and flight class to numerical values for model input + gender_map = {"Male": 0, "Female": 1} + flight_class_map = {"Economy": 0, "Business": 1, "First Class": 2} + + gender_num = gender_map[gender] + flight_class_num = flight_class_map[flight_class] + + # Model 1 Prediction + if st.button("Predict Satisfaction"): + satisfaction_input = pd.DataFrame( + [[age, gender_num, flight_class_num, flight_duration]], + columns=["Age", "Gender", "Flight Class", "Flight Duration"] + ) + satisfaction_prediction = model_1.predict(satisfaction_input) + st.write(f"Predicted Passenger Satisfaction: {'Satisfied' if satisfaction_prediction[0] == 1 else 'Not Satisfied'}") + + # Generate live graph for Satisfaction Prediction + satisfaction_data = pd.DataFrame({ + "Category": ["Not Satisfied", "Satisfied"], + "Prediction": [1 - satisfaction_prediction[0], satisfaction_prediction[0]], + }) + satisfaction_chart = alt.Chart(satisfaction_data).mark_bar().encode( + x="Category", + y="Prediction", + color="Category" + ).properties(title="Passenger Satisfaction Prediction") + st.altair_chart(satisfaction_chart, use_container_width=True) + + st.write("---") # Separator + + # Input fields for Model 2 (Flight Delay Prediction) + st.write("### Flight Delay Prediction") + airport_country_code = st.text_input("Airport Country Code (e.g., US)").upper() + arrival_airport = st.text_input("Arrival Airport Code (e.g., JFK)").upper() + departure_time = st.number_input("Departure Time (24-hour format, e.g., 13.5 for 1:30 PM)", min_value=0.0, max_value=24.0) + + if not country_code_encoder or not airport_encoder: + st.error("Encoders are not available. Predictions cannot proceed.") + return + + if st.button("Predict Delay"): + try: + # Encode inputs + if airport_country_code not in country_code_encoder.classes_: + st.error(f"Invalid Airport Country Code: {airport_country_code}") + return + + if arrival_airport not in airport_encoder.classes_: + st.error(f"Invalid Arrival Airport Code: {arrival_airport}") + return + + airport_country_code_encoded = country_code_encoder.transform([airport_country_code])[0] + arrival_airport_encoded = airport_encoder.transform([arrival_airport])[0] + + # Prepare input for prediction as a DataFrame + delay_input = pd.DataFrame( + [[airport_country_code_encoded, arrival_airport_encoded, departure_time]], + columns=["Airport Country Code", "Arrival Airport", "Departure Time"] + ) + + # Perform prediction + delay_prediction = model_2.predict(delay_input) + + st.write(f"Predicted Flight Delay: {'Delayed' if delay_prediction[0] == 1 else 'On Time'}") + + # Generate live graph for Delay Prediction + delay_data = pd.DataFrame({ + "Category": ["On Time", "Delayed"], + "Prediction": [1 - delay_prediction[0], delay_prediction[0]], + }) + delay_chart = alt.Chart(delay_data).mark_bar().encode( + x="Category", + y="Prediction", + color="Category" + ).properties(title="Flight Delay Prediction") + st.altair_chart(delay_chart, use_container_width=True) + except ValueError as e: + st.error(f"Error: {e}. Please ensure the input values are valid and match the training data.") + # Main Navigation if menu == "Chatbot": - st.subheader("Chatbot functionality to be implemented...") + show_chatbot_section() elif menu == "Data Explorer": - st.subheader("Data Explorer functionality to be implemented...") + show_data_explorer_section() elif menu == "Visualizations": if df is not None: - show_visualizations_section(df) + show_visualizations_section(df) # Pass the dataframe as an argument else: st.error("Dataset not loaded. Please check the file path.") elif menu == "Predictions": - st.subheader("Prediction functionality to be implemented...") + show_predictions_section() +