Encoder-Decoder Architecture Example

This example demonstrates how to build and train encoder-decoder architectures for sequence-to-sequence tasks using Treadmill. We’ll cover various seq2seq models including attention mechanisms and transformer-style architectures.

Overview

What you’ll learn: - Basic encoder-decoder architecture - Attention mechanisms (Bahdanau & Luong) - Transformer-based encoder-decoder - Sequence-to-sequence training techniques - Custom loss functions for sequences - Beam search decoding

Use cases covered: - Machine translation (English to French) - Text summarization - Sequence generation - Time series forecasting

Estimated time: 45-60 minutes

Prerequisites

pip install -e ".[full]"
pip install nltk spacy datasets transformers

Basic Encoder-Decoder Architecture

Let’s start with a simple RNN-based encoder-decoder:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import random
from collections import defaultdict

from treadmill import Trainer, TrainingConfig, OptimizerConfig
from treadmill.callbacks import EarlyStopping, ModelCheckpoint

# Set seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

Step 1: Data Preparation for Seq2Seq

class Seq2SeqDataset(Dataset):
    """Dataset for sequence-to-sequence tasks."""

    def __init__(self, source_sequences, target_sequences,
                 source_vocab, target_vocab, max_length=50):
        self.source_sequences = source_sequences
        self.target_sequences = target_sequences
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        self.max_length = max_length

        # Special tokens
        self.source_vocab['<PAD>'] = 0
        self.source_vocab['<UNK>'] = 1
        self.source_vocab['<SOS>'] = 2
        self.source_vocab['<EOS>'] = 3

        self.target_vocab['<PAD>'] = 0
        self.target_vocab['<UNK>'] = 1
        self.target_vocab['<SOS>'] = 2
        self.target_vocab['<EOS>'] = 3

    def __len__(self):
        return len(self.source_sequences)

    def __getitem__(self, idx):
        source_seq = self.source_sequences[idx]
        target_seq = self.target_sequences[idx]

        # Convert to indices and add special tokens
        source_indices = [self.source_vocab['<SOS>']]
        source_indices.extend([
            self.source_vocab.get(token, self.source_vocab['<UNK>'])
            for token in source_seq
        ])
        source_indices.append(self.source_vocab['<EOS>'])

        target_indices = [self.target_vocab['<SOS>']]
        target_indices.extend([
            self.target_vocab.get(token, self.target_vocab['<UNK>'])
            for token in target_seq
        ])
        target_indices.append(self.target_vocab['<EOS>'])

        # Pad sequences
        source_indices = source_indices[:self.max_length]
        target_indices = target_indices[:self.max_length]

        source_indices += [0] * (self.max_length - len(source_indices))
        target_indices += [0] * (self.max_length - len(target_indices))

        return {
            'source': torch.LongTensor(source_indices),
            'target': torch.LongTensor(target_indices),
            'source_length': len(source_seq) + 2,  # +2 for SOS, EOS
            'target_length': len(target_seq) + 2
        }

def create_translation_data():
    """Create simple translation dataset (English to simplified French)."""

    # Simple English-French pairs for demonstration
    pairs = [
        (["hello", "world"], ["bonjour", "monde"]),
        (["good", "morning"], ["bon", "matin"]),
        (["how", "are", "you"], ["comment", "allez", "vous"]),
        (["thank", "you"], ["merci"]),
        (["goodbye"], ["au", "revoir"]),
        (["yes"], ["oui"]),
        (["no"], ["non"]),
        (["please"], ["s'il", "vous", "plait"]),
        (["excuse", "me"], ["excusez", "moi"]),
        (["I", "love", "you"], ["je", "t'aime"])
    ] * 100  # Repeat for more training data

    # Build vocabularies
    source_vocab = {}
    target_vocab = {}
    vocab_id = 4  # Start after special tokens

    for source, target in pairs:
        for token in source:
            if token not in source_vocab:
                source_vocab[token] = vocab_id
                vocab_id += 1
        for token in target:
            if token not in target_vocab:
                target_vocab[token] = vocab_id
                vocab_id += 1

    # Split data
    train_pairs = pairs[:int(0.8 * len(pairs))]
    val_pairs = pairs[int(0.8 * len(pairs)):]

    train_source = [pair[0] for pair in train_pairs]
    train_target = [pair[1] for pair in train_pairs]
    val_source = [pair[0] for pair in val_pairs]
    val_target = [pair[1] for pair in val_pairs]

    return (train_source, train_target, val_source, val_target,
            source_vocab, target_vocab)

