My First JAX Contribution: Modernizing SPMD Examples from pmap to jax.jit

January 2025

You know that feeling when you’re diving deep into a codebase, trying to learn something new, and you stumble upon outdated examples? That’s exactly how my first contribution to JAX began.

The Discovery

I was working through JAX’s distributed training examples, specifically the MNIST SPMD (Single Program, Multiple Data) tutorial. As someone passionate about high-performance ML systems, I wanted to understand how JAX handles data parallelism at scale. But there was a problem – the example was still using pmap, which I discovered was deprecated.

Looking at issue #20312, it was clear: pmap was on its way out, replaced by the more powerful jax.jit with sharding annotations. The old way wasn’t just deprecated; it was teaching newcomers an outdated pattern.

Why This Mattered

The shift from pmap to jax.jit with NamedSharding isn’t just a syntax change – it represents JAX’s evolution toward a more unified and flexible parallelism model. Here’s what changed:

The Old Way (pmap)

# Old pmap-based approach
@jax.pmap
def train_step(state, batch):
    def loss_fn(params):
        logits = model.apply(params, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, batch['label']
        )
        return loss.mean()
    
    grads = jax.grad(loss_fn)(state.params)
    # pmap automatically handles the data sharding
    grads = jax.lax.pmean(grads, 'batch')
    state = state.apply_gradients(grads=grads)
    return state

The New Way (jax.jit with NamedSharding)

# Modern jax.jit with explicit sharding
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P

# Define the mesh (logical view of devices)
mesh = Mesh(np.array(jax.devices()), ('data',))

# Create sharding specifications
data_sharding = NamedSharding(mesh, P('data', None))
replicated_sharding = NamedSharding(mesh, P())

@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn(params, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, batch['label']
        )
        return loss.mean()
    
    grads = jax.grad(loss_fn)(state.params)
    # Explicit all-reduce across the data axis
    grads = jax.lax.pmean(grads, 'data')
    state = state.apply_gradients(grads=grads)
    return state

# Shard the data explicitly
batch = {'image': images, 'label': labels}
batch = jax.device_put(batch, data_sharding)

The Learning Journey

What I love about this change is how much more explicit and educational it is. With pmap, a lot of the sharding magic happened behind the scenes. With NamedSharding, you’re forced to think about:

  1. Device Mesh Topology: How are your devices arranged? In my PR, I kept it simple with a 1D mesh:
    mesh = Mesh(np.array(jax.devices()), ('data',))
    
  2. Partition Specifications: What gets sharded and how?
    # Data is sharded across the 'data' axis, features are replicated
    data_sharding = NamedSharding(mesh, P('data', None))
    # Model parameters are replicated across all devices
    replicated_sharding = NamedSharding(mesh, P())
    
  3. Explicit Device Placement: You manually place data on devices:
    sharded_batch = jax.device_put(batch, data_sharding)
    

The Subtle Details That Matter

One thing I discovered while working on this PR was the importance of consistency in collective operations. When using pmap, the axis name in pmean was arbitrary. With the mesh approach, it needs to match your mesh axis names:

# This has to match the mesh axis name!
grads = jax.lax.pmean(grads, 'data')  # 'data' matches our mesh definition

Another gotcha was ensuring the data shapes align with the sharding strategy. If you have 8 devices and want to shard data across them, your batch size needs to be divisible by 8:

# Ensure batch size is divisible by device count
num_devices = jax.device_count()
batch_size = 128 * num_devices  # Each device gets 128 examples

Why This Contribution Felt Important

As someone diving deep into distributed ML, I believe good examples are crucial. They’re often the first thing developers see when learning a new framework. An outdated example doesn’t just teach old patterns – it can actively confuse newcomers who then encounter deprecation warnings or conflicting advice in other parts of the documentation.

By updating this example, future JAX users will:

  • Learn the modern, recommended approach from day one
  • Understand the explicit nature of device sharding
  • Be better prepared for more complex parallelism patterns like tensor parallelism and pipeline parallelism

What I Learned

Contributing to JAX taught me more than just the technical details of SPMD:

  1. Reading deprecation discussions is goldmine: Issue #20312 had incredible discussions about why pmap was being deprecated and the philosophy behind the new approach.

  2. Small contributions matter: This wasn’t a massive feature or optimization. It was a documentation update with about +64/-48 lines changed. But it helps everyone who comes after.

  3. The JAX community is thoughtful: The design decisions around jax.jit and sharding show deep thinking about usability and power. The new approach is more verbose but also more flexible and educational.

Looking Forward

This PR opened my eyes to the elegance of JAX’s approach to parallelism. The explicit nature of NamedSharding and PartitionSpec makes it easier to reason about what’s happening across your devices.

I’m now exploring more complex sharding patterns – like how to implement ZeRO-style parameter sharding or how to efficiently shard attention mechanisms for large language models. The foundation that JAX provides with its sharding APIs makes these advanced techniques surprisingly approachable.

If you’re interested in distributed ML or just starting with JAX, I’d encourage you to look at the sharding examples in the documentation. And if you find something that could be improved – don’t hesitate to contribute! Even small updates to examples can make a big difference for the next person learning.


Have you contributed to open source ML frameworks? What was your experience like? I’d love to hear about it!

Links: