from aindo.rdml.eval import compute_privacy_stats, report
from aindo.rdml.relational import Column, RelationalData, Schema, Table
from aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer, Validation
device : str | torch.device | None ,
# Load data and define schema
' workclass ' : Column.CATEGORICAL,
' fnlwgt ' : Column.INTEGER,
' education ' : Column.CATEGORICAL,
' education-num ' : Column.CATEGORICAL,
' marital-status ' : Column.CATEGORICAL,
' occupation ' : Column.CATEGORICAL,
' relationship ' : Column.CATEGORICAL,
' race ' : Column.CATEGORICAL,
' sex ' : Column.CATEGORICAL,
' capital-gain ' : Column.INTEGER,
' capital-loss ' : Column.INTEGER,
' hours-per-week ' : Column.INTEGER,
' native-country ' : Column.CATEGORICAL,
names = list ( schema.tables [ ' adult ' ] .columns ) ,
data = RelationalData ( data = data , schema = schema )
_, data = data. split ( ratio = 0.2 )
preproc = TabularPreproc. from_schema ( schema = schema ). fit ( data = data )
data_train_valid, data_test = data. split ( ratio = split_ratio )
data_train, data_valid = data_train_valid. split ( ratio = split_ratio )
model = TabularModel. build ( preproc = preproc , size = ' tiny ' if quick else ' small ' )
model.device = device # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
dataset_train = TabularDataset. from_data ( data = data_train , preproc = preproc , on_disk = True )
dataset_valid = TabularDataset. from_data ( data = data_valid , preproc = preproc )
trainer = TabularTrainer ( model = model )
save_best = output_dir / ' best.pt ' ,
tensorboard = output_dir / ' tb ' ,
# Generate synthetic data
data_synth = model. generate (
n_samples = data [ ' adult ' ] .shape [ 0 ] ,
data_synth. to_csv ( output_dir / ' synth ' )
# Compute and print PDF report
path = output_dir / ' report.pdf ' ,
# Compute extra privacy stats and print some results
privacy_stats = compute_privacy_stats (
' privacy_score ' : ps.privacy_score,
' privacy_score_std ' : ps.privacy_score_std,
' %_points_at_risk ' : ps.risk * 100 ,
for t, ps in privacy_stats. items ()
with open ( output_dir / ' privacy_stats.json ' , mode = ' w ' , encoding = ' utf-8 ' ) as f:
json. dump ( privacy_stats_out , f )
if __name__ == ' __main__ ' :
parser = argparse. ArgumentParser ()
parser. add_argument ( ' data_dir ' , type = Path , help = " The directory were to find the 'adult' dataset " )
parser. add_argument ( ' output_dir ' , type = Path , help = " The output directory " )
' --n ' , ' -n ' , type = int , default = 1000 ,
help = " Training epochs (or steps if the --steps flag is used) " ,
parser. add_argument ( ' --steps ' , ' -s ' , action = ' store_true ' , help = " Use steps instead of epochs " )
parser. add_argument ( ' --valid-each ' , ' -v ' , type = int , default = 200 , help = " # steps between validations " )
parser. add_argument ( ' --device ' , ' -g ' , default = None , help = " Training device " )
parser. add_argument ( ' --memory ' , ' -m ' , type = int , default = 4096 , help = " Available memory (MB) " )
' --quick ' , ' -q ' , action = ' store_true ' ,
help = " Perform a quick test run, with reduced data and a small model "
args = parser. parse_args ()
output_dir = args.output_dir ,
n_epochs = None if args.steps else args.n ,
n_steps = args.n if args.steps else None ,
valid_each = args.valid_each ,