# Create dataset
(train_source, train_target, val_source, val_target,
 source_vocab, target_vocab) = create_translation_data()

print(f"Dataset created:")
print(f"  Source vocab size: {len(source_vocab)}")
print(f"  Target vocab size: {len(target_vocab)}")
print(f"  Training pairs: {len(train_source)}")
print(f"  Validation pairs: {len(val_source)}")

Step 2: Basic Encoder-Decoder Model

class Encoder(nn.Module):
    """RNN-based encoder."""

    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers,
                           batch_first=True, bidirectional=True)

    def forward(self, x, lengths=None):
        # Embed input
        embedded = self.embedding(x)

        # Pack sequences for efficiency
        if lengths is not None:
            embedded = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )

        # Encode
        output, (hidden, cell) = self.lstm(embedded)

        # Unpack
        if lengths is not None:
            output, _ = nn.utils.rnn.pad_packed_sequence(
                output, batch_first=True
            )

        return output, (hidden, cell)

class Decoder(nn.Module):
    """RNN-based decoder."""

    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers=1):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.output_projection = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden_state, cell_state):
        # Embed input
        embedded = self.embedding(x)

        # Decode
        output, (hidden, cell) = self.lstm(embedded, (hidden_state, cell_state))

        # Project to vocabulary
        output = self.output_projection(output)

        return output, (hidden, cell)

class Seq2Seq(nn.Module):
    """Basic sequence-to-sequence model."""

    def __init__(self, source_vocab_size, target_vocab_size,
                 embed_dim=256, hidden_dim=512, num_layers=1):
        super().__init__()

        self.encoder = Encoder(source_vocab_size, embed_dim,
                             hidden_dim * 2, num_layers)  # *2 for bidirectional
        self.decoder = Decoder(target_vocab_size, embed_dim,
                             hidden_dim, num_layers)

        # Bridge from bidirectional encoder to decoder
        self.bridge_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
        self.bridge_cell = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, source, target, source_lengths=None):
        batch_size = source.size(0)
        target_length = target.size(1)

        # Encode
        encoder_outputs, (encoder_hidden, encoder_cell) = self.encoder(
            source, source_lengths
        )

        # Bridge encoder states to decoder
        # Take the last layer's hidden states from both directions
        encoder_hidden = encoder_hidden[-2:].transpose(0, 1).contiguous()
        encoder_cell = encoder_cell[-2:].transpose(0, 1).contiguous()

        encoder_hidden = encoder_hidden.view(batch_size, -1)
        encoder_cell = encoder_cell.view(batch_size, -1)

        decoder_hidden = self.bridge_hidden(encoder_hidden).unsqueeze(0)
        decoder_cell = self.bridge_cell(encoder_cell).unsqueeze(0)

        # Decode
        outputs = []
        decoder_input = target[:, 0:1]  # Start with SOS token

        for t in range(target_length - 1):
            output, (decoder_hidden, decoder_cell) = self.decoder(
                decoder_input, decoder_hidden, decoder_cell
            )
            outputs.append(output)
            decoder_input = target[:, t+1:t+2]  # Teacher forcing

        return torch.cat(outputs, dim=1)

Step 3: Attention Mechanism

class BahdanauAttention(nn.Module):
    """Bahdanau (additive) attention mechanism."""

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.encoder_projection = nn.Linear(encoder_dim, attention_dim)
        self.decoder_projection = nn.Linear(decoder_dim, attention_dim)
        self.attention_vector = nn.Linear(attention_dim, 1)

    def forward(self, encoder_outputs, decoder_hidden):
        # encoder_outputs: (batch, seq_len, encoder_dim)
        # decoder_hidden: (batch, decoder_dim)

        batch_size, seq_len, encoder_dim = encoder_outputs.shape

        # Project encoder outputs
        encoder_proj = self.encoder_projection(encoder_outputs)  # (batch, seq_len, att_dim)

        # Project and expand decoder hidden state
        decoder_proj = self.decoder_projection(decoder_hidden)  # (batch, att_dim)
        decoder_proj = decoder_proj.unsqueeze(1).expand(-1, seq_len, -1)  # (batch, seq_len, att_dim)

        # Compute attention scores
        energy = torch.tanh(encoder_proj + decoder_proj)  # (batch, seq_len, att_dim)
        attention_scores = self.attention_vector(energy).squeeze(2)  # (batch, seq_len)

        # Compute attention weights
        attention_weights = F.softmax(attention_scores, dim=1)  # (batch, seq_len)

        # Compute context vector
        context = torch.bmm(attention_weights.unsqueeze(1),
                           encoder_outputs).squeeze(1)  # (batch, encoder_dim)

        return context, attention_weights

