Trainer API Reference

The Trainer class is the core component of Treadmill, orchestrating the entire training process with a clean, intuitive interface.

class treadmill.Trainer(model: Module, config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, loss_fn: Callable | None = None, metric_fns: Dict[str, Callable] | None = None, callbacks: List[Callback] | None = None)[source]

Bases: object

Main training class that orchestrates the entire training process.

This class provides a clean, modular interface for PyTorch model training with support for validation, callbacks, metrics tracking, and more.

__init__(model: Module, config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, loss_fn: Callable | None = None, metric_fns: Dict[str, Callable] | None = None, callbacks: List[Callback] | None = None)[source]

Initialize the trainer.

Parameters:
  • model – PyTorch model to train

  • config – Training configuration

  • train_dataloader – Training data loader

  • val_dataloader – Optional validation data loader

  • loss_fn – Loss function (if None, will try to infer from model)

  • metric_fns – Dictionary of metric functions

  • callbacks – List of callbacks for training hooks

train() Dict[str, Any][source]

Execute the complete training loop.

Returns:

Dictionary containing training history and final metrics

fit() Dict[str, Any][source]

Alias for train() method for sklearn-style compatibility.

Many users expect a fit() method from sklearn/other ML libraries. This method simply calls train() for compatibility.

Returns:

Dictionary containing training history and final metrics

property history: Dict[str, Any] | None

Access training history after training has completed.

This property allows access to training results even if the return value from train() or fit() was not stored in a variable.

Returns:

Dictionary containing training history and final metrics, or None if training hasn’t run yet

Example

trainer = Trainer(…) trainer.fit() # Don’t store the result

# Access history later print(f”Total epochs: {trainer.history[‘total_epochs’]}”) print(f”Best accuracy: {trainer.history[‘best_metrics’][‘accuracy’]:.4f}”)

property report: TrainingReport | None

Access comprehensive training report after training has completed.

This property provides detailed information about the training session including model info, configuration, metrics, timing, and more.

Returns:

TrainingReport object with comprehensive training information, or None if training hasn’t completed yet

Example

trainer = Trainer(…) trainer.fit()

# Access detailed report print(f”Model: {trainer.report.model_name}”) print(f”Parameters: {trainer.report.total_parameters:,}”) print(f”Training time: {trainer.report.training_time:.1f}s”) print(f”Best loss: {trainer.report.best_metrics[‘val_loss’]:.4f}”)

# Convert to dictionary for serialization report_dict = trainer.report.to_dict()

save_checkpoint(filepath: str, additional_info: Dict | None = None)[source]

Save a training checkpoint.

save_training_checkpoint(epoch: int, loss_value: float)[source]

Save a comprehensive training checkpoint for resume capability.

load_checkpoint(filepath: str, resume_training: bool = True)[source]

Load a training checkpoint.

Class Overview

The Trainer class provides a high-level interface for PyTorch model training with the following key features:

  • Automatic device management: Handles GPU/CPU placement automatically

  • Mixed precision training: Leverages automatic mixed precision for faster training

  • Flexible callbacks: Extensible callback system for custom training logic

  • Rich metrics tracking: Built-in and custom metrics with beautiful output

  • Smart checkpointing: Automatic model saving and restoration

  • Early stopping: Configurable early stopping to prevent overfitting

Constructor

Trainer.__init__(model: Module, config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, loss_fn: Callable | None = None, metric_fns: Dict[str, Callable] | None = None, callbacks: List[Callback] | None = None)[source]

Initialize the trainer.

Parameters:
  • model – PyTorch model to train

  • config – Training configuration

  • train_dataloader – Training data loader

  • val_dataloader – Optional validation data loader

  • loss_fn – Loss function (if None, will try to infer from model)

  • metric_fns – Dictionary of metric functions

  • callbacks – List of callbacks for training hooks

