Skip to content

Trainers

class TabularTrainer

TabularTrainer.__init__

def __init__(model: TabularModel) -> None

Trainer for a TabularModel.

Arguments:

  • model - A TabularModel instance.

TabularTrainer.train

def train(data: RelationalData,
n_epochs: int,
batch_size: int = 0,
lr: float = 0.,
memory: int = 0,
valid: Validation | None = None,
hooks: Sequence[TrainHook] = (),
accumulate_grad: int = 1) -> None

Train the input model with the input dataset.

Arguments:

  • data - A RelationalData object containing the training data.
  • n_epochs - The desired number of training epochs.
  • batch_size - The size of a batch of data during training. When it is not specified the user must provide the argument memory.
  • lr - The learning rate. If it is 0 the optimal value for the learning rate is automatically determined.
  • memory - The available memory in MB that is used to automatically compute the optimal value of the batch size.
  • valid - A Validation object. If None, no validation is performed.
  • hooks - A sequence of custom TrainHook objects.
  • accumulate_grad - The number of gradient accumulation steps. If equal to 1, the weights are updated at each step.

class TextTrainer

TextTrainer.__init__

def __init__(model: TextModel) -> None

Trainer for a TextModel.

Arguments:

  • model - A TextModel instance.

TextTrainer.train

def train(data: RelationalData,
n_epochs: int,
batch_size: int = 0,
lr: float = 0.,
memory: int = 0,
valid: Validation | None = None,
hooks: Sequence[TrainHook] = (),
accumulate_grad: int = 1) -> None

Train the input model with the input dataset.

Arguments:

  • data - A RelationalData object containing the training data.
  • n_epochs - The desired number of training epochs.
  • batch_size - The size of a batch of data during training. When it is not specified the user must provide the argument memory.
  • lr - The learning rate. If it is 0 the optimal value for the learning rate is automatically determined.
  • memory - The available memory in MB that is used to automatically compute the optimal value of the batch size.
  • valid - A Validation object. If None, no validation is performed.
  • hooks - A sequence of custom TrainHook objects.
  • accumulate_grad - The number of gradient accumulation steps. If equal to 1, the weights are updated at each step.