class AttentionDecoder(nn.Module):
    """Decoder with attention mechanism."""

    def __init__(self, vocab_size, embed_dim, hidden_dim, encoder_dim,
                 attention_dim=256, num_layers=1):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attention = BahdanauAttention(encoder_dim, hidden_dim, attention_dim)

        # LSTM input includes embedding + context
        self.lstm = nn.LSTM(embed_dim + encoder_dim, hidden_dim,
                           num_layers, batch_first=True)

        # Output projection
        self.output_projection = nn.Linear(hidden_dim + encoder_dim + embed_dim,
                                         vocab_size)

    def forward(self, x, hidden_state, cell_state, encoder_outputs):
        # Embed input
        embedded = self.embedding(x)

        # Compute attention
        context, attention_weights = self.attention(
            encoder_outputs, hidden_state.squeeze(0)
        )

        # Concatenate embedding and context
        lstm_input = torch.cat([embedded, context.unsqueeze(1)], dim=2)

        # Decode
        output, (hidden, cell) = self.lstm(lstm_input, (hidden_state, cell_state))

        # Concatenate LSTM output, context, and embedding for final projection
        final_output = torch.cat([output, context.unsqueeze(1), embedded], dim=2)
        output = self.output_projection(final_output)

        return output, (hidden, cell), attention_weights

class AttentionSeq2Seq(nn.Module):
    """Sequence-to-sequence model with attention."""

    def __init__(self, source_vocab_size, target_vocab_size,
                 embed_dim=256, hidden_dim=512, attention_dim=256, num_layers=1):
        super().__init__()

        encoder_hidden_dim = hidden_dim
        self.encoder = Encoder(source_vocab_size, embed_dim,
                             encoder_hidden_dim, num_layers)

        self.decoder = AttentionDecoder(
            target_vocab_size, embed_dim, hidden_dim,
            encoder_hidden_dim * 2, attention_dim, num_layers  # *2 for bidirectional
        )

        # Bridge layers
        self.bridge_hidden = nn.Linear(encoder_hidden_dim * 2, hidden_dim)
        self.bridge_cell = nn.Linear(encoder_hidden_dim * 2, hidden_dim)

    def forward(self, source, target, source_lengths=None):
        batch_size = source.size(0)
        target_length = target.size(1)

        # Encode
        encoder_outputs, (encoder_hidden, encoder_cell) = self.encoder(
            source, source_lengths
        )

        # Bridge encoder states to decoder
        encoder_hidden = encoder_hidden[-2:].transpose(0, 1).contiguous()
        encoder_cell = encoder_cell[-2:].transpose(0, 1).contiguous()

        encoder_hidden = encoder_hidden.view(batch_size, -1)
        encoder_cell = encoder_cell.view(batch_size, -1)

        decoder_hidden = self.bridge_hidden(encoder_hidden).unsqueeze(0)
        decoder_cell = self.bridge_cell(encoder_cell).unsqueeze(0)

        # Decode with attention
        outputs = []
        attention_weights_list = []
        decoder_input = target[:, 0:1]  # Start with SOS token

        for t in range(target_length - 1):
            output, (decoder_hidden, decoder_cell), attention_weights = self.decoder(
                decoder_input, decoder_hidden, decoder_cell, encoder_outputs
            )
            outputs.append(output)
            attention_weights_list.append(attention_weights)
            decoder_input = target[:, t+1:t+2]  # Teacher forcing

        return torch.cat(outputs, dim=1), torch.stack(attention_weights_list, dim=1)

Step 4: Custom Loss Function and Training

