Skip to main content

Overview

Continual learning (also called lifelong learning) enables models to learn new tasks sequentially without forgetting previously learned knowledge. Neurenix provides multiple strategies to prevent catastrophic forgetting including Elastic Weight Consolidation (EWC), experience replay, regularization, and knowledge distillation.

Quick Start

Prevent forgetting with EWC:
import neurenix as nx
from neurenix.continual import EWC

# Create model
model = create_model()
optimizer = nx.optim.Adam(model.parameters(), lr=0.001)
criterion = nx.nn.CrossEntropyLoss()

# Initialize EWC
ewc = EWC(model, lambda_reg=1000.0)

# Train on Task 1
for epoch in range(num_epochs):
    for batch in task1_loader:
        optimizer.zero_grad()
        loss = criterion(model(batch.x), batch.y)
        loss.backward()
        optimizer.step()

# Register Task 1
ewc.register_task(task1_loader, criterion, optimizer)

# Train on Task 2 with EWC protection
for epoch in range(num_epochs):
    for batch in task2_loader:
        optimizer.zero_grad()
        
        # Standard loss
        loss = criterion(model(batch.x), batch.y)
        
        # Add EWC penalty
        loss += ewc.penalty()
        
        loss.backward()
        optimizer.step()
Reference: neurenix/continual/ewc.py:16

Elastic Weight Consolidation (EWC)

EWC slows down learning on parameters important for previous tasks.

Basic EWC

from neurenix.continual import EWC
import neurenix as nx

# Create EWC instance
ewc = EWC(
    model=model,
    lambda_reg=1000.0,  # Regularization strength
    use_online=False     # Standard EWC
)

# Train on first task
train_task(model, task1_loader, optimizer)

# Register task (compute importance)
ewc.register_task(
    dataloader=task1_loader,
    loss_fn=criterion,
    optimizer=optimizer,
    num_samples=1000  # Samples for Fisher matrix estimation
)

# Train on second task with EWC
for batch in task2_loader:
    optimizer.zero_grad()
    
    loss = criterion(model(batch.x), batch.y)
    loss += ewc.penalty()  # Protect important weights
    
    loss.backward()
    optimizer.step()
Reference: neurenix/continual/ewc.py:31

Online EWC

Accumulate importance across multiple tasks:
# Create online EWC
ewc = EWC(
    model=model,
    lambda_reg=5000.0,
    use_online=True  # Accumulate importance
)

# Train on multiple tasks
for task_id, task_loader in enumerate(task_loaders):
    print(f"Training on Task {task_id + 1}")
    
    for epoch in range(num_epochs):
        for batch in task_loader:
            optimizer.zero_grad()
            
            loss = criterion(model(batch.x), batch.y)
            
            if task_id > 0:  # Add penalty after first task
                loss += ewc.penalty()
            
            loss.backward()
            optimizer.step()
    
    # Register task (accumulates importance)
    ewc.register_task(task_loader, criterion, optimizer)
Reference: neurenix/continual/ewc.py:35

EWC Penalty Computation

# Compute EWC penalty manually
penalty = ewc.penalty()

# Penalty formula:
# penalty = (lambda_reg / 2) * sum(importance * (param - old_param)^2)
#
# where:
# - importance: Fisher information for each parameter
# - old_param: Parameter values from previous task
# - param: Current parameter values
Reference: neurenix/continual/ewc.py:154

Experience Replay

Store and replay examples from previous tasks.

Basic Experience Replay

from neurenix.continual import ExperienceReplay
import neurenix as nx

# Create replay buffer
replay = ExperienceReplay(
    memory_size=10000,      # Store 10k examples
    strategy="random",       # 'random', 'reservoir', 'importance'
    per_class=True,         # Balanced memory per class
    sample_size=128         # Replay batch size
)

# Train on Task 1
for batch in task1_loader:
    # Store examples in memory
    replay.update_memory(batch.x, batch.y)
    
    # Train
    optimizer.zero_grad()
    loss = criterion(model(batch.x), batch.y)
    loss.backward()
    optimizer.step()

