Skip to content
Snippets Groups Projects
Commit 1a322564 authored by Sayed Saeedi's avatar Sayed Saeedi
Browse files

Codes for ADSGAN and RTVAE

parent 47c8160a
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id:299dec31 tags:
``` python
import pandas as pd
import numpy as np
from sdv.metadata import SingleTableMetadata
from sdmetrics.reports.single_table import QualityReport
from sdmetrics.reports.single_table import DiagnosticReport
from table_evaluator import TableEvaluator
import matplotlib.pyplot as plt
from sdmetrics.single_column import StatisticSimilarity
import math
from sdmetrics.single_column import RangeCoverage
from sdmetrics.visualization import get_column_plot
import os
import plotly.io as py
import string
from synthcity.plugins import Plugins
#Plugins(categories=["generic", "privacy"]).list() #uncomment to see a list of model for generating data
```
%% Cell type:code id:6127a704 tags:
``` python
#loading the preprocessed datasets
# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/benign.csv')
# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/bot_attacks.csv')
# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/bruteforce_attacks.csv')
# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/doS_attacks.csv')
# real_data = pd.read_csv('Datasets/Preprocessed_Datasets/infilteration_attacks.csv')
print(real_data.shape)
print(real_data.Label.unique())
# if bruteforce_attack or dos_attacks are used then uncomment the below line and change the name of the dataset accordingly
#real_data=real_data[real_data.Label=='SSH-Bruteforce'] # change according to the dataset
real_data = real_data.iloc[:300000, :]
print(real_data.shape) #
```
%% Cell type:code id:9c41b506-aa5c-419c-8a49-f246de3ee6ae tags:
``` python
# imported from https://github.com/vanderschaarlab/synthcity/tree/main
#if using ADSGAN just change "rtvae" to "adsgan"
syn_model = Plugins().get("rtvae", n_iter= 500, lr=0.0001, batch_size= 300, decoder_n_layers_hidden=4, encoder_n_layers_hidden=4)
syn_model.fit(real_data)
```
%% Cell type:code id:c2510624-8cfc-480b-88ee-e2234117cb50 tags:
``` python
# generating synthetic data and saving the file
synthetic_data=syn_model.generate(300000).dataframe()
synthetic_data.to_csv('RTVAE_Results/LOICHTTP.csv', index=False) #similar to the loaded dataset
```
%% Cell type:code id:d895ace2-e2d9-4ee0-886a-742aebcbd6c0 tags:
``` python
def get_data_info(df):
"""Crates the categorical columns, continuous columns, and metadata of a dataframe.
Args:
df (pandas.Dataframe): The input dataframe containing continuous and categorical values.
Returns:
list: the list of categorical column names. Specifically, columns with only 4 uniques values
list: The list of continuous column names.
metadata: The metadata of the dataframe. for more informatin visit https://docs.sdv.dev/sdv/reference/metadata-spec/single-table-metadata-json
"""
#createing
categorical_columns = ['Label']
continuous_columns = []
for i in df.columns:
if i not in categorical_columns:
continuous_columns.append(i)
#creating metadat
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(df)
for column in categorical_columns:
metadata.update_column(
column_name = column,
sdtype = 'categorical'
)
for column in continuous_columns:
metadata.update_column(
column_name = column,
sdtype = 'numerical'
)
# validating metadata
metadata.validate()
metadata.validate_data(data=real_data)
return categorical_columns, continuous_columns, metadata
categorical_columns, continuous_columns, metadata = get_data_info(real_data)
```
%% Cell type:code id:90330684-4fed-4571-9026-4cb04250e475 tags:
``` python
# evaluating synthetic data with table_evaluator cumulative sum per features and distribution
table_evaluator = TableEvaluator(real_data, synthetic_data, cat_cols = categorical_columns)
table_evaluator.visual_evaluation()
```
%% Cell type:code id:d8833c4f-ec42-413f-96ec-34516401ec8b tags:
``` python
#saving and visualizing column pair trend and column shapes
metadata = metadata.to_dict()
my_report = QualityReport()
my_report.generate(real_data, synthetic_data, metadata)
my_report.save(filepath='RTVAE_Results/LOICHTTP/quality.pkl')
my_report.get_visualization(property_name='Column Pair Trends')
```
%% Cell type:code id:9a510048-8949-45a4-b9d7-543b211fc710 tags:
``` python
#saving and visualiztation data validity
#metadata = metadata.to_dict()
my_report = DiagnosticReport()
my_report.generate(real_data, synthetic_data, metadata)
my_report.save(filepath='RTVAE_Results/LOICHTTP.csv/diagnostic.pkl')
#my_report.get_visualization('Data Validity')
```
%% Cell type:code id:668f18aa-5a36-4a72-8c47-5c549a7a5a86 tags:
``` python
#statistical similarity metric
sstest=[]
for i in real_data.columns:
y=StatisticSimilarity.compute(
real_data=real_data[i],
synthetic_data=synthetic_data[i],
statistic='median'
)
sstest.append(y)
df = pd.DataFrame(sstest, columns=['SS Test'])
print(df['SS Test'].mean())
```
%% Cell type:code id:142c89a4-7977-49a0-8aed-69edf12ea07b tags:
``` python
#range coverage metric
range_coverage=[]
for i in real_data.columns:
y=RangeCoverage.compute(
real_data=real_data[i],
synthetic_data=synthetic_data[i]
)
range_coverage.append(y)
df = pd.DataFrame(range_coverage, columns=['Range Coverage'])
print(df['Range Coverage'].mean())
```
%% Cell type:code id:62ec2b27-5262-4906-a1b3-eb755c7dc0da tags:
``` python
# checking the number of unique synthetic data instances
df = pd.concat([real_data, synthetic_data], axis=0)
print(df.shape)
df.dropna(inplace=True)
df.drop_duplicates(inplace=True)
print(df.shape)
```
%% Cell type:code id:53ff9fbd-1632-4d5f-84ce-f64d13305a9b tags:
``` python
#Saving the distribution of each column
def sanitize_column_name(column_name):
valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
return ''.join(c for c in column_name if c in valid_chars)
for i in real_data.columns:
fig = get_column_plot(
real_data=real_data,
synthetic_data=synthetic_data,
column_name=i,
plot_type='bar'
)
sanitized_column_name = sanitize_column_name(i)
# Save the figure in the 'Pics' directory, change the location accordingly
py.write_image(fig, os.path.join('RTVAE_Results/LOICHTTP/Pics', f"{sanitized_column_name}.png"))
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment