Advanced Usage Example
This example demonstrates sophisticated training techniques using Treadmill’s advanced features. Perfect for experienced practitioners who want to leverage all of Treadmill’s capabilities.
Overview
Advanced techniques covered: - Custom callbacks and training hooks - Learning rate scheduling and optimization - Advanced data loading and augmentation - Mixed precision training and gradient accumulation - Custom metrics and monitoring - Model checkpointing and resumption - Multi-GPU training setup - Advanced evaluation and analysis
Estimated time: 30-45 minutes
Prerequisites
pip install -e ".[full]"
pip install wandb tensorboard # Optional: for advanced logging
Advanced Image Classification
Let’s build a sophisticated training pipeline for CIFAR-100 with all advanced features:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import os
from treadmill import Trainer, TrainingConfig, OptimizerConfig, SchedulerConfig
from treadmill.callbacks import (
EarlyStopping, ModelCheckpoint, ReduceLROnPlateau,
Callback
)
# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
Step 1: Advanced Data Pipeline
class AdvancedDataPipeline:
"""Advanced data loading with sophisticated augmentation."""
def __init__(self, data_dir='./data', batch_size=128, num_workers=4):
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.setup_transforms()
self.load_data()
def setup_transforms(self):
"""Create sophisticated augmentation pipelines."""
# Advanced training augmentations
self.train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.1
),
transforms.RandomAffine(0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761]
),
# Advanced augmentations
transforms.RandomErasing(p=0.1, scale=(0.02, 0.33))
])
# Test transform (no augmentation)
self.test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761]
)
])
def load_data(self):
"""Load CIFAR-100 with train/val split."""
# Load full datasets
train_dataset = torchvision.datasets.CIFAR100(
root=self.data_dir, train=True, download=True,
transform=self.train_transform
)
test_dataset = torchvision.datasets.CIFAR100(
root=self.data_dir, train=False, download=True,
transform=self.test_transform
)
# Create train/validation split
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
self.train_dataset, self.val_dataset = torch.utils.data.random_split(
train_dataset, [train_size, val_size],
generator=torch.Generator().manual_seed(42)
)
# Update validation dataset transform
self.val_dataset.dataset = torchvision.datasets.CIFAR100(
root=self.data_dir, train=True, download=False,
transform=self.test_transform
)
self.test_dataset = test_dataset
print(f"Dataset loaded:")
print(f" Training: {len(self.train_dataset)} samples")
print(f" Validation: {len(self.val_dataset)} samples")
print(f" Test: {len(self.test_dataset)} samples")
print(f" Classes: 100")
def get_loaders(self):
"""Get data loaders with advanced settings."""
train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
drop_last=True # For batch norm stability
)
val_loader = DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True
)
test_loader = DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True
)
return train_loader, val_loader, test_loader
# Create data pipeline
data_pipeline = AdvancedDataPipeline(batch_size=128)
train_loader, val_loader, test_loader = data_pipeline.get_loaders()
Step 2: Advanced Model Architecture
class ResidualBlock(nn.Module):
"""Residual block with batch normalization."""
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
residual = self.shortcut(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += residual
out = F.relu(out)
return out
class AdvancedCNN(nn.Module):
"""Advanced CNN with residual connections and modern techniques."""
def __init__(self, num_classes=100, dropout_rate=0.3):
super().__init__()
# Initial convolution
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(3, 2, 1)
# Residual blocks
self.layer1 = self._make_layer(64, 64, 2, stride=1)
self.layer2 = self._make_layer(64, 128, 2, stride=2)
self.layer3 = self._make_layer(128, 256, 2, stride=2)
self.layer4 = self._make_layer(256, 512, 2, stride=2)
# Global average pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# Classifier with dropout
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, num_classes)
)
self._initialize_weights()
def _make_layer(self, in_channels, out_channels, blocks, stride):
layers = []
layers.append(ResidualBlock(in_channels, out_channels, stride))
for _ in range(1, blocks):
layers.append(ResidualBlock(out_channels, out_channels))
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# Create model
model = AdvancedCNN(num_classes=100, dropout_rate=0.3)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model created with {count_parameters(model):,} parameters")
Step 3: Custom Callbacks
class CustomLoggingCallback(Callback):
"""Custom callback for advanced logging and monitoring."""
def __init__(self, log_dir='./logs'):
self.log_dir = log_dir
self.metrics_history = defaultdict(list)
os.makedirs(log_dir, exist_ok=True)
def on_epoch_end(self, trainer, epoch, metrics, **kwargs):
# Log metrics
for name, value in metrics.items():
self.metrics_history[name].append(value)
# Log to file
with open(f"{self.log_dir}/training_log.txt", "a") as f:
f.write(f"Epoch {epoch}: {metrics}\n")
# Advanced logging every 10 epochs
if epoch % 10 == 0:
self._advanced_logging(trainer, epoch, metrics)
def _advanced_logging(self, trainer, epoch, metrics):
"""Advanced logging with model analysis."""
# Log gradient norms
total_norm = 0
for p in trainer.model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
print(f" Gradient norm: {total_norm:.4f}")
# Log learning rate
if trainer.scheduler:
current_lr = trainer.optimizer.param_groups[0]['lr']
print(f" Learning rate: {current_lr:.6f}")
class WarmupCallback(Callback):
"""Learning rate warmup callback."""
def __init__(self, warmup_epochs=5, base_lr=0.001):
self.warmup_epochs = warmup_epochs
self.base_lr = base_lr
def on_epoch_start(self, trainer, epoch, **kwargs):
if epoch < self.warmup_epochs:
# Linear warmup
lr = self.base_lr * (epoch + 1) / self.warmup_epochs
for param_group in trainer.optimizer.param_groups:
param_group['lr'] = lr
print(f" Warmup LR: {lr:.6f}")
Step 4: Advanced Training Configuration
# Advanced optimizer configuration
optimizer_config = OptimizerConfig(
optimizer_class="AdamW",
lr=0.001,
weight_decay=0.01,
params={
"betas": (0.9, 0.999),
"eps": 1e-8,
"amsgrad": True # Use AMSGrad variant
}
)
# Learning rate scheduler configuration
scheduler_config = SchedulerConfig(
scheduler_class="CosineAnnealingLR",
params={
"T_max": 200, # Maximum number of iterations
"eta_min": 1e-6 # Minimum learning rate
}
)
# Advanced training configuration
config = TrainingConfig(
# Training parameters
epochs=200,
device="auto",
# Performance optimizations
mixed_precision=True,
accumulate_grad_batches=2, # Effective batch size = 128 * 2 = 256
grad_clip_norm=1.0, # Gradient clipping
# Validation and monitoring
validation_frequency=1,
log_frequency=50,
# Early stopping (generous for long training)
early_stopping_patience=30,
early_stopping_min_delta=0.0001,
# Checkpointing
checkpoint_dir="./checkpoints/advanced_cifar100",
save_best_model=True,
save_last_model=True,
checkpoint_frequency=10, # Save every 10 epochs
# Optimizer and scheduler
optimizer=optimizer_config,
scheduler=scheduler_config
)
Step 5: Advanced Metrics
class AdvancedMetrics:
"""Collection of advanced metrics."""
@staticmethod
def accuracy(predictions, targets):
pred_classes = torch.argmax(predictions, dim=1)
return (pred_classes == targets).float().mean().item()
@staticmethod
def top_k_accuracy(predictions, targets, k=5):
_, top_k_preds = torch.topk(predictions, k, dim=1)
targets_expanded = targets.view(-1, 1).expand_as(top_k_preds)
correct = (top_k_preds == targets_expanded).any(dim=1)
return correct.float().mean().item()
@staticmethod
def precision_at_k(predictions, targets, k=5):
_, top_k_preds = torch.topk(predictions, k, dim=1)
correct = (top_k_preds == targets.view(-1, 1)).float()
return correct.sum(dim=1).mean().item() / k
@staticmethod
def confidence_score(predictions):
probabilities = F.softmax(predictions, dim=1)
max_probs = torch.max(probabilities, dim=1)[0]
return max_probs.mean().item()
# Create metrics dictionary
custom_metrics = {
'accuracy': AdvancedMetrics.accuracy,
'top5_accuracy': lambda p, t: AdvancedMetrics.top_k_accuracy(p, t, k=5),
'precision_at_5': lambda p, t: AdvancedMetrics.precision_at_k(p, t, k=5),
'confidence': lambda p, t: AdvancedMetrics.confidence_score(p)
}
Step 6: Advanced Callbacks Setup
# Create advanced callbacks
callbacks = [
# Learning rate warmup
WarmupCallback(warmup_epochs=5, base_lr=0.001),
# Custom logging
CustomLoggingCallback(log_dir='./logs/advanced_training'),
# Early stopping with validation loss
EarlyStopping(
monitor='val_loss',
patience=30,
min_delta=0.0001,
verbose=True,
mode='min'
),
# Model checkpointing
ModelCheckpoint(
filepath='./checkpoints/advanced_cifar100/best_model_{epoch:03d}_{val_acc:.4f}.pt',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=True,
save_top_k=3 # Keep top 3 models
),
# Reduce learning rate on plateau
ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=10,
min_lr=1e-7,
verbose=True
)
]
Step 7: Advanced Loss Functions
class LabelSmoothingLoss(nn.Module):
"""Label smoothing loss for better generalization."""
def __init__(self, num_classes, smoothing=0.1):
super().__init__()
self.num_classes = num_classes
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
def forward(self, predictions, targets):
log_probs = F.log_softmax(predictions, dim=1)
# Create smoothed targets
true_dist = torch.zeros_like(log_probs)
true_dist.fill_(self.smoothing / (self.num_classes - 1))
true_dist.scatter_(1, targets.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * log_probs, dim=1))
# Use label smoothing loss
loss_fn = LabelSmoothingLoss(num_classes=100, smoothing=0.1)
Step 8: Training with All Advanced Features
# Create advanced trainer
trainer = Trainer(
model=model,
config=config,
train_dataloader=train_loader,
val_dataloader=val_loader,
loss_fn=loss_fn,
metric_fns=custom_metrics,
callbacks=callbacks
)
# Print training setup
print("🚀 Advanced Training Setup:")
print(f" Model: {type(model).__name__} ({count_parameters(model):,} params)")
print(f" Optimizer: {config.optimizer.optimizer_class.__name__}")
print(f" Scheduler: {config.scheduler.scheduler_class.__name__}")
print(f" Mixed Precision: {config.mixed_precision}")
print(f" Gradient Accumulation: {config.accumulate_grad_batches}")
print(f" Device: {trainer.device}")
print(f" Callbacks: {len(callbacks)}")
print("-" * 80)
# Start advanced training
history = trainer.train()
print("✅ Advanced training completed!")
Step 9: Advanced Evaluation and Analysis
def advanced_evaluation(trainer, test_loader, class_names=None):
"""Comprehensive model evaluation with advanced metrics."""
print("📊 Comprehensive Model Evaluation")
print("=" * 50)
# Basic evaluation
test_results = trainer.evaluate(test_loader)
print(f"Test Results:")
for metric_name, value in test_results.items():
print(f" {metric_name.replace('_', ' ').title()}: {value:.4f}")
# Detailed predictions analysis
all_predictions = []
all_targets = []
all_confidences = []
trainer.model.eval()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(trainer.device), target.to(trainer.device)
output = trainer.model(data)
# Get predictions and confidence
probs = F.softmax(output, dim=1)
pred = torch.argmax(output, dim=1)
confidence = torch.max(probs, dim=1)[0]
all_predictions.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_confidences.extend(confidence.cpu().numpy())
# Convert to numpy arrays
predictions = np.array(all_predictions)
targets = np.array(all_targets)
confidences = np.array(all_confidences)
# Detailed analysis
print(f"\n🔍 Detailed Analysis:")
print(f" Average Confidence: {confidences.mean():.4f}")
print(f" Confidence Std: {confidences.std():.4f}")
print(f" Low Confidence Samples (<0.5): {(confidences < 0.5).sum()}")
print(f" High Confidence Samples (>0.9): {(confidences > 0.9).sum()}")
# Per-class accuracy
unique_classes = np.unique(targets)
print(f"\n📈 Per-Class Performance (showing top 10 and bottom 10):")
class_accuracies = []
for cls in unique_classes:
mask = targets == cls
if mask.sum() > 0:
acc = (predictions[mask] == targets[mask]).mean()
class_accuracies.append((cls, acc))
# Sort by accuracy
class_accuracies.sort(key=lambda x: x[1], reverse=True)
print(" Best performing classes:")
for cls, acc in class_accuracies[:10]:
print(f" Class {cls:2d}: {acc:.4f}")
print(" Worst performing classes:")
for cls, acc in class_accuracies[-10:]:
print(f" Class {cls:2d}: {acc:.4f}")
return {
'predictions': predictions,
'targets': targets,
'confidences': confidences,
'class_accuracies': class_accuracies
}
# Run advanced evaluation
eval_results = advanced_evaluation(trainer, test_loader)
Step 10: Model Optimization and Export
def optimize_model_for_inference(model, example_input):
"""Optimize model for production inference."""
# Set to evaluation mode
model.eval()
# Trace the model for optimization
traced_model = torch.jit.trace(model, example_input)
# Optimize for inference
traced_model = torch.jit.optimize_for_inference(traced_model)
return traced_model
def export_model(model, filepath, example_input=None):
"""Export model in multiple formats."""
print(f"📦 Exporting model to {filepath}")
# Save complete checkpoint
checkpoint = {
'model_state_dict': model.state_dict(),
'model_class': type(model).__name__,
'model_config': {
'num_classes': 100,
'dropout_rate': 0.3
},
'training_history': history,
'performance_metrics': eval_results
}
torch.save(checkpoint, f"{filepath}_complete.pt")
# Save optimized model for inference
if example_input is not None:
optimized_model = optimize_model_for_inference(model, example_input)
torch.jit.save(optimized_model, f"{filepath}_optimized.pt")
# Save just the state dict (smaller file)
torch.save(model.state_dict(), f"{filepath}_weights.pt")
print("✅ Model export completed!")
# Export the trained model
example_input = torch.randn(1, 3, 32, 32).to(trainer.device)
export_model(trainer.model, "./models/advanced_cifar100", example_input)
Advanced Training Summary
🎯 Advanced Features Used:
✅ Data Pipeline: - Sophisticated augmentation strategies - Efficient data loading with multiple workers - Advanced normalization and preprocessing
✅ Model Architecture: - Residual connections for better gradient flow - Batch normalization for training stability - Proper weight initialization - Dropout for regularization
✅ Training Optimization: - Mixed precision training (faster + less memory) - Gradient accumulation (larger effective batch size) - Gradient clipping (training stability) - Label smoothing (better generalization)
✅ Learning Rate Management: - Warmup for stable training start - Cosine annealing scheduling - Reduce on plateau for fine-tuning
✅ Monitoring and Callbacks: - Custom logging and metrics tracking - Advanced early stopping - Multiple model checkpointing - Performance monitoring
✅ Evaluation and Analysis: - Comprehensive metrics (accuracy, top-k, confidence) - Per-class performance analysis - Model confidence analysis - Production-ready export
📊 Expected Results:
With this advanced setup, you should achieve: - CIFAR-100 accuracy: 70-75% (vs ~45% random) - Training stability: Smooth convergence curves - Generalization: Good test performance - Efficiency: Fast training with mixed precision
🚀 Production Readiness:
The trained model is ready for production with: - Optimized inference models - Complete checkpoints for resuming - Comprehensive performance metrics - Export in multiple formats
This advanced example demonstrates how Treadmill scales from simple scripts to production-ready training pipelines while maintaining clean, readable code! 🏃♀️➡️
Next Steps
Explore multi-GPU training with DataParallel
Try distributed training with DistributedDataParallel
Implement custom optimizers and schedulers
Add Weights & Biases integration for experiment tracking
Deploy models with TorchServe or TensorRT