Skip to content

The Trainer Loop Feature

Writing custom training loops can be repetitive. The Trainer class abstracts the boilerplate of zeroing gradients, forward passes, loss calculation, backpropagation, and optimization steps.

Basic Usage

lua
local trainer = Gradien.Trainer.new({
    model = myModel,
    optimizerFactory = function(params) return Gradien.Optim.Adam(params) end,
    loss = Gradien.NN.Losses.mse_backward,
    
    -- Optional
    metric = function(pred, target) return 0 end,
    reportEvery = 10,
    callbacks = {
        onStep = function(ctx) print(ctx.loss) end
    }
})
lua
trainer:fit(function()
    -- Return a function that returns (X, Y) batches
    return MyDataLoader()
end, {
    epochs = 50,
    stepsPerEpoch = 100
})

Classification Trainer

For classification tasks, Trainer.newClassification provides sensible defaults (Cross Entropy Loss and Accuracy metric).

lua
local clsTrainer = Gradien.Trainer.newClassification({
    model = myModel,
    optimizerFactory = myOptFactory,
    callbacks = myCallbacks
}, {
    smoothing = 0.1 -- Label smoothing
})

Callbacks

The trainer supports a rich callback system to hook into the training process.

lua
callbacks = {
    onStep = function(ctx) 
        -- ctx: { epoch, step, loss, metric, model, optimizer }
    end,
    onEpochEnd = function(ctx)
        print("End of epoch:", ctx.epoch)
    end,
    onBest = function(ctx)
        print("New best metric:", ctx.metric)
    end
}