An enhanced R6 trainer mirroring PyHealth's Python version. It supports:
Dynamic
steps_per_epoch: can iterate indefinitely over a dataloader to reach a target number of steps, just like Python.Parameter‑group–wise weight decay: bias and LayerNorm parameters are excluded from L2 regularisation.
Gradient clipping.
Optional progress bar using
progressr::progressor()(falls back to simple logging).Correctly named
additional_outputscollection.
Public fields
modelA torch model object.
metricsA list of metric names.
deviceThe computation device ("cpu" or "cuda").
exp_pathPath to save logs and checkpoints.
Methods
Method new()
Initialize the Trainer.
Usage
Trainer$new(
model,
checkpoint_path = NULL,
metrics = NULL,
device = NULL,
enable_logging = TRUE,
output_path = NULL,
exp_name = NULL
)Method train()
Train the model.
Usage
Trainer$train(
train_dataloader,
val_dataloader = NULL,
test_dataloader = NULL,
epochs = 5,
optimizer_class = optim_adam,
optimizer_params = list(lr = 0.001),
steps_per_epoch = NULL,
evaluation_steps = 1L,
weight_decay = 0,
max_grad_norm = NULL,
monitor = NULL,
monitor_criterion = "max",
load_best_model_at_last = TRUE,
use_progress_bar = TRUE
)Arguments
train_dataloaderTraining dataloader.
val_dataloaderOptional validation dataloader.
test_dataloaderOptional test dataloader.
epochsNumber of training epochs.
optimizer_classOptimizer constructor.
optimizer_paramsParameters for optimizer.
steps_per_epochOptional override for steps per epoch.
evaluation_stepsSteps between evaluations.
weight_decayWeight decay parameter.
max_grad_normOptional gradient clipping norm.
monitorMetric name to monitor.
monitor_criterion"max" or "min".
load_best_model_at_lastLoad best model after training.
use_progress_barShow training progress.
Method inference()
Perform inference on a dataloader.