from aindo.rdml.eval import compute_privacy_stats, report
from aindo.rdml.relational import Column, RelationalData, Schema, Table
from aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer, Validation
device : str | torch.device | None ,
# Load data and define schema
" workclass " : Column.CATEGORICAL,
" fnlwgt " : Column.INTEGER,
" education " : Column.CATEGORICAL,
" education-num " : Column.CATEGORICAL,
" marital-status " : Column.CATEGORICAL,
" occupation " : Column.CATEGORICAL,
" relationship " : Column.CATEGORICAL,
" race " : Column.CATEGORICAL,
" sex " : Column.CATEGORICAL,
" capital-gain " : Column.INTEGER,
" capital-loss " : Column.INTEGER,
" hours-per-week " : Column.INTEGER,
" native-country " : Column.CATEGORICAL,
names = list ( schema.tables [ " adult " ] .columns ) ,
data = RelationalData ( data = data , schema = schema )
_, data = data. split ( ratio = 0.2 )
preproc = TabularPreproc. from_schema ( schema = schema ). fit ( data = data )
data_train_valid, data_test = data. split ( ratio = split_ratio )
data_train, data_valid = data_train_valid. split ( ratio = split_ratio )
model = TabularModel. build ( preproc = preproc , size = " tiny " if quick else " small " )
model.device = device # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
dataset_train = TabularDataset. from_data ( data = data_train , preproc = preproc , on_disk = True )
dataset_valid = TabularDataset. from_data ( data = data_valid , preproc = preproc )
trainer = TabularTrainer ( model = model )
save_best = output_dir / " best.pt " ,
tensorboard = output_dir / " tb " ,
# Generate synthetic data
data_synth = model. generate (
n_samples = data [ " adult " ] .shape [ 0 ] ,
data_synth. to_csv ( output_dir / " synth " )
# Compute and print PDF report
path = output_dir / " report.pdf " ,
# Compute extra privacy stats and print some results
privacy_stats = compute_privacy_stats (
" privacy_score " : ps.privacy_score,
" privacy_score_std " : ps.privacy_score_std,
" %_points_at_risk " : ps.risk * 100 ,
for t, ps in privacy_stats. items ()
with open ( output_dir / " privacy_stats.json " , mode = " w " , encoding = " utf-8 " ) as f:
json. dump ( privacy_stats_out , f )
if __name__ == " __main__ " :
parser = argparse. ArgumentParser ()
parser. add_argument ( " data_dir " , type = Path , help = " The directory were to find the 'adult' dataset " )
parser. add_argument ( " output_dir " , type = Path , help = " The output directory " )
help = " Training epochs (or steps if the --steps flag is used) " ,
parser. add_argument ( " --steps " , " -s " , action = " store_true " , help = " Use steps instead of epochs " )
parser. add_argument ( " --valid-each " , " -v " , type = int , default = 200 , help = " # steps between validations " )
parser. add_argument ( " --device " , " -g " , default = None , help = " Training device " )
parser. add_argument ( " --memory " , " -m " , type = int , default = 4096 , help = " Available memory (MB) " )
" --quick " , " -q " , action = " store_true " , help = " Perform a quick test run, with reduced data and a small model "
args = parser. parse_args ()
output_dir = args.output_dir ,
n_epochs = None if args.steps else args.n ,
n_steps = args.n if args.steps else None ,
valid_each = args.valid_each ,