class Seq2SeqLoss(nn.Module):
    """Custom loss function for sequence-to-sequence tasks."""

    def __init__(self, vocab_size, pad_idx=0, label_smoothing=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.pad_idx = pad_idx
        self.label_smoothing = label_smoothing

    def forward(self, predictions, targets):
        # predictions: (batch, seq_len, vocab_size)
        # targets: (batch, seq_len)

        batch_size, seq_len, vocab_size = predictions.shape

        # Flatten for cross entropy
        predictions = predictions.view(-1, vocab_size)
        targets = targets[:, 1:].contiguous().view(-1)  # Remove SOS token from targets

        # Create mask for padding tokens
        mask = targets != self.pad_idx

        # Apply label smoothing
        if self.label_smoothing > 0:
            # Create one-hot encoding
            one_hot = torch.zeros_like(predictions)
            one_hot.scatter_(1, targets.unsqueeze(1), 1)

            # Apply label smoothing
            smooth_one_hot = one_hot * (1 - self.label_smoothing) + \
                           self.label_smoothing / vocab_size

            # Compute loss
            log_probs = F.log_softmax(predictions, dim=1)
            loss = -(smooth_one_hot * log_probs).sum(dim=1)
        else:
            loss = F.cross_entropy(predictions, targets, reduction='none')

        # Apply mask and compute mean
        loss = loss.masked_select(mask).mean()

        return loss

Step 5: Training Setup

# Create datasets
train_dataset = Seq2SeqDataset(train_source, train_target,
                              source_vocab, target_vocab, max_length=20)
val_dataset = Seq2SeqDataset(val_source, val_target,
                            source_vocab, target_vocab, max_length=20)

# Custom collate function for variable length sequences
def collate_fn(batch):
    sources = torch.stack([item['source'] for item in batch])
    targets = torch.stack([item['target'] for item in batch])
    source_lengths = torch.tensor([item['source_length'] for item in batch])
    target_lengths = torch.tensor([item['target_length'] for item in batch])

    return {
        'source': sources,
        'target': targets,
        'source_lengths': source_lengths,
        'target_lengths': target_lengths
    }

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                         collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                       collate_fn=collate_fn)

# Create model
model = AttentionSeq2Seq(
    source_vocab_size=len(source_vocab),
    target_vocab_size=len(target_vocab),
    embed_dim=128,
    hidden_dim=256,
    attention_dim=128
)

# Custom loss
loss_fn = Seq2SeqLoss(vocab_size=len(target_vocab), pad_idx=0)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

Step 6: Custom Metrics for Seq2Seq

def sequence_accuracy(predictions, batch_data):
    """Calculate sequence-level accuracy."""
    targets = batch_data['target'][:, 1:]  # Remove SOS token

    # Get predicted tokens
    pred_tokens = torch.argmax(predictions, dim=2)

    # Calculate accuracy (ignoring padding)
    mask = targets != 0  # Padding token is 0
    correct = (pred_tokens == targets) | (~mask)
    sequence_correct = correct.all(dim=1).float().mean().item()

    return sequence_correct

def token_accuracy(predictions, batch_data):
    """Calculate token-level accuracy."""
    targets = batch_data['target'][:, 1:]  # Remove SOS token

    # Get predicted tokens
    pred_tokens = torch.argmax(predictions, dim=2)

    # Calculate accuracy (ignoring padding)
    mask = targets != 0
    correct = (pred_tokens == targets) & mask
    total_tokens = mask.sum().item()
    correct_tokens = correct.sum().item()

    return correct_tokens / (total_tokens + 1e-8)

# Create metrics that work with the custom data format
def create_seq2seq_metrics():
    def seq_acc_wrapper(predictions, batch_data):
        return sequence_accuracy(predictions, batch_data)

    def token_acc_wrapper(predictions, batch_data):
        return token_accuracy(predictions, batch_data)

    return {
        'sequence_accuracy': seq_acc_wrapper,
        'token_accuracy': token_acc_wrapper
    }

Step 7: Advanced Training Configuration

# Advanced optimizer with learning rate scheduling
optimizer_config = OptimizerConfig(
    optimizer_class="AdamW",
    lr=0.001,
    weight_decay=0.01,
    params={
        'betas': (0.9, 0.98),
        'eps': 1e-9
    }
)