# Train on Task 2 with replay
for batch in task2_loader:
    optimizer.zero_grad()
    
    # Current task loss
    loss = criterion(model(batch.x), batch.y)
    
    # Sample from replay buffer
    replay_x, replay_y = replay.sample_memory(batch_size=32)
    replay_loss = criterion(model(replay_x), replay_y)
    
    # Combined loss
    total_loss = loss + 0.5 * replay_loss
    
    total_loss.backward()
    optimizer.step()
    
    # Update memory with new examples
    replay.update_memory(batch.x, batch.y)
Reference: neurenix/continual/replay.py:17

Reservoir Sampling Strategy

Maintain representative sample as data streams:
# Reservoir sampling ensures uniform distribution
replay = ExperienceReplay(
    memory_size=5000,
    strategy="reservoir",  # Reservoir sampling
    per_class=False
)

for task_id, task_loader in enumerate(task_loaders):
    for batch in task_loader:
        # Update memory with reservoir sampling
        replay.update_memory(batch.x, batch.y, task_id=task_id)
        
        # Train with replay
        if task_id > 0:
            replay_x, replay_y = replay.sample_memory()
            # ... training code
Reference: neurenix/continual/replay.py:49

Per-Class Balanced Replay

# Maintain balanced memory across classes
replay = ExperienceReplay(
    memory_size=10000,
    strategy="random",
    per_class=True,     # Balance classes
    sample_size=100
)

# Automatically balances across classes
for batch in train_loader:
    replay.update_memory(batch.x, batch.y)

print(f"Total examples: {replay.get_memory_size()}")
print(f"Classes stored: {len(replay.memory)}")
Reference: neurenix/continual/replay.py:34

Regularization Methods

L2 Regularization

Simple regularization toward previous parameters:
from neurenix.continual import L2Regularization

# Create L2 regularizer
l2_reg = L2Regularization(
    model=model,
    lambda_reg=0.1  # Regularization strength
)

# Train on Task 1
train_task(model, task1_loader, optimizer)

# Register task
l2_reg.register_task()

# Train on Task 2
for batch in task2_loader:
    optimizer.zero_grad()
    
    loss = criterion(model(batch.x), batch.y)
    loss += l2_reg.penalty()  # L2 penalty
    
    loss.backward()
    optimizer.step()
Reference: neurenix/continual/regularization.py:15

Weight Freezing

Freeze important weights after learning:
from neurenix.continual import WeightFreezing

# Create weight freezing
freeze = WeightFreezing(
    model=model,
    importance_threshold=0.1,
    importance_method="magnitude"  # 'magnitude', 'gradient', 'fisher'
)

# Train on Task 1
train_task(model, task1_loader, optimizer)

# Register task and compute importance
freeze.register_task(
    dataloader=task1_loader,
    loss_fn=criterion
)

# Train on Task 2 with frozen weights
for batch in task2_loader:
    optimizer.zero_grad()
    loss = criterion(model(batch.x), batch.y)
    loss.backward()
    
    # Apply mask to freeze important weights
    freeze.apply_mask()
    
    optimizer.step()
Reference: neurenix/continual/regularization.py:83

Knowledge Distillation

Transfer knowledge from old model to new:
from neurenix.continual import KnowledgeDistillation
import neurenix as nx

# Save old model
old_model = model.clone()
old_model.eval()

# Create distillation
kd = KnowledgeDistillation(
    teacher_model=old_model,
    student_model=model,
    temperature=2.0,
    alpha=0.5  # Weight between task loss and distillation loss
)

# Train on new task with distillation
for batch in task2_loader:
    optimizer.zero_grad()
    
    # Task loss
    task_loss = criterion(model(batch.x), batch.y)
    
    # Distillation loss
    distill_loss = kd.compute_loss(batch.x)
    
    # Combined loss
    loss = kd.alpha * task_loss + (1 - kd.alpha) * distill_loss
    
    loss.backward()
    optimizer.step()
Reference: neurenix/continual/__init__.py:11

Synaptic Intelligence

Track parameter importance during training:
from neurenix.continual import SynapticIntelligence

# Create synaptic intelligence
si = SynapticIntelligence(
    model=model,
    lambda_reg=1.0,
    epsilon=1e-3
)

