Multi table - BasketballMan dataset
In the following we present an example script using the aindo.rdml
library
to generate synthetic data in the multi table case.
We make use of the BasketballMen dataset.
import argparseimport jsonfrom pathlib import Path
import pandas as pdimport torch
from aindo.rdml.eval import compute_privacy_stats, reportfrom aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Tablefrom aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer, Validation
def example_basket( data_dir: Path, output_dir: Path, n_epochs: int | None, n_steps: int | None, valid_each: int, device: str | torch.device | None, memory: int, quick: bool,) -> None: # Load data and define schema data = { "players": pd.read_csv(data_dir / "players.csv"), "season": pd.read_csv(data_dir / "season.csv"), "all_star": pd.read_csv(data_dir / "all_star.csv"), } schema = Schema( players=Table( playerID=PrimaryKey(), pos=Column.CATEGORICAL, height=Column.NUMERIC, weight=Column.NUMERIC, college=Column.CATEGORICAL, race=Column.CATEGORICAL, birthCity=Column.CATEGORICAL, birthState=Column.CATEGORICAL, birthCountry=Column.CATEGORICAL, ), season=Table( playerID=ForeignKey(parent="players"), year=Column.INTEGER, stint=Column.INTEGER, tmID=Column.CATEGORICAL, lgID=Column.CATEGORICAL, GP=Column.INTEGER, points=Column.INTEGER, GS=Column.INTEGER, assists=Column.INTEGER, steals=Column.INTEGER, minutes=Column.INTEGER, ), all_star=Table( playerID=ForeignKey(parent="players"), conference=Column.CATEGORICAL, league_id=Column.CATEGORICAL, points=Column.INTEGER, rebounds=Column.INTEGER, assists=Column.INTEGER, blocks=Column.INTEGER, ), ) data = RelationalData(data=data, schema=schema) if quick: _, data = data.split(ratio=0.2)
# Define preprocessor preproc = TabularPreproc.from_schema(schema=schema).fit(data=data)
# Split data split_ratio = 0.1 data_train_valid, data_test = data.split(ratio=split_ratio) data_train, data_valid = data_train_valid.split(ratio=split_ratio)
# Build model model = TabularModel.build(preproc=preproc, size="tiny" if quick else "small") model.device = device # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
# Train the model dataset_train = TabularDataset.from_data(data=data_train, preproc=preproc, on_disk=True) dataset_valid = TabularDataset.from_data(data=data_valid, preproc=preproc) trainer = TabularTrainer(model=model) trainer.train( dataset=dataset_train, n_epochs=n_epochs, n_steps=n_steps, memory=memory, valid=Validation( dataset=dataset_valid, early_stop="normal", save_best=output_dir / "best.pt", tensorboard=output_dir / "tb", each=valid_each, trigger="step", ), )
# Generate synthetic data data_synth = model.generate( n_samples=data["players"].shape[0], batch_size=512, ) data_synth.to_csv(output_dir / "synth")
# Compute and print PDF report report( data_train=data_train, data_test=data_test, data_synth=data_synth, path=output_dir / "report.pdf", )
# Compute extra privacy stats and print some results privacy_stats = compute_privacy_stats( data_train=data_train, data_synth=data_synth, ) privacy_stats_out = { t: { "privacy_score": ps.privacy_score, "privacy_score_std": ps.privacy_score_std, "%_points_at_risk": ps.risk * 100, } for t, ps in privacy_stats.items() } with open(output_dir / "privacy_stats.json", mode="w", encoding="utf-8") as f: json.dump(privacy_stats_out, f)
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("data_dir", type=Path, help="The directory were to find the 'basket' dataset") parser.add_argument("output_dir", type=Path, help="The output directory") parser.add_argument( "--n", "-n", type=int, default=1000, help="Training epochs (or steps if the --steps flag is used)", ) parser.add_argument("--steps", "-s", action="store_true", help="Use steps instead of epochs") parser.add_argument("--valid-each", "-v", type=int, default=200, help="# steps between validations") parser.add_argument("--device", "-g", default=None, help="Training device") parser.add_argument("--memory", "-m", type=int, default=4096, help="Available memory (MB)") parser.add_argument( "--quick", "-q", action="store_true", help="Perform a quick test run, with reduced data and a small model" ) args = parser.parse_args()
example_basket( data_dir=args.data_dir, output_dir=args.output_dir, n_epochs=None if args.steps else args.n, n_steps=args.n if args.steps else None, valid_each=args.valid_each, device=args.device, memory=args.memory, quick=args.quick, )