# Training configuration
config = TrainingConfig(
    epochs=100,
    device="auto",
    mixed_precision=False,  # Can cause issues with attention
    accumulate_grad_batches=2,
    grad_clip_norm=1.0,

    validation_frequency=1,
    log_frequency=10,

    early_stopping_patience=15,
    early_stopping_min_delta=0.001,

    checkpoint_dir="./checkpoints/seq2seq",
    save_best_model=True,
    save_last_model=True,

    optimizer=optimizer_config
)

# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_sequence_accuracy',
        patience=15,
        min_delta=0.001,
        mode='max',
        verbose=True
    ),
    ModelCheckpoint(
        filepath='./checkpoints/seq2seq/best_model_{epoch:03d}_{val_seq_acc:.4f}.pt',
        monitor='val_sequence_accuracy',
        save_best_only=True,
        mode='max',
        verbose=True
    )
]

Step 8: Custom Trainer for Seq2Seq

class Seq2SeqTrainer(Trainer):
    """Custom trainer for sequence-to-sequence models."""

    def _compute_loss(self, batch, predictions):
        """Custom loss computation for seq2seq."""
        if isinstance(predictions, tuple):
            predictions, attention_weights = predictions

        return self.loss_fn(predictions, batch['target'])

    def _compute_metrics(self, batch, predictions):
        """Custom metrics computation for seq2seq."""
        if isinstance(predictions, tuple):
            predictions, attention_weights = predictions

        metrics = {}
        for name, metric_fn in self.metric_fns.items():
            try:
                metrics[name] = metric_fn(predictions, batch)
            except Exception as e:
                # Handle metric computation errors gracefully
                metrics[name] = 0.0

        return metrics

    def _forward_pass(self, batch):
        """Custom forward pass for seq2seq."""
        return self.model(
            batch['source'],
            batch['target'],
            batch.get('source_lengths')
        )

# Create custom trainer
trainer = Seq2SeqTrainer(
    model=model,
    config=config,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    loss_fn=loss_fn,
    metric_fns=create_seq2seq_metrics(),
    callbacks=callbacks
)

# Train the model
print("🚀 Starting Seq2Seq Training...")
history = trainer.train()

# Evaluate
test_results = trainer.evaluate(val_loader)
print(f"\n📊 Seq2Seq Results:")
for metric, value in test_results.items():
    print(f"  {metric.replace('_', ' ').title()}: {value:.4f}")

Summary and Advanced Techniques

🎯 What We Built:

Basic Encoder-Decoder: RNN-based seq2seq architecture ✅ Attention Mechanism: Bahdanau attention for better alignment ✅ Custom Training: Specialized trainer for sequence tasks ✅ Advanced Metrics: Sequence and token-level evaluation ✅ Inference Pipeline: Greedy decoding for generation

🚀 Advanced Extensions:

# Transformer-based encoder-decoder (modern approach)
class TransformerSeq2Seq(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512,
                 nhead=8, num_layers=6):
        super().__init__()
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True
        )
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        # Create masks for transformer
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1))

        src_emb = self.src_embedding(src)
        tgt_emb = self.tgt_embedding(tgt)

        output = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
        return self.output_projection(output)

📊 Performance Improvements:

  1. Attention Visualization: Visualize attention weights to understand model focus

  2. Beam Search: Implement beam search for better generation quality

  3. Length Penalties: Add length normalization for better sequence generation

  4. Teacher Forcing Schedule: Gradually reduce teacher forcing during training

  5. Transformer Models: Use modern transformer architectures

🔄 Production Considerations:

  • Batched Inference: Optimize for batch processing

  • Model Compression: Use techniques like knowledge distillation

  • Caching: Cache encoder outputs for faster decoding

  • Multi-GPU: Scale training across multiple GPUs

This comprehensive example shows how Treadmill can handle complex sequence-to-sequence tasks while maintaining clean, extensible code! 🏃‍♀️‍➡️

Next Steps

  • Experiment with Transformer architectures

  • Try different attention mechanisms (self-attention, multi-head)

  • Implement beam search and other decoding strategies

  • Apply to real-world datasets (WMT translation, CNN/DailyMail summarization)

  • Explore pre-trained models (T5, BART, etc.)