# Train on Task 1
for batch in task1_loader:
    optimizer.zero_grad()
    loss = criterion(model(batch.x), batch.y)
    loss.backward()
    
    # Track parameter changes
    si.update_importance()
    
    optimizer.step()

# Finalize importance for Task 1
si.finalize_importance()

# Train on Task 2
for batch in task2_loader:
    optimizer.zero_grad()
    
    loss = criterion(model(batch.x), batch.y)
    loss += si.penalty()  # SI penalty
    
    loss.backward()
    si.update_importance()
    optimizer.step()
Reference: neurenix/continual/__init__.py:12

Complete Multi-Task Example

import neurenix as nx
from neurenix.continual import EWC, ExperienceReplay

# Create model and training components
model = create_model()
optimizer = nx.optim.Adam(model.parameters(), lr=0.001)
criterion = nx.nn.CrossEntropyLoss()

# Initialize continual learning strategies
ewc = EWC(model, lambda_reg=5000.0, use_online=True)
replay = ExperienceReplay(
    memory_size=5000,
    strategy="reservoir",
    per_class=True,
    sample_size=64
)

# Train on sequence of tasks
task_loaders = [task1_loader, task2_loader, task3_loader, task4_loader]
test_loaders = [test1_loader, test2_loader, test3_loader, test4_loader]

for task_id, task_loader in enumerate(task_loaders):
    print(f"\nTraining on Task {task_id + 1}")
    
    for epoch in range(num_epochs):
        for batch in task_loader:
            optimizer.zero_grad()
            
            # Current task loss
            loss = criterion(model(batch.x), batch.y)
            
            # Add EWC penalty (after first task)
            if task_id > 0:
                loss += ewc.penalty()
            
            # Add replay loss (after first task)
            if task_id > 0 and replay.get_memory_size() > 0:
                replay_x, replay_y = replay.sample_memory(batch_size=32)
                replay_loss = criterion(model(replay_x), replay_y)
                loss += 0.5 * replay_loss
            
            loss.backward()
            optimizer.step()
            
            # Update replay buffer
            replay.update_memory(batch.x, batch.y, task_id=task_id)
    
    # Register task with EWC
    ewc.register_task(task_loader, criterion, optimizer, num_samples=1000)
    
    # Evaluate on all tasks seen so far
    print(f"\nEvaluation after Task {task_id + 1}:")
    for eval_id in range(task_id + 1):
        accuracy = evaluate(model, test_loaders[eval_id])
        print(f"  Task {eval_id + 1} accuracy: {accuracy:.2f}%")

# Final evaluation
print("\nFinal Results:")
for task_id, test_loader in enumerate(test_loaders):
    accuracy = evaluate(model, test_loader)
    print(f"Task {task_id + 1}: {accuracy:.2f}%")

avg_accuracy = sum(evaluate(model, loader) for loader in test_loaders) / len(test_loaders)
print(f"\nAverage accuracy: {avg_accuracy:.2f}%")

Combining Strategies

EWC + Experience Replay

# Best of both: regularization + memory
ewc = EWC(model, lambda_reg=1000.0)
replay = ExperienceReplay(memory_size=5000)

for task_id, task_loader in enumerate(task_loaders):
    for batch in task_loader:
        optimizer.zero_grad()
        
        # Task loss
        loss = criterion(model(batch.x), batch.y)
        
        # EWC regularization
        if task_id > 0:
            loss += ewc.penalty()
        
        # Replay previous examples
        if replay.get_memory_size() > 0:
            replay_x, replay_y = replay.sample_memory()
            loss += criterion(model(replay_x), replay_y)
        
        loss.backward()
        optimizer.step()
        replay.update_memory(batch.x, batch.y)
    
    ewc.register_task(task_loader, criterion, optimizer)

EWC + Knowledge Distillation

# Regularization + soft targets
old_model = model.clone()
ewc = EWC(model, lambda_reg=5000.0)
kd = KnowledgeDistillation(old_model, model, temperature=2.0, alpha=0.7)

