from aindo.rdml.eval import compute_privacy_stats, report
from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer, Validation
device : str | torch.device | None ,
# Load data and define schema
' players ' : pd. read_csv ( data_dir / ' players.csv ' ),
' season ' : pd. read_csv ( data_dir / ' season.csv ' ),
' all_star ' : pd. read_csv ( data_dir / ' all_star.csv ' ),
college = Column.CATEGORICAL ,
birthCity = Column.CATEGORICAL ,
birthState = Column.CATEGORICAL ,
birthCountry = Column.CATEGORICAL ,
playerID = ForeignKey ( parent = ' players ' ) ,
playerID = ForeignKey ( parent = ' players ' ) ,
conference = Column.CATEGORICAL ,
league_id = Column.CATEGORICAL ,
data = RelationalData ( data = data , schema = schema )
_, data = data. split ( ratio = 0.2 )
preproc = TabularPreproc. from_schema ( schema = schema ). fit ( data = data )
data_train_valid, data_test = data. split ( ratio = split_ratio )
data_train, data_valid = data_train_valid. split ( ratio = split_ratio )
model = TabularModel. build ( preproc = preproc , size = ' tiny ' if quick else ' small ' )
model.device = device # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
dataset_train = TabularDataset. from_data ( data = data_train , preproc = preproc , on_disk = True )
dataset_valid = TabularDataset. from_data ( data = data_valid , preproc = preproc )
trainer = TabularTrainer ( model = model )
save_best = output_dir / ' best.pt ' ,
tensorboard = output_dir / ' tb ' ,
# Generate synthetic data
data_synth = model. generate (
n_samples = data [ ' players ' ] .shape [ 0 ] ,
data_synth. to_csv ( output_dir / ' synth ' )
# Compute and print PDF report
path = output_dir / ' report.pdf ' ,
# Compute extra privacy stats and print some results
privacy_stats = compute_privacy_stats (
' privacy_score ' : ps.privacy_score,
' privacy_score_std ' : ps.privacy_score_std,
' %_points_at_risk ' : ps.risk * 100 ,
for t, ps in privacy_stats. items ()
with open ( output_dir / ' privacy_stats.json ' , mode = ' w ' , encoding = ' utf-8 ' ) as f:
json. dump ( privacy_stats_out , f )
if __name__ == ' __main__ ' :
parser = argparse. ArgumentParser ()
parser. add_argument ( ' data_dir ' , type = Path , help = " The directory were to find the 'basket' dataset " )
parser. add_argument ( ' output_dir ' , type = Path , help = " The output directory " )
' --n ' , ' -n ' , type = int , default = 1000 ,
help = " Training epochs (or steps if the --steps flag is used) " ,
parser. add_argument ( ' --steps ' , ' -s ' , action = ' store_true ' , help = " Use steps instead of epochs " )
parser. add_argument ( ' --valid-each ' , ' -v ' , type = int , default = 200 , help = " # steps between validations " )
parser. add_argument ( ' --device ' , ' -g ' , default = None , help = " Training device " )
parser. add_argument ( ' --memory ' , ' -m ' , type = int , default = 4096 , help = " Available memory (MB) " )
' --quick ' , ' -q ' , action = ' store_true ' ,
help = " Perform a quick test run, with reduced data and a small model "
args = parser. parse_args ()
output_dir = args.output_dir ,
n_epochs = None if args.steps else args.n ,
n_steps = args.n if args.steps else None ,
valid_each = args.valid_each ,