LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels

Joint Embedding Predictive Architectures (JEPAs) promise efficient world models that learn from pixels, but existing methods suffer from training instability and complex hyperparameter tuning. LeWorldModel (LeWM) solves these problems with a streamlined approach that achieves 48× faster planning while maintaining competitive performance.

The Collapse Problem in World Models

Current JEPA methods face a fundamental challenge: representation collapse. When learning to predict future states in latent space, models often map all inputs to identical representations to trivially satisfy the prediction objective. This renders the learned representations useless.

Existing solutions create new problems:

  • PLDM requires six hyperparameters and complex multi-term losses
  • DINO-WM relies on frozen pre-trained encoders, limiting end-to-end learning
  • Traditional methods use heuristic tricks like exponential moving averages without theoretical guarantees

LeWorldModel’s Two-Term Solution

LeWM introduces a principled approach using only two loss terms:

1. Next-Embedding Prediction Loss

L_pred = ||ẑ_{t+1} - z_{t+1}||²

This standard prediction loss encourages the model to learn useful dynamics.

2. SIGReg Regularization

The Sketched-Isotropic-Gaussian Regularizer (SIGReg) prevents collapse by enforcing Gaussian-distributed latent embeddings:

1
2
3
4
5
def SIGReg(embeddings):
    # Project embeddings onto random directions
    projections = embeddings @ random_directions
    # Apply normality test to each projection
    return mean([normality_test(proj) for proj in projections])

SIGReg works by:

  • Projecting high-dimensional embeddings onto random unit vectors
  • Testing each projection for normality using the Epps-Pulley statistic
  • Leveraging the Cramér-Wold theorem: matching all 1D marginals equals matching the full distribution

Architecture and Training

LeWM uses a simple encoder-predictor architecture:

  • Encoder: Vision Transformer (ViT-Tiny, ~5M parameters) maps pixels to latent representations
  • Predictor: Transformer (~10M parameters) models temporal dynamics with action conditioning
  • Total: 15M parameters trainable on a single GPU

The complete training objective becomes:

L_LeWM = L_pred + λ * SIGReg(Z)

With only one effective hyperparameter (λ), tuning reduces from O(n⁶) to O(log n) complexity through bisection search.

Performance Results

LeWM demonstrates strong performance across diverse environments:

Planning Speed

  • 48× faster than DINO-WM (0.98s vs 47s for full planning)
  • Enables near real-time control with competitive accuracy

Control Performance

  • Push-T: 96% success rate (vs 92% DINO-WM, 78% PLDM)
  • OGBench-Cube: 74% success rate, competitive with foundation models
  • Reacher: 86% success rate, outperforming baselines

Training Stability

LeWM’s two-term objective shows smooth, monotonic convergence compared to PLDM’s noisy multi-term optimization.

Physical Understanding

LeWM’s latent space captures meaningful physical structure:

Probing Results

Linear probes successfully extract physical quantities:

  • Agent/object positions with 97%+ correlation
  • Block orientations and velocities
  • End-effector states in 3D manipulation

Violation-of-Expectation Tests

The model reliably detects physically implausible events, assigning higher surprise to:

  • Object teleportation (physical violations)
  • Compared to color changes (visual perturbations)

Emergent Properties

LeWM exhibits temporal path straightening—latent trajectories become increasingly linear over training without explicit regularization, suggesting efficient representation learning.

Implementation Details

Key Components

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# Encoder: ViT-Tiny with projection layer
encoder = ViT(patch_size=14, layers=12, heads=3, dim=192)

# Predictor: Transformer with action conditioning
predictor = Transformer(layers=6, heads=16, dropout=0.1)

# Training loop
for batch in dataloader:
    embeddings = encoder(observations)
    predictions = predictor(embeddings, actions)
    
    pred_loss = mse_loss(predictions, targets)
    sigreg_loss = SIGReg(embeddings)
    
    total_loss = pred_loss + lambda * sigreg_loss
    total_loss.backward()

Planning with CEM

LeWM uses Cross-Entropy Method for trajectory optimization:

  • Sample 300 action sequences per iteration
  • Select top 30 candidates as elites
  • Update sampling distribution iteratively
  • Execute first action, then replan (MPC)

Limitations and Future Work

Current Constraints

  • Short horizons: Planning limited to ~25 timesteps due to error accumulation
  • Data requirements: Needs sufficient interaction coverage in offline datasets
  • Simple environments: SIGReg may struggle with very low-dimensional dynamics

Research Directions

  • Hierarchical modeling for long-horizon planning
  • Pre-training on diverse video datasets to reduce data requirements
  • Inverse dynamics to reduce reliance on action labels

Conclusion

LeWorldModel represents a significant advance in world model learning, offering:

  • Simplicity: Two-term objective vs. six+ terms in alternatives
  • Stability: Provable anti-collapse guarantees through SIGReg
  • Efficiency: 48× faster planning with competitive performance
  • Accessibility: Single GPU training in hours vs. days

By solving the fundamental collapse problem with principled regularization, LeWM enables practical deployment of learned world models for real-time control applications.

The method’s combination of theoretical grounding and empirical performance makes it a compelling foundation for future world model research, particularly in robotics and autonomous systems where fast, reliable planning is essential.