for batch in task2_loader:
    optimizer.zero_grad()
    
    task_loss = criterion(model(batch.x), batch.y)
    ewc_penalty = ewc.penalty()
    distill_loss = kd.compute_loss(batch.x)
    
    loss = task_loss + ewc_penalty + 0.5 * distill_loss
    
    loss.backward()
    optimizer.step()

Evaluation Metrics

Average Accuracy

def compute_average_accuracy(model, test_loaders):
    """Average accuracy across all tasks."""
    accuracies = []
    for loader in test_loaders:
        acc = evaluate(model, loader)
        accuracies.append(acc)
    return sum(accuracies) / len(accuracies)

Forgetting Measure

def compute_forgetting(accuracies_matrix):
    """
    Measure how much the model forgets.
    
    Args:
        accuracies_matrix: accuracies_matrix[i][j] = accuracy on task i
                          after training on task j
    """
    n_tasks = len(accuracies_matrix)
    forgetting = 0.0
    
    for i in range(n_tasks - 1):
        # Max accuracy achieved on task i
        max_acc = max(accuracies_matrix[i][i:n_tasks])
        # Final accuracy on task i
        final_acc = accuracies_matrix[i][-1]
        # Forgetting for task i
        forgetting += max_acc - final_acc
    
    return forgetting / (n_tasks - 1)

Backward Transfer

def compute_backward_transfer(accuracies_matrix):
    """
    Measure improvement on previous tasks.
    Positive values indicate beneficial backward transfer.
    """
    n_tasks = len(accuracies_matrix)
    backward_transfer = 0.0
    
    for i in range(n_tasks - 1):
        # Accuracy on task i after training on task i
        acc_after_i = accuracies_matrix[i][i]
        # Final accuracy on task i
        final_acc = accuracies_matrix[i][-1]
        # Transfer for task i
        backward_transfer += final_acc - acc_after_i
    
    return backward_transfer / (n_tasks - 1)

Best Practices

1. Choose the Right Strategy

# For limited memory: Use EWC or regularization
if memory_constrained:
    ewc = EWC(model, lambda_reg=5000.0)

# For abundant memory: Use experience replay
else:
    replay = ExperienceReplay(memory_size=50000)

# For best results: Combine strategies
combined = (ewc, replay)

2. Tune Regularization Strength

# Start with moderate values
lambda_values = [100, 1000, 5000, 10000]

for lambda_reg in lambda_values:
    ewc = EWC(model.clone(), lambda_reg=lambda_reg)
    # Train and evaluate
    avg_acc = train_and_evaluate(model, tasks, ewc)
    print(f"Lambda {lambda_reg}: {avg_acc:.2f}%")

3. Balance Old and New Tasks

# Adjust loss weighting
new_task_weight = 1.0
old_task_weight = 0.5  # Lower weight for old tasks

loss = new_task_weight * task_loss
if task_id > 0:
    loss += old_task_weight * replay_loss

4. Monitor Forgetting

# Track accuracy on all tasks
accuracies = {task_id: [] for task_id in range(num_tasks)}

for current_task in range(num_tasks):
    train_on_task(model, current_task)
    
    # Evaluate on all tasks
    for eval_task in range(current_task + 1):
        acc = evaluate(model, test_loaders[eval_task])
        accuracies[eval_task].append(acc)
        
        # Alert if significant forgetting
        if len(accuracies[eval_task]) > 1:
            drop = accuracies[eval_task][-2] - acc
            if drop > 5.0:  # 5% drop
                print(f"Warning: Task {eval_task} forgot {drop:.1f}%")

5. Optimize Memory Usage

# Use efficient replay strategies
replay = ExperienceReplay(
    memory_size=5000,
    strategy="reservoir",  # Better distribution
    per_class=True,        # Balanced classes
    sample_size=64         # Smaller batches
)

# Or use gradient-based importance for EWC
ewc = EWC(model, lambda_reg=1000.0, use_online=True)

Performance Tips

  1. Use online EWC for multiple tasks
  2. Combine EWC with replay for best results
  3. Tune regularization strength per dataset
  4. Use reservoir sampling for streaming data
  5. Monitor forgetting metrics during training
  6. Balance replay buffer across classes and tasks