Complete Image Classification Tutorial
This comprehensive tutorial will guide you through building a complete image classification system using Treadmill. We’ll cover everything from data preparation to model evaluation, using the CIFAR-10 dataset as our example.
Tutorial Overview
What You’ll Learn: - Setting up data pipelines for image classification - Building and training CNN architectures - Advanced training techniques and optimizations - Model evaluation and analysis - Best practices for real-world deployment
Prerequisites: - Basic PyTorch knowledge - Understanding of convolutional neural networks - Python programming experience
Estimated Time: 45-60 minutes
Step 1: Environment Setup
First, let’s ensure you have all the necessary dependencies:
# Install Treadmill with full dependencies
pip install -e ".[full]"
# Additional packages for this tutorial
pip install matplotlib seaborn
Now let’s import all the necessary libraries:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
# Treadmill imports
from treadmill import Trainer, TrainingConfig, OptimizerConfig
from treadmill.callbacks import EarlyStopping, ModelCheckpoint
from treadmill.metrics import MetricsTracker
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
Step 2: Dataset Preparation
CIFAR-10 Dataset Overview
CIFAR-10 consists of 60,000 32×32 color images in 10 classes: - Airplane, Automobile, Bird, Cat, Deer - Dog, Frog, Horse, Ship, Truck
Let’s load and explore the dataset:
# Define the class names
class_names = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
# Data augmentation and normalization transforms
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# Validation/test transform (no augmentation)
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# Load datasets
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=val_transform
)
# Split training data into train/validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])
# Apply validation transform to validation subset
val_subset.dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=False, transform=val_transform
)
print(f"Training samples: {len(train_subset)}")
print(f"Validation samples: {len(val_subset)}")
print(f"Test samples: {len(test_dataset)}")
Data Exploration
Let’s visualize some samples to understand our data better:
def visualize_samples(dataset, num_samples=16, title="Sample Images"):
"""Visualize random samples from the dataset."""
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
fig.suptitle(title, fontsize=16)
for i in range(num_samples):
# Get random sample
idx = np.random.randint(0, len(dataset))
image, label = dataset[idx]
# Denormalize for visualization
mean = torch.tensor([0.4914, 0.4822, 0.4465])
std = torch.tensor([0.2023, 0.1994, 0.2010])
image = image * std.view(3, 1, 1) + mean.view(3, 1, 1)
image = torch.clamp(image, 0, 1)
# Plot
ax = axes[i // 4, i % 4]
ax.imshow(image.permute(1, 2, 0))
ax.set_title(f"{class_names[label]}")
ax.axis('off')
plt.tight_layout()
plt.show()
# Visualize training samples
visualize_samples(train_subset, title="Training Samples")
Class Distribution Analysis
def analyze_class_distribution(dataset, title="Class Distribution"):
"""Analyze and visualize class distribution."""
class_counts = torch.zeros(10)
for _, label in tqdm(dataset, desc="Analyzing distribution"):
class_counts[label] += 1
# Plot distribution
plt.figure(figsize=(12, 6))
bars = plt.bar(class_names, class_counts, color='skyblue', edgecolor='navy')
plt.title(title)
plt.xlabel('Classes')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45)
# Add value labels on bars
for bar, count in zip(bars, class_counts):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
f'{int(count)}', ha='center', va='bottom')
plt.tight_layout()
plt.show()
return class_counts
train_distribution = analyze_class_distribution(train_subset, "Training Set Class Distribution")
Create Data Loaders
# Create data loaders with optimal settings
batch_size = 128 # Adjust based on your GPU memory
num_workers = 4 # Adjust based on your CPU cores
train_loader = DataLoader(
train_subset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True # Faster GPU transfer
)
val_loader = DataLoader(
val_subset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")
Step 3: Model Architecture
Custom CNN Architecture
Let’s build a modern CNN architecture with best practices:
class CIFAR10CNN(nn.Module):
"""
Modern CNN architecture for CIFAR-10 classification.
Features:
- Residual connections for better gradient flow
- Batch normalization for stable training
- Dropout for regularization
- Adaptive pooling for flexible input sizes
"""
def __init__(self, num_classes=10, dropout_rate=0.3):
super(CIFAR10CNN, self).__init__()
# First block: 32x32 -> 16x16
self.block1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# Second block: 16x16 -> 8x8
self.block2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# Third block: 8x8 -> 4x4
self.block3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# Fourth block: 4x4 -> 2x2
self.block4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
# Global average pooling and classifier
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, num_classes)
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize model weights using He initialization."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.global_pool(x)
x = x.view(x.size(0), -1) # Flatten
x = self.classifier(x)
return x
# Create model instance
model = CIFAR10CNN(num_classes=10, dropout_rate=0.3)
# Display model architecture
from torchsummary import summary
summary(model, input_size=(3, 32, 32))
Model Analysis
def count_parameters(model):
"""Count trainable and non-trainable parameters."""
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable = sum(p.numel() for p in model.parameters() if not p.requires_grad)
print(f"Trainable parameters: {trainable:,}")
print(f"Non-trainable parameters: {non_trainable:,}")
print(f"Total parameters: {trainable + non_trainable:,}")
return trainable, non_trainable
count_parameters(model)
Step 4: Custom Metrics
Let’s define comprehensive metrics for evaluating our model:
def accuracy(predictions, targets):
"""Calculate accuracy."""
pred_classes = torch.argmax(predictions, dim=1)
return (pred_classes == targets).float().mean().item()
def top_k_accuracy(predictions, targets, k=3):
"""Calculate top-k accuracy."""
_, top_k_preds = torch.topk(predictions, k, dim=1)
targets_expanded = targets.view(-1, 1).expand_as(top_k_preds)
correct = (top_k_preds == targets_expanded).any(dim=1)
return correct.float().mean().item()
def per_class_accuracy(predictions, targets, num_classes=10):
"""Calculate per-class accuracy."""
pred_classes = torch.argmax(predictions, dim=1)
class_correct = torch.zeros(num_classes)
class_total = torch.zeros(num_classes)
for i in range(targets.size(0)):
label = targets[i]
class_total[label] += 1
if pred_classes[i] == label:
class_correct[label] += 1
# Avoid division by zero
class_acc = class_correct / (class_total + 1e-8)
return class_acc.mean().item()
# Custom metrics dictionary
custom_metrics = {
'accuracy': accuracy,
'top3_accuracy': lambda p, t: top_k_accuracy(p, t, k=3),
'per_class_acc': per_class_accuracy
}
Step 5: Training Configuration
Advanced Training Setup
# Optimizer configuration with learning rate scheduling
optimizer_config = OptimizerConfig(
optimizer_class="AdamW", # AdamW often works better than Adam
lr=0.001,
weight_decay=0.01, # L2 regularization
params={
"betas": (0.9, 0.999),
"eps": 1e-8
}
)
# Training configuration with all optimizations enabled
config = TrainingConfig(
# Basic training parameters
epochs=100,
device="auto", # Auto-detect GPU/CPU
# Performance optimizations
mixed_precision=True, # Faster training on modern GPUs
accumulate_grad_batches=1, # Simulate larger batch sizes
grad_clip_norm=1.0, # Gradient clipping
# Validation and logging
validation_frequency=1, # Validate every epoch
log_frequency=50, # Log every 50 batches
# Early stopping configuration
early_stopping_patience=15, # Stop if no improvement for 15 epochs
early_stopping_min_delta=0.001, # Minimum improvement threshold
# Checkpointing
checkpoint_dir="./checkpoints/cifar10_cnn",
save_best_model=True,
save_last_model=True,
# Optimizer and scheduler
optimizer=optimizer_config,
# scheduler will be added after creating trainer
)
Advanced Callbacks
from treadmill.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
# Early stopping callback
early_stopping = EarlyStopping(
monitor='val_loss',
patience=15,
min_delta=0.001,
verbose=True,
mode='min'
)
# Model checkpointing callback
model_checkpoint = ModelCheckpoint(
filepath='./checkpoints/cifar10_cnn/best_model_{epoch:02d}_{val_acc:.4f}.pt',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=True
)
# Learning rate reduction callback
reduce_lr = ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-7,
verbose=True
)
callbacks = [early_stopping, model_checkpoint, reduce_lr]
Step 6: Training Process
Initialize and Train
# Create trainer
trainer = Trainer(
model=model,
config=config,
train_dataloader=train_loader,
val_dataloader=val_loader,
loss_fn=nn.CrossEntropyLoss(label_smoothing=0.1), # Label smoothing
metric_fns=custom_metrics,
callbacks=callbacks
)
# Print training information
print("🚀 Starting CIFAR-10 Image Classification Training")
print(f"📊 Dataset: {len(train_subset)} train, {len(val_subset)} val, {len(test_dataset)} test")
print(f"🏗️ Model: {sum(p.numel() for p in model.parameters() if p.requires_grad):,} parameters")
print(f"💾 Device: {trainer.device}")
print(f"⚡ Mixed precision: {config.mixed_precision}")
print("-" * 80)
# Train the model
history = trainer.train()
print("🎉 Training completed!")
Training Visualization
def plot_training_history(history, save_path=None):
"""Plot comprehensive training history."""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training History', fontsize=16, fontweight='bold')
# Loss plot
axes[0, 0].plot(history['train_loss'], label='Training Loss', linewidth=2)
if 'val_loss' in history:
axes[0, 0].plot(history['val_loss'], label='Validation Loss', linewidth=2)
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Accuracy plot
if 'train_accuracy' in history:
axes[0, 1].plot(history['train_accuracy'], label='Training Accuracy', linewidth=2)
if 'val_accuracy' in history:
axes[0, 1].plot(history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0, 1].set_title('Accuracy')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# Top-3 accuracy plot
if 'train_top3_accuracy' in history:
axes[1, 0].plot(history['train_top3_accuracy'], label='Training Top-3', linewidth=2)
if 'val_top3_accuracy' in history:
axes[1, 0].plot(history['val_top3_accuracy'], label='Validation Top-3', linewidth=2)
axes[1, 0].set_title('Top-3 Accuracy')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Top-3 Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
# Learning rate plot (if available)
if hasattr(trainer, 'lr_history') and trainer.lr_history:
axes[1, 1].plot(trainer.lr_history, linewidth=2, color='red')
axes[1, 1].set_title('Learning Rate')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True, alpha=0.3)
else:
axes[1, 1].text(0.5, 0.5, 'Learning Rate\nHistory\nNot Available',
ha='center', va='center', transform=axes[1, 1].transAxes)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
# Plot training history
plot_training_history(history, save_path='training_history.png')
Step 7: Model Evaluation
Comprehensive Test Set Evaluation
# Evaluate on test set
print("📊 Evaluating on test set...")
test_results = trainer.evaluate(test_loader)
print("\n🎯 Test Set Results:")
for metric_name, value in test_results.items():
print(f" {metric_name.replace('_', ' ').title()}: {value:.4f}")
Detailed Classification Report
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
def detailed_evaluation(trainer, test_loader, class_names):
"""Generate detailed evaluation including confusion matrix and classification report."""
# Get predictions and true labels
all_predictions = []
all_targets = []
trainer.model.eval()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(tqdm(test_loader, desc="Generating predictions")):
data, target = data.to(trainer.device), target.to(trainer.device)
output = trainer.model(data)
pred = torch.argmax(output, dim=1)
all_predictions.extend(pred.cpu().numpy())
all_targets.extend(target.cpu().numpy())
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
# Classification report
report = classification_report(
all_targets, all_predictions,
target_names=class_names,
output_dict=True
)
print("\n📋 Detailed Classification Report:")
print(classification_report(all_targets, all_predictions, target_names=class_names))
# Confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
# Per-class accuracy analysis
per_class_acc = cm.diagonal() / cm.sum(axis=1)
plt.figure(figsize=(12, 6))
bars = plt.bar(class_names, per_class_acc, color='lightcoral', edgecolor='darkred')
plt.title('Per-Class Accuracy')
plt.xlabel('Classes')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
plt.ylim(0, 1)
# Add value labels on bars
for bar, acc in zip(bars, per_class_acc):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{acc:.3f}', ha='center', va='bottom')
plt.tight_layout()
plt.savefig('per_class_accuracy.png', dpi=300, bbox_inches='tight')
plt.show()
return report, cm, per_class_acc
# Generate detailed evaluation
report, confusion_mat, per_class_acc = detailed_evaluation(trainer, test_loader, class_names)
Error Analysis
def analyze_misclassifications(trainer, test_loader, class_names, num_examples=16):
"""Analyze and visualize misclassified examples."""
misclassified = []
trainer.model.eval()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(trainer.device), target.to(trainer.device)
output = trainer.model(data)
pred = torch.argmax(output, dim=1)
# Find misclassified examples
mask = pred != target
if mask.any():
for i in range(data.size(0)):
if mask[i] and len(misclassified) < num_examples:
# Store image, true label, predicted label, and confidence
confidence = F.softmax(output[i], dim=0)[pred[i]].item()
misclassified.append({
'image': data[i].cpu(),
'true_label': target[i].item(),
'pred_label': pred[i].item(),
'confidence': confidence
})
if len(misclassified) >= num_examples:
break
# Visualize misclassified examples
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
fig.suptitle('Misclassified Examples', fontsize=16, fontweight='bold')
for i, example in enumerate(misclassified[:16]):
ax = axes[i // 4, i % 4]
# Denormalize image for visualization
image = example['image']
mean = torch.tensor([0.4914, 0.4822, 0.4465])
std = torch.tensor([0.2023, 0.1994, 0.2010])
image = image * std.view(3, 1, 1) + mean.view(3, 1, 1)
image = torch.clamp(image, 0, 1)
ax.imshow(image.permute(1, 2, 0))
ax.set_title(f"True: {class_names[example['true_label']]}\n"
f"Pred: {class_names[example['pred_label']]}\n"
f"Conf: {example['confidence']:.2f}", fontsize=10)
ax.axis('off')
plt.tight_layout()
plt.savefig('misclassified_examples.png', dpi=300, bbox_inches='tight')
plt.show()
# Analyze misclassifications
analyze_misclassifications(trainer, test_loader, class_names)
Step 8: Model Interpretation
Feature Map Visualization
def visualize_feature_maps(model, data_loader, layer_name, num_images=4):
"""Visualize feature maps from a specific layer."""
# Hook to capture feature maps
feature_maps = []
def hook_fn(module, input, output):
feature_maps.append(output)
# Register hook
target_layer = dict(model.named_modules())[layer_name]
hook = target_layer.register_forward_hook(hook_fn)
model.eval()
with torch.no_grad():
for batch_idx, (data, _) in enumerate(data_loader):
if batch_idx >= num_images:
break
data = data.to(next(model.parameters()).device)
_ = model(data[:1]) # Process one image at a time
# Visualize first few feature maps
fmaps = feature_maps[-1][0] # First image in batch
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
fig.suptitle(f'Feature Maps from {layer_name} - Image {batch_idx + 1}')
for i in range(min(16, fmaps.shape[0])):
ax = axes[i // 8, i % 8]
ax.imshow(fmaps[i].cpu(), cmap='viridis')
ax.set_title(f'Filter {i}')
ax.axis('off')
plt.tight_layout()
plt.show()
# Remove hook
hook.remove()
# Visualize feature maps from different layers
# visualize_feature_maps(model, test_loader, 'block1.0', num_images=2)
Step 9: Model Saving and Deployment
Save Final Model
# Save the complete model
final_model_path = './models/cifar10_cnn_final.pt'
torch.save({
'model_state_dict': model.state_dict(),
'model_config': {
'num_classes': 10,
'dropout_rate': 0.3
},
'training_config': config,
'training_history': history,
'test_results': test_results,
'class_names': class_names
}, final_model_path)
print(f"✅ Model saved to {final_model_path}")
Load and Use Saved Model
def load_trained_model(model_path, device='cpu'):
"""Load a trained model for inference."""
checkpoint = torch.load(model_path, map_location=device)
# Reconstruct model
model_config = checkpoint['model_config']
model = CIFAR10CNN(**model_config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
return model, checkpoint
# Example usage
# loaded_model, checkpoint = load_trained_model(final_model_path, device='cuda')
# class_names = checkpoint['class_names']
Inference Function
def predict_image(model, image, class_names, transform=None, device='cpu'):
"""Predict class for a single image."""
if transform is None:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
model.eval()
with torch.no_grad():
if isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.unsqueeze(0) # Add batch dimension
else:
# Assume PIL Image
image = transform(image).unsqueeze(0)
image = image.to(device)
output = model(image)
probabilities = F.softmax(output, dim=1)
# Get top-k predictions
top_probs, top_indices = torch.topk(probabilities, k=5)
predictions = []
for prob, idx in zip(top_probs[0], top_indices[0]):
predictions.append({
'class': class_names[idx],
'probability': prob.item()
})
return predictions
# Example inference
# predictions = predict_image(model, test_image, class_names, device=trainer.device)
# print("Top predictions:")
# for pred in predictions:
# print(f" {pred['class']}: {pred['probability']:.4f}")
Summary and Best Practices
What We Accomplished:
✅ Built a complete image classification pipeline ✅ Implemented modern CNN architecture with best practices ✅ Used advanced training techniques (mixed precision, label smoothing) ✅ Implemented comprehensive evaluation and error analysis ✅ Created reusable inference functions
Key Best Practices Demonstrated:
Data Preparation: - Proper train/validation splitting - Data augmentation for better generalization - Normalization using dataset statistics
Model Architecture: - Batch normalization for training stability - Dropout for regularization - Residual connections for better gradient flow - Proper weight initialization
Training Optimization: - Mixed precision training for speed - Label smoothing for better calibration - Gradient clipping for stability - Learning rate scheduling
Evaluation: - Multiple metrics for comprehensive assessment - Confusion matrix analysis - Per-class performance analysis - Error analysis for insights
Production Readiness: - Model checkpointing and saving - Inference pipeline - Comprehensive logging
Next Steps:
Experiment with different architectures (ResNet, EfficientNet)
Try transfer learning with pre-trained models
Implement advanced techniques (CutMix, MixUp)
Deploy the model using TorchServe or similar frameworks
This tutorial provides a solid foundation for real-world image classification projects using Treadmill! 🚀