Parameters in Detail:

  • model (torch.nn.Module): Your PyTorch model to train

  • config (TrainingConfig): Configuration object controlling all training parameters

  • train_dataloader (torch.utils.data.DataLoader): Training data loader

  • val_dataloader (torch.utils.data.DataLoader, optional): Validation data loader for monitoring

  • loss_fn (callable, optional): Loss function. If None, attempts to infer from model

  • metric_fns (dict, optional): Dictionary mapping metric names to functions

  • callbacks (list, optional): List of callback objects for training hooks

Example:

import torch
import torch.nn as nn
from treadmill import Trainer, TrainingConfig

model = nn.Sequential(
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

config = TrainingConfig(epochs=10, device="auto")

trainer = Trainer(
    model=model,
    config=config,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    loss_fn=nn.CrossEntropyLoss(),
    metric_fns={'accuracy': accuracy_function}
)

Training Methods

fit()

Trainer.fit() Dict[str, Any][source]

Alias for train() method for sklearn-style compatibility.

Many users expect a fit() method from sklearn/other ML libraries. This method simply calls train() for compatibility.

Returns:

Dictionary containing training history and final metrics

The primary method for training your model. Executes the complete training loop including:

  • Model training over specified epochs

  • Validation (if validation data provided)

  • Callback execution at appropriate hooks

  • Metrics computation and logging

  • Checkpoint saving

  • Early stopping checks

Returns:

Dictionary containing training history with keys:

  • train_loss: List of training losses per epoch

  • val_loss: List of validation losses per epoch (if validation enabled)

  • train_{metric}: Training metrics per epoch for each custom metric

  • val_{metric}: Validation metrics per epoch for each custom metric

Example:

# Basic training
history = trainer.train()

# Access training history
print(f"Final training loss: {history['train_loss'][-1]}")
print(f"Best validation accuracy: {max(history.get('val_accuracy', [0]))}")

train_epoch()

Executes a single training epoch. Useful for custom training loops or debugging.

Parameters:
  • epoch (int): Current epoch number for logging and callbacks

Returns:

Dictionary containing training metrics for the epoch

Example:

# Custom training loop
for epoch in range(config.epochs):
    train_metrics = trainer.train_epoch(epoch)

    if epoch % 5 == 0:  # Custom validation frequency
        val_metrics = trainer.validate_epoch(epoch)

    # Custom logic here
    if some_condition:
        break

validate_epoch()

Executes validation for a single epoch. Only available if validation dataloader provided.

Parameters:
  • epoch (int): Current epoch number for logging and callbacks

Returns:

Dictionary containing validation metrics for the epoch

Example:

# Manual validation
val_metrics = trainer.validate_epoch(epoch=0)
print(f"Validation loss: {val_metrics['loss']}")
print(f"Validation accuracy: {val_metrics.get('accuracy', 'N/A')}")

State Management

save_checkpoint()

Trainer.save_checkpoint(filepath: str, additional_info: Dict | None = None)[source]

Save a training checkpoint.

Saves the current training state including model weights, optimizer state, scheduler state, and training progress.

Parameters:
  • filepath (str): Path where to save the checkpoint

  • is_best (bool, optional): Whether this is the best model so far

Example:

# Save checkpoint manually
trainer.save_checkpoint("model_epoch_10.pt")

# Save as best model
trainer.save_checkpoint("best_model.pt", is_best=True)

load_checkpoint()

Trainer.load_checkpoint(filepath: str, resume_training: bool = True)[source]

Load a training checkpoint.

Loads a previously saved checkpoint, restoring model weights, optimizer state, and training progress.

Parameters:
  • filepath (str): Path to the checkpoint file

  • map_location (str or torch.device, optional): Device to map loaded tensors to

Returns:

Dictionary containing the loaded checkpoint data

Example:

# Resume training from checkpoint
checkpoint_data = trainer.load_checkpoint("model_epoch_10.pt")
print(f"Resuming from epoch {checkpoint_data.get('epoch', 0)}")

# Continue training
trainer.train()

get_history()

Returns the complete training history including losses and metrics for all epochs.

Returns:

Dictionary containing training history with metric names as keys

Example:

history = trainer.get_history()

# Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Training Loss')
if 'val_loss' in history:
    plt.plot(history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Loss')

plt.subplot(1, 2, 2)
if 'train_accuracy' in history:
    plt.plot(history['train_accuracy'], label='Training Accuracy')
if 'val_accuracy' in history:
    plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy')

plt.show()

Prediction Methods

predict()

Generates predictions for a given dataset using the trained model.

Parameters:
Returns:

Tensor containing model predictions

Example:

# Generate predictions on test set
test_predictions = trainer.predict(test_loader)

# Convert to class predictions
predicted_classes = torch.argmax(test_predictions, dim=1)

# Calculate accuracy
accuracy = (predicted_classes == test_labels).float().mean()
print(f"Test accuracy: {accuracy:.4f}")

evaluate()

Evaluates the model on a given dataset, computing loss and all configured metrics.

Parameters:
Returns:

Dictionary containing evaluation metrics

Example:

# Evaluate on test set
test_results = trainer.evaluate(test_loader)

print(f"Test Results:")
for metric_name, value in test_results.items():
    print(f"  {metric_name}: {value:.4f}")

Properties and Attributes

model

Access to the underlying PyTorch model.

# Access model for inference
model = trainer.model
model.eval()

with torch.no_grad():
    output = model(input_tensor)

config

The training configuration object.

# Check configuration
print(f"Training for {trainer.config.epochs} epochs")
print(f"Using device: {trainer.config.device}")

device

The device (CPU/GPU) being used for training.

print(f"Training on: {trainer.device}")

metrics_tracker

The metrics tracking object that maintains training history.

# Access detailed metrics
tracker = trainer.metrics_tracker
latest_metrics = tracker.get_latest_metrics()

current_epoch

The current epoch number during training.

# Useful in callbacks
print(f"Currently at epoch: {trainer.current_epoch}")

Advanced Usage Examples

Custom Training Loop

trainer = Trainer(model, config, train_loader, val_loader)

for epoch in range(config.epochs):
    # Custom pre-epoch logic
    if epoch % 10 == 0:
        # Adjust learning rate or other parameters
        for param_group in trainer.optimizer.param_groups:
            param_group['lr'] *= 0.9

    # Train one epoch
    train_metrics = trainer.train_epoch(epoch)

    # Custom validation schedule
    if epoch % config.validation_frequency == 0:
        val_metrics = trainer.validate_epoch(epoch)

        # Custom early stopping logic
        if val_metrics['loss'] > previous_best_loss * 1.1:
            print("Custom early stopping triggered!")
            break

    # Custom checkpointing
    if epoch % 20 == 0:
        trainer.save_checkpoint(f"checkpoint_epoch_{epoch}.pt")

Integration with Callbacks

from treadmill.callbacks import EarlyStopping, ModelCheckpoint

# Create custom callbacks
early_stopping = EarlyStopping(
    patience=10,
    min_delta=0.001,
    verbose=True
)

checkpointing = ModelCheckpoint(
    filepath="models/best_model_{epoch:02d}_{val_loss:.4f}.pt",
    save_best_only=True,
    verbose=True
)

trainer = Trainer(
    model=model,
    config=config,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    callbacks=[early_stopping, checkpointing]
)

# Callbacks are automatically called during training
history = trainer.train()

Multi-GPU Training

# DataParallel for multi-GPU training
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

config = TrainingConfig(
    epochs=50,
    device="cuda",
    mixed_precision=True  # Works with DataParallel
)

trainer = Trainer(model, config, train_loader, val_loader)
trainer.train()

See Also

  • TrainingConfig: Configuration options for training

  • callbacks: Available callbacks and custom callback creation

  • metrics: Built-in metrics and custom metric functions

  • Complete Image Classification Tutorial: Complete training tutorial

  • ../examples/advanced_training: Advanced training examples