Trainers
class TabularTrainer
TabularTrainer.__init__
Trainer for a TabularModel
.
Arguments:
model
- ATabularModel
instance.dp_budget
- The (eps, delta)-budget for differentially private (DP) training. If None (the default), the training will not be differentially private. Available only for single table datasets.
TabularTrainer.train
Train the tabular model with the input dataset.
Arguments:
dataset
- The training data, as aTabularDataset
object.n_epochs
- The number of training epochs. One and only one of n_epochs and n_steps must be provided.n_steps
- The number of training steps. One and only one of n_epochs and n_steps must be provided.batch_size
- The size of a batch of data during training. When it is not specified the user must provide the argumentmemory
.lr
- The learning rate. If it is 0 the optimal value for the learning rate is automatically determined.memory
- The available memory in MB that is used to automatically compute the optimal value of the batch size.valid
- AValidation
object. If None, no validation is performed.hooks
- A sequence of customTrainHook
objects.accumulate_grad
- The number of gradient accumulation steps. If equal to 1, the weights are updated at each step.dp_step
- Data for differentially private step. Must be provided if and only if the trainer has a DP-budget.
class TextTrainer
TextTrainer.__init__
Trainer for a TextModel
.
Arguments:
model
- ATextModel
instance.
TextTrainer.train
Train the text model with the input dataset.
Arguments:
dataset
- The training data, as aTextDataset
object.n_epochs
- The number of training epochs. One and only one of n_epochs and n_steps must be provided.n_steps
- The number of training steps. One and only one of n_epochs and n_steps must be provided.batch_size
- The size of a batch of data during training. When it is not specified the user must provide the argumentmemory
.lr
- The learning rate. If it is 0 the optimal value for the learning rate is automatically determined.memory
- The available memory in MB that is used to automatically compute the optimal value of the batch size.valid
- AValidation
object. If None, no validation is performed.hooks
- A sequence of customTrainHook
objects.accumulate_grad
- The number of gradient accumulation steps. If equal to 1, the weights are updated at each step.