Trainer API Reference
The Trainer class is the core component of Treadmill, orchestrating the entire training process with a clean, intuitive interface.
- class treadmill.Trainer(model: Module, config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, loss_fn: Callable | None = None, metric_fns: Dict[str, Callable] | None = None, callbacks: List[Callback] | None = None)[source]
Bases:
objectMain training class that orchestrates the entire training process.
This class provides a clean, modular interface for PyTorch model training with support for validation, callbacks, metrics tracking, and more.
- __init__(model: Module, config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, loss_fn: Callable | None = None, metric_fns: Dict[str, Callable] | None = None, callbacks: List[Callback] | None = None)[source]
Initialize the trainer.
- Parameters:
model – PyTorch model to train
config – Training configuration
train_dataloader – Training data loader
val_dataloader – Optional validation data loader
loss_fn – Loss function (if None, will try to infer from model)
metric_fns – Dictionary of metric functions
callbacks – List of callbacks for training hooks
- train() Dict[str, Any][source]
Execute the complete training loop.
- Returns:
Dictionary containing training history and final metrics
- fit() Dict[str, Any][source]
Alias for train() method for sklearn-style compatibility.
Many users expect a fit() method from sklearn/other ML libraries. This method simply calls train() for compatibility.
- Returns:
Dictionary containing training history and final metrics
- property history: Dict[str, Any] | None
Access training history after training has completed.
This property allows access to training results even if the return value from train() or fit() was not stored in a variable.
- Returns:
Dictionary containing training history and final metrics, or None if training hasn’t run yet
Example
trainer = Trainer(…) trainer.fit() # Don’t store the result
# Access history later print(f”Total epochs: {trainer.history[‘total_epochs’]}”) print(f”Best accuracy: {trainer.history[‘best_metrics’][‘accuracy’]:.4f}”)
- property report: TrainingReport | None
Access comprehensive training report after training has completed.
This property provides detailed information about the training session including model info, configuration, metrics, timing, and more.
- Returns:
TrainingReport object with comprehensive training information, or None if training hasn’t completed yet
Example
trainer = Trainer(…) trainer.fit()
# Access detailed report print(f”Model: {trainer.report.model_name}”) print(f”Parameters: {trainer.report.total_parameters:,}”) print(f”Training time: {trainer.report.training_time:.1f}s”) print(f”Best loss: {trainer.report.best_metrics[‘val_loss’]:.4f}”)
# Convert to dictionary for serialization report_dict = trainer.report.to_dict()
- save_checkpoint(filepath: str, additional_info: Dict | None = None)[source]
Save a training checkpoint.
Class Overview
The Trainer class provides a high-level interface for PyTorch model training with the following key features:
Automatic device management: Handles GPU/CPU placement automatically
Mixed precision training: Leverages automatic mixed precision for faster training
Flexible callbacks: Extensible callback system for custom training logic
Rich metrics tracking: Built-in and custom metrics with beautiful output
Smart checkpointing: Automatic model saving and restoration
Early stopping: Configurable early stopping to prevent overfitting
Constructor
- Trainer.__init__(model: Module, config: TrainingConfig, train_dataloader: DataLoader, val_dataloader: DataLoader | None = None, loss_fn: Callable | None = None, metric_fns: Dict[str, Callable] | None = None, callbacks: List[Callback] | None = None)[source]
Initialize the trainer.
- Parameters:
model – PyTorch model to train
config – Training configuration
train_dataloader – Training data loader
val_dataloader – Optional validation data loader
loss_fn – Loss function (if None, will try to infer from model)
metric_fns – Dictionary of metric functions
callbacks – List of callbacks for training hooks
Parameters in Detail:
model(torch.nn.Module): Your PyTorch model to trainconfig(TrainingConfig): Configuration object controlling all training parameterstrain_dataloader(torch.utils.data.DataLoader): Training data loaderval_dataloader(torch.utils.data.DataLoader, optional): Validation data loader for monitoringloss_fn(callable, optional): Loss function. If None, attempts to infer from modelmetric_fns(dict, optional): Dictionary mapping metric names to functionscallbacks(list, optional): List of callback objects for training hooks
Example:
import torch
import torch.nn as nn
from treadmill import Trainer, TrainingConfig
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
config = TrainingConfig(epochs=10, device="auto")
trainer = Trainer(
model=model,
config=config,
train_dataloader=train_loader,
val_dataloader=val_loader,
loss_fn=nn.CrossEntropyLoss(),
metric_fns={'accuracy': accuracy_function}
)
Training Methods
fit()
- Trainer.fit() Dict[str, Any][source]
Alias for train() method for sklearn-style compatibility.
Many users expect a fit() method from sklearn/other ML libraries. This method simply calls train() for compatibility.
- Returns:
Dictionary containing training history and final metrics
The primary method for training your model. Executes the complete training loop including:
Model training over specified epochs
Validation (if validation data provided)
Callback execution at appropriate hooks
Metrics computation and logging
Checkpoint saving
Early stopping checks
- Returns:
Dictionary containing training history with keys:
train_loss: List of training losses per epochval_loss: List of validation losses per epoch (if validation enabled)train_{metric}: Training metrics per epoch for each custom metricval_{metric}: Validation metrics per epoch for each custom metric
Example:
# Basic training
history = trainer.train()
# Access training history
print(f"Final training loss: {history['train_loss'][-1]}")
print(f"Best validation accuracy: {max(history.get('val_accuracy', [0]))}")
train_epoch()
Executes a single training epoch. Useful for custom training loops or debugging.
- Parameters:
epoch(int): Current epoch number for logging and callbacks
- Returns:
Dictionary containing training metrics for the epoch
Example:
# Custom training loop
for epoch in range(config.epochs):
train_metrics = trainer.train_epoch(epoch)
if epoch % 5 == 0: # Custom validation frequency
val_metrics = trainer.validate_epoch(epoch)
# Custom logic here
if some_condition:
break
validate_epoch()
Executes validation for a single epoch. Only available if validation dataloader provided.
- Parameters:
epoch(int): Current epoch number for logging and callbacks
- Returns:
Dictionary containing validation metrics for the epoch
Example:
# Manual validation
val_metrics = trainer.validate_epoch(epoch=0)
print(f"Validation loss: {val_metrics['loss']}")
print(f"Validation accuracy: {val_metrics.get('accuracy', 'N/A')}")
State Management
save_checkpoint()
- Trainer.save_checkpoint(filepath: str, additional_info: Dict | None = None)[source]
Save a training checkpoint.
Saves the current training state including model weights, optimizer state, scheduler state, and training progress.
- Parameters:
filepath(str): Path where to save the checkpointis_best(bool, optional): Whether this is the best model so far
Example:
# Save checkpoint manually
trainer.save_checkpoint("model_epoch_10.pt")
# Save as best model
trainer.save_checkpoint("best_model.pt", is_best=True)
load_checkpoint()
- Trainer.load_checkpoint(filepath: str, resume_training: bool = True)[source]
Load a training checkpoint.
Loads a previously saved checkpoint, restoring model weights, optimizer state, and training progress.
- Parameters:
filepath(str): Path to the checkpoint filemap_location(str or torch.device, optional): Device to map loaded tensors to
- Returns:
Dictionary containing the loaded checkpoint data
Example:
# Resume training from checkpoint
checkpoint_data = trainer.load_checkpoint("model_epoch_10.pt")
print(f"Resuming from epoch {checkpoint_data.get('epoch', 0)}")
# Continue training
trainer.train()
get_history()
Returns the complete training history including losses and metrics for all epochs.
- Returns:
Dictionary containing training history with metric names as keys
Example:
history = trainer.get_history()
# Plot training curves
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Training Loss')
if 'val_loss' in history:
plt.plot(history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Loss')
plt.subplot(1, 2, 2)
if 'train_accuracy' in history:
plt.plot(history['train_accuracy'], label='Training Accuracy')
if 'val_accuracy' in history:
plt.plot(history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy')
plt.show()
Prediction Methods
predict()
Generates predictions for a given dataset using the trained model.
- Parameters:
dataloader(torch.utils.data.DataLoader): Data to generate predictions for
- Returns:
Tensor containing model predictions
Example:
# Generate predictions on test set
test_predictions = trainer.predict(test_loader)
# Convert to class predictions
predicted_classes = torch.argmax(test_predictions, dim=1)
# Calculate accuracy
accuracy = (predicted_classes == test_labels).float().mean()
print(f"Test accuracy: {accuracy:.4f}")
evaluate()
Evaluates the model on a given dataset, computing loss and all configured metrics.
- Parameters:
dataloader(torch.utils.data.DataLoader): Data to evaluate on
- Returns:
Dictionary containing evaluation metrics
Example:
# Evaluate on test set
test_results = trainer.evaluate(test_loader)
print(f"Test Results:")
for metric_name, value in test_results.items():
print(f" {metric_name}: {value:.4f}")
Properties and Attributes
model
Access to the underlying PyTorch model.
# Access model for inference
model = trainer.model
model.eval()
with torch.no_grad():
output = model(input_tensor)
config
The training configuration object.
# Check configuration
print(f"Training for {trainer.config.epochs} epochs")
print(f"Using device: {trainer.config.device}")
device
The device (CPU/GPU) being used for training.
print(f"Training on: {trainer.device}")
metrics_tracker
The metrics tracking object that maintains training history.
# Access detailed metrics
tracker = trainer.metrics_tracker
latest_metrics = tracker.get_latest_metrics()
current_epoch
The current epoch number during training.
# Useful in callbacks
print(f"Currently at epoch: {trainer.current_epoch}")
Advanced Usage Examples
Custom Training Loop
trainer = Trainer(model, config, train_loader, val_loader)
for epoch in range(config.epochs):
# Custom pre-epoch logic
if epoch % 10 == 0:
# Adjust learning rate or other parameters
for param_group in trainer.optimizer.param_groups:
param_group['lr'] *= 0.9
# Train one epoch
train_metrics = trainer.train_epoch(epoch)
# Custom validation schedule
if epoch % config.validation_frequency == 0:
val_metrics = trainer.validate_epoch(epoch)
# Custom early stopping logic
if val_metrics['loss'] > previous_best_loss * 1.1:
print("Custom early stopping triggered!")
break
# Custom checkpointing
if epoch % 20 == 0:
trainer.save_checkpoint(f"checkpoint_epoch_{epoch}.pt")
Integration with Callbacks
from treadmill.callbacks import EarlyStopping, ModelCheckpoint
# Create custom callbacks
early_stopping = EarlyStopping(
patience=10,
min_delta=0.001,
verbose=True
)
checkpointing = ModelCheckpoint(
filepath="models/best_model_{epoch:02d}_{val_loss:.4f}.pt",
save_best_only=True,
verbose=True
)
trainer = Trainer(
model=model,
config=config,
train_dataloader=train_loader,
val_dataloader=val_loader,
callbacks=[early_stopping, checkpointing]
)
# Callbacks are automatically called during training
history = trainer.train()
Multi-GPU Training
# DataParallel for multi-GPU training
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
config = TrainingConfig(
epochs=50,
device="cuda",
mixed_precision=True # Works with DataParallel
)
trainer = Trainer(model, config, train_loader, val_loader)
trainer.train()
See Also
TrainingConfig: Configuration options for trainingcallbacks: Available callbacks and custom callback creation
metrics: Built-in metrics and custom metric functions
Complete Image Classification Tutorial: Complete training tutorial
../examples/advanced_training: Advanced training examples