Skip to content

Multi table - BasketballMan dataset

In the following we present an example script using the aindo.rdml library to generate synthetic data in the multi table case. We make use of the BasketballMen dataset.

import argparse
import json
from pathlib import Path
import pandas as pd
import torch
from aindo.rdml.eval import compute_privacy_stats, report
from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer, Validation
def example_basket(
data_dir: Path,
output_dir: Path,
n_epochs: int | None,
n_steps: int | None,
valid_each: int,
device: str | torch.device | None,
memory: int,
quick: bool,
) -> None:
# Load data and define schema
data = {
'players': pd.read_csv(data_dir / 'players.csv'),
'season': pd.read_csv(data_dir / 'season.csv'),
'all_star': pd.read_csv(data_dir / 'all_star.csv'),
}
schema = Schema(
players=Table(
playerID=PrimaryKey(),
pos=Column.CATEGORICAL,
height=Column.NUMERIC,
weight=Column.NUMERIC,
college=Column.CATEGORICAL,
race=Column.CATEGORICAL,
birthCity=Column.CATEGORICAL,
birthState=Column.CATEGORICAL,
birthCountry=Column.CATEGORICAL,
),
season=Table(
playerID=ForeignKey(parent='players'),
year=Column.INTEGER,
stint=Column.INTEGER,
tmID=Column.CATEGORICAL,
lgID=Column.CATEGORICAL,
GP=Column.INTEGER,
points=Column.INTEGER,
GS=Column.INTEGER,
assists=Column.INTEGER,
steals=Column.INTEGER,
minutes=Column.INTEGER,
),
all_star=Table(
playerID=ForeignKey(parent='players'),
conference=Column.CATEGORICAL,
league_id=Column.CATEGORICAL,
points=Column.INTEGER,
rebounds=Column.INTEGER,
assists=Column.INTEGER,
blocks=Column.INTEGER,
),
)
data = RelationalData(data=data, schema=schema)
if quick:
_, data = data.split(ratio=0.2)
# Define preprocessor
preproc = TabularPreproc.from_schema(schema=schema).fit(data=data)
# Split data
split_ratio = 0.1
data_train_valid, data_test = data.split(ratio=split_ratio)
data_train, data_valid = data_train_valid.split(ratio=split_ratio)
# Build model
model = TabularModel.build(preproc=preproc, size='tiny' if quick else 'small')
model.device = device # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
# Train the model
dataset_train = TabularDataset.from_data(data=data_train, preproc=preproc, on_disk=True)
dataset_valid = TabularDataset.from_data(data=data_valid, preproc=preproc)
trainer = TabularTrainer(model=model)
trainer.train(
dataset=dataset_train,
n_epochs=n_epochs,
n_steps=n_steps,
memory=memory,
valid=Validation(
dataset=dataset_valid,
early_stop='normal',
save_best=output_dir / 'best.pt',
tensorboard=output_dir / 'tb',
each=valid_each,
trigger='step',
),
)
# Generate synthetic data
data_synth = model.generate(
n_samples=data['players'].shape[0],
batch_size=512,
)
data_synth.to_csv(output_dir / 'synth')
# Compute and print PDF report
report(
data_train=data_train,
data_test=data_test,
data_synth=data_synth,
path=output_dir / 'report.pdf',
)
# Compute extra privacy stats and print some results
privacy_stats = compute_privacy_stats(
data_train=data_train,
data_synth=data_synth,
)
privacy_stats_out = {
t: {
'privacy_score': ps.privacy_score,
'privacy_score_std': ps.privacy_score_std,
'%_points_at_risk': ps.risk * 100,
}
for t, ps in privacy_stats.items()
}
with open(output_dir / 'privacy_stats.json', mode='w', encoding='utf-8') as f:
json.dump(privacy_stats_out, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('data_dir', type=Path, help="The directory were to find the 'basket' dataset")
parser.add_argument('output_dir', type=Path, help="The output directory")
parser.add_argument(
'--n', '-n', type=int, default=1000,
help="Training epochs (or steps if the --steps flag is used)",
)
parser.add_argument('--steps', '-s', action='store_true', help="Use steps instead of epochs")
parser.add_argument('--valid-each', '-v', type=int, default=200, help="# steps between validations")
parser.add_argument('--device', '-g', default=None, help="Training device")
parser.add_argument('--memory', '-m', type=int, default=4096, help="Available memory (MB)")
parser.add_argument(
'--quick', '-q', action='store_true',
help="Perform a quick test run, with reduced data and a small model"
)
args = parser.parse_args()
example_basket(
data_dir=args.data_dir,
output_dir=args.output_dir,
n_epochs=None if args.steps else args.n,
n_steps=args.n if args.steps else None,
valid_each=args.valid_each,
device=args.device,
memory=args.memory,
quick=args.quick,
)