Skip to content
Snippets Groups Projects
Synthcity.ipynb 9 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "299dec31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sdv.metadata import SingleTableMetadata\n",
    "from sdmetrics.reports.single_table import QualityReport\n",
    "from sdmetrics.reports.single_table import DiagnosticReport\n",
    "from table_evaluator import TableEvaluator\n",
    "import matplotlib.pyplot as plt\n",
    "from sdmetrics.single_column import StatisticSimilarity\n",
    "import math\n",
    "from sdmetrics.single_column import RangeCoverage\n",
    "from sdmetrics.visualization import get_column_plot\n",
    "import os\n",
    "import plotly.io as py\n",
    "import string\n",
    "\n",
    "from synthcity.plugins import Plugins\n",
    "\n",
    "#Plugins(categories=[\"generic\", \"privacy\"]).list() #uncomment to see a list of model for generating data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6127a704",
   "metadata": {},
   "outputs": [],
   "source": [
    "#loading the preprocessed datasets \n",
    "\n",
    "# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/benign.csv')\n",
    "# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/bot_attacks.csv')\n",
    "# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/bruteforce_attacks.csv')\n",
    "# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/doS_attacks.csv')\n",
    "# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/infilteration_attacks.csv')\n",
    "\n",
    "print(real_data.shape)\n",
    "print(real_data.Label.unique())\n",
    "\n",
    "# if bruteforce_attack or dos_attacks are used then uncomment the below line and change the name of the dataset accordingly\n",
    "#real_data=real_data[real_data.Label=='SSH-Bruteforce'] # change according to the dataset\n",
    "real_data = real_data.iloc[:300000, :]\n",
    "print(real_data.shape) # "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c41b506-aa5c-419c-8a49-f246de3ee6ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imported from https://github.com/vanderschaarlab/synthcity/tree/main\n",
    "#if using ADSGAN  just change \"rtvae\"  to \"adsgan\"\n",
    "\n",
    "syn_model = Plugins().get(\"rtvae\", n_iter= 500, lr=0.0001, batch_size= 300, decoder_n_layers_hidden=4, encoder_n_layers_hidden=4)\n",
    "\n",
    "syn_model.fit(real_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2510624-8cfc-480b-88ee-e2234117cb50",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generating synthetic data and saving the file\n",
    "synthetic_data=syn_model.generate(300000).dataframe()\n",
    "synthetic_data.to_csv('RTVAE_Results/LOICHTTP.csv', index=False) #similar to the loaded dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d895ace2-e2d9-4ee0-886a-742aebcbd6c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data_info(df):\n",
    "    \"\"\"Crates the categorical columns, continuous columns, and metadata of a dataframe.\n",
    "\n",
    "    Args:\n",
    "        df (pandas.Dataframe): The input dataframe containing continuous and categorical values.\n",
    "\n",
    "    Returns:\n",
    "        list: the list of categorical column names. Specifically, columns with only 4 uniques values\n",
    "        list: The list of continuous column names.\n",
    "        metadata: The metadata of the dataframe. for more informatin visit https://docs.sdv.dev/sdv/reference/metadata-spec/single-table-metadata-json\n",
    "    \"\"\"\n",
    "    #createing \n",
    "    categorical_columns = ['Label']\n",
    "    continuous_columns = []\n",
    "    for i in df.columns:\n",
    "        if i not in categorical_columns:\n",
    "            continuous_columns.append(i)\n",
    "    \n",
    "    #creating metadat\n",
    "    metadata = SingleTableMetadata()\n",
    "    metadata.detect_from_dataframe(df)\n",
    "    \n",
    "    for column in categorical_columns:\n",
    "        metadata.update_column(\n",
    "            column_name = column,\n",
    "            sdtype = 'categorical'\n",
    "        )\n",
    "    \n",
    "    for column in continuous_columns:\n",
    "        metadata.update_column(\n",
    "            column_name = column,\n",
    "            sdtype = 'numerical'  \n",
    "        )\n",
    "    # validating metadata\n",
    "    metadata.validate()\n",
    "    metadata.validate_data(data=real_data)\n",
    "    \n",
    "    return categorical_columns, continuous_columns, metadata\n",
    "\n",
    "\n",
    "categorical_columns, continuous_columns, metadata = get_data_info(real_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90330684-4fed-4571-9026-4cb04250e475",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluating synthetic data with table_evaluator cumulative sum per features and distribution\n",
    "table_evaluator = TableEvaluator(real_data, synthetic_data, cat_cols = categorical_columns)\n",
    "table_evaluator.visual_evaluation()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8833c4f-ec42-413f-96ec-34516401ec8b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#saving and visualizing column pair trend and column shapes\n",
    "metadata = metadata.to_dict()\n",
    "my_report = QualityReport()\n",
    "my_report.generate(real_data, synthetic_data, metadata)\n",
    "my_report.save(filepath='RTVAE_Results/LOICHTTP/quality.pkl')\n",
    "my_report.get_visualization(property_name='Column Pair Trends')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a510048-8949-45a4-b9d7-543b211fc710",
   "metadata": {},
   "outputs": [],
   "source": [
    "#saving and visualiztation data validity\n",
    "#metadata = metadata.to_dict()\n",
    "my_report = DiagnosticReport()\n",
    "my_report.generate(real_data, synthetic_data, metadata)\n",
    "my_report.save(filepath='RTVAE_Results/LOICHTTP.csv/diagnostic.pkl')\n",
    "#my_report.get_visualization('Data Validity')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "668f18aa-5a36-4a72-8c47-5c549a7a5a86",
   "metadata": {},
   "outputs": [],
   "source": [
    "#statistical similarity metric\n",
    "sstest=[]\n",
    "for i in real_data.columns:\n",
    "    y=StatisticSimilarity.compute(\n",
    "        real_data=real_data[i],\n",
    "        synthetic_data=synthetic_data[i],\n",
    "        statistic='median'\n",
    "        )\n",
    "    sstest.append(y)\n",
    "\n",
    "df = pd.DataFrame(sstest, columns=['SS Test'])\n",
    "\n",
    "print(df['SS Test'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "142c89a4-7977-49a0-8aed-69edf12ea07b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#range coverage metric\n",
    "range_coverage=[]\n",
    "for i in real_data.columns:\n",
    "    \n",
    "    y=RangeCoverage.compute(\n",
    "    real_data=real_data[i],\n",
    "    synthetic_data=synthetic_data[i]\n",
    "    )\n",
    "    range_coverage.append(y)\n",
    "df = pd.DataFrame(range_coverage, columns=['Range Coverage'])\n",
    "\n",
    "print(df['Range Coverage'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62ec2b27-5262-4906-a1b3-eb755c7dc0da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# checking the number of unique synthetic data instances\n",
    "df = pd.concat([real_data, synthetic_data], axis=0)\n",
    "print(df.shape)\n",
    "df.dropna(inplace=True)\n",
    "df.drop_duplicates(inplace=True)\n",
    "print(df.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ff9fbd-1632-4d5f-84ce-f64d13305a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Saving the distribution of each column\n",
    "def sanitize_column_name(column_name):\n",
    "    valid_chars = \"-_.() %s%s\" % (string.ascii_letters, string.digits)\n",
    "    return ''.join(c for c in column_name if c in valid_chars)\n",
    "\n",
    "for i in real_data.columns:\n",
    "    fig = get_column_plot(\n",
    "        real_data=real_data,\n",
    "        synthetic_data=synthetic_data,\n",
    "        column_name=i,\n",
    "        plot_type='bar'\n",
    "    )\n",
    "\n",
    "    sanitized_column_name = sanitize_column_name(i)\n",
    "\n",
    "    # Save the figure in the 'Pics' directory, change the location accordingly\n",
    "    py.write_image(fig, os.path.join('RTVAE_Results/LOICHTTP/Pics', f\"{sanitized_column_name}.png\")) \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}