TPU Multislice Training with JobSet
TPU Multislice allows you to scale training workloads beyond a single TPU pod (slice) by connecting multiple slices together.
Understanding the hardware network is critical for these workloads:
- Intra-slice (ICI): The communication between TPU chips within a slice happens over inter chip interconnects (ICI). This is a dedicated, ultra-high-bandwidth physical network.
- Inter-slice (DCN): The communication between slices happens over the Data Center Network (DCN).
The Google Cloud blog post on scaling AI workloads with Multislice provides helpful diagrams that visualize the difference between the high-speed ICI network within a slice and the DCN used between slices.
Because ICI relies on physical wiring within a specific hardware boundary (represented in GKE as a Node Pool), it is crucial that Kubernetes does not fragment a single logical slice across multiple Node Pools. JobSet solves this exact problem using Exclusive Topology.
Understanding Exclusive Topology
To ensure optimal performance and prevent Job crashes due to broken ICI links, we must guarantee that all Pods belonging to a single TPU slice are scheduled onto the exact same physical Node Pool.
The exclusive placement feature is enabled by creating the JobSet with the annotation alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool. This annotation configures Pod affinity to ensure all Pods are scheduled on the same slice.
It is commonly used to ensure a 1:1 map between child Jobs and GKE Node pools. That is, if two Pods are part of the same child Job, then they will run in Nodes in the same Node pool. Otherwise, they will run in Nodes from different Node pools.
How JobSet Schedules the Pods (Leader/Follower)
When a Multislice JobSet is submitted, the JobSet controller actively manages the scheduling sequence:
- Leader Scheduling: The JobSet controller, through its Pod webhook and Pod controller, will ensure that only leader Pods (Pods with value 0 for the label
batch.kubernetes.io/job-completion-index, there is one per child Job) will be scheduled first. - Follower Binding: The follower Pods (Pods with value different than 0 for the label
batch.kubernetes.io/job-completion-index) will be allowed to be scheduled only when their leader has been scheduled. The webhook intercepts the follower Pods and dynamically injectsnodeSelectorsto force them into the exact same Node Pool as their leader.
Note
Active development is underway to bring native support for gang scheduling to JobSet, which will enable more robust topology-aware scheduling (the core objective of the exclusive topology feature). For progress and design details, please follow the issue #969: Gang Scheduling of JobSets.Example: JAX on TPU Trillium (v6e) Multislice
Before you begin, ensure you have the following set up in your GKE cluster:
Create a GKE cluster in a v6e TPU–supported location: Ensure you have created a GKE cluster in a location that supports v6e (Trillium) TPUs. Check the TPU regions and zones to confirm availability.
Install JobSet and Kueue: Make sure both the JobSet and Kueue controllers are installed.
Create TPU Node Pools: For this example multislice workload, you need at least two separate TPU node pools, one for each slice. This allows JobSet’s exclusive placement to assign each
ReplicatedJob(representing a slice) to its own dedicated node pool. This will acquire 32 TPU chips in total.Replace the placeholders and run the following commands to create two
ct6e-standard-4tnode pools with a4x4topology:# Set your cluster variables export PROJECT_ID=my-project-id # Replace with your Google Cloud Project ID export CLUSTER_NAME=my-tpu-cluster # Replace with your GKE cluster name export CONTROL_PLANE_LOCATION=us-central1 # Replace with your GKE control plane region export NODE_LOCATION=us-central1-b # Replace with the zone for TPU creation # Create the first node pool gcloud container node-pools create tpu-slice-a \ --location=$CONTROL_PLANE_LOCATION \ --cluster=$CLUSTER_NAME \ --node-locations=$NODE_LOCATION \ --machine-type=ct6e-standard-4t \ --tpu-topology=4x4 \ --project=$PROJECT_ID # Create the second node pool gcloud container node-pools create tpu-slice-b \ --location=$CONTROL_PLANE_LOCATION \ --cluster=$CLUSTER_NAME \ --node-locations=$NODE_LOCATION \ --machine-type=ct6e-standard-4t \ --tpu-topology=4x4 \ --project=$PROJECT_IDConfigure Kueue: Create the necessary Kueue and Kubernetes resources to manage the TPU workloads. This includes defining a
ResourceFlavorfor the TPUs and setting up queues. For more details on these configurations, see the Kueue and GKE integration documentation.apiVersion: kueue.x-k8s.io/v1beta2 kind: ResourceFlavor metadata: name: v6e-4x4 spec: nodeLabels: cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice cloud.google.com/gke-tpu-topology: 4x4 --- apiVersion: kueue.x-k8s.io/v1beta2 kind: ClusterQueue metadata: name: multislice-cluster-queue spec: namespaceSelector: {} queueingStrategy: BestEffortFIFO resourceGroups: - coveredResources: ["google.com/tpu"] flavors: - name: "v6e-4x4" resources: - name: "google.com/tpu" nominalQuota: "32" # 2 slices * 16 chips/slice = 32 --- apiVersion: kueue.x-k8s.io/v1beta2 kind: LocalQueue metadata: namespace: default # Or the namespace where you will run your JobSet name: multislice-queue spec: clusterQueue: multislice-cluster-queueApply the configurations:
kubectl apply -f kueue-config.yaml
Example JobSet
The following example runs a distributed JAX workload across 2 slices of TPU Trillium (v6e). It demonstrates how to integrate with Kueue’s WorkloadPriorityClass and configure the specialized networking required for v6e machines. For the latest configuration options, refer to the GKE TPU Multislice tutorial.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: v6e-multislice
labels:
# Kueue integration: Routes the workload to a specific LocalQueue
kueue.x-k8s.io/queue-name: multislice-queue
annotations:
# Ensures 1:1 mapping between child Jobs and physical TPU Node Pools
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
# Synchronous JAX training requires JobSet to restart if any node fails
maxRestarts: 4
replicatedJobs:
- name: slice
# Provision 2 independent TPU v6e slices (multislice)
replicas: 2
template:
spec:
# A 4x4 v6e slice has 16 chips. With 4 chips per node, we need 4 nodes per slice.
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: jax-tpu
image: us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
securityContext:
privileged: true
command:
- bash
- -c
- |
set -e
cat <<'EOF' > distributed_train.py
import os
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
# 1. Initialize distributed JAX cluster across all slices
jax.distributed.initialize()
process_id = jax.process_index()
num_processes = jax.process_count() # Total TPU hosts (e.g., 8)
local_device_count = jax.local_device_count() # Chips per host (e.g., 4)
global_device_count = jax.device_count() # Total chips across all hosts (e.g., 32)
nodes_per_slice = int(os.environ.get("JOBSET_REPLICATEDJOB_PARALLELISM", 4))
devices_per_slice = local_device_count * nodes_per_slice
num_slices = global_device_count // devices_per_slice
if process_id == 0:
print(f"=== Multislice JAX Cluster Initialized ===")
print(f"Total Processes (Nodes): {num_processes}")
print(f"Total Global TPU Devices: {global_device_count}")
print(f"Calculated Slices: {num_slices}, Devices per Slice: {devices_per_slice}")
print(f"[Process {process_id}] Local Devices: {jax.local_devices()}")
# 2. Define a Device Mesh across all slices
global_devices = jax.devices()
devices_2d = np.array(global_devices).reshape(num_slices, devices_per_slice)
mesh = Mesh(devices_2d, ('slice', 'tpu_in_slice'))
# 3. Create a large sharded matrix (distributed training simulation)
sharding = NamedSharding(mesh, P(('slice', 'tpu_in_slice'), None))
x = jnp.ones((8192, 8192), dtype=jnp.float32)
x_sharded = jax.device_put(x, sharding)
# 4. Perform parallel computation with JIT compilation
# Multiplying two 8192x8192 all-ones matrices together.
# The dot product of each row and column scales every single element in the
# resulting matrix 'y' from 1.0 to exactly 8192.0.
y = jax.jit(lambda a: a @ a)(x_sharded)
# Block until calculation is done to ensure correct execution
y.block_until_ready()
print(f"[Process {process_id}] MatMul calculation succeeded. Output shape: {y.shape}")
# Summing up all elements in the sharded matrix 'y'.
# Since matrix 'y' has 8192x8192 elements and each element equals 8192.0,
# the predictable theoretical global constant will be 8192 * 8192 * 8192 = 8192^3.
# Note: This jnp.sum() must run outside any conditional blocks to trigger a synchronous
# collective Reduce-Sum operation across all nodes and slices to prevent deadlocks.
matrix_sum = jnp.sum(y)
print(f"Global matrix sum result: {matrix_sum}")
EOF
# Execute the proper training script
python3 distributed_train.py
sleep 10
resources:
limits:
google.com/tpu: 4 # 4 Trillium chips per node
If you create a JobSet with replicas: 2 (2 Slices) and parallelism: 4 (4 Nodes per Slice), JobSet creates two child Jobs:
$ kubectl get jobs -o wide
NAME COMPLETIONS DURATION AGE CONTAINERS IMAGES SELECTOR
v6e-multislice-slice-0 0/4 2m 2m jax-tpu us-docker.pkg.../tpu controller-uid=1111...
v6e-multislice-slice-1 0/4 2m 2m jax-tpu us-docker.pkg.../tpu controller-uid=2222...
When checking the Pods, you will see the 1:1 mapping in action:
$ kubectl get pods -o custom-columns="NAME:.metadata.name,STATUS:.status.phase,NODE:.spec.nodeName,NODEPOOL_SELECTOR:.spec.nodeSelector.cloud\.google\.com/gke-nodepool"
NAME STATUS NODE NODEPOOL_SELECTOR
# --------------------------------------------------------------------------------------------------
# 4 Pods of v6e-multislice-slice-0 -> Exclusively bound to tpu-slice-a (ICI intact)
# --------------------------------------------------------------------------------------------------
v6e-multislice-slice-0-0-sk2b6 Running gke-tpu-3335d306-dhnf <none>
v6e-multislice-slice-0-1-g5kzv Running gke-tpu-3335d306-3ww8 tpu-slice-a
v6e-multislice-slice-0-2-rnp55 Running gke-tpu-3335d306-gqwm tpu-slice-a
v6e-multislice-slice-0-3-2dtbf Running gke-tpu-3335d306-ccp2 tpu-slice-a
# --------------------------------------------------------------------------------------------------
# 4 Pods of v6e-multislice-slice-1 -> Exclusively bound to tpu-slice-b (ICI intact)
# --------------------------------------------------------------------------------------------------
v6e-multislice-slice-1-0-t4xq6 Running gke-tpu-08d18b59-rkb0 <none>
v6e-multislice-slice-1-1-nw5hp Running gke-tpu-08d18b59-zvvp tpu-slice-b
v6e-multislice-slice-1-2-gjt79 Running gke-tpu-08d18b59-9mqf tpu-slice-b
v6e-multislice-slice-1-3-gzm7p Running gke-tpu-08d18b59-j1fb tpu-slice-b
The output demonstrates JobSet’s Leader/Follower scheduling mechanism for exclusive-topology. The first pod (Leader, index 0) is scheduled first without a specific nodepool restriction (showing <none>). Once it is placed on a node, JobSet intercepts the remaining follower pods and dynamically injects the Leader’s selected nodepool into their node selectors. This guarantees that all follower pods are forced into the exact same physical nodepool as the leader, keeping the high-speed ICI network between the TPU chips intact.
With the pods correctly scheduled, the final step is to verify that the JAX application can communicate across both slices and utilize all available TPU chips. By inspecting the logs of any pod, you can confirm that the multislice setup is working:
$ kubectl logs job/v6e-multislice-slice-0 --all-pods
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] === Multislice JAX Cluster Initialized ===
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] Total Processes (Slices/Nodes): 8
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] Total Global TPU Devices: 32
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] Calculated Slices: 2, Devices per Slice: 16
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] [Process 0] Local Devices: [MegaScalePjRtDevice(wrapped=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), slice_id=0), MegaScalePjRtDevice(wrapped=TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), slice_id=0), MegaScalePjRtDevice(wrapped=TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), slice_id=0), MegaScalePjRtDevice(wrapped=TpuDevice(id=5, process_index=0, coords=(1,1,0), core_on_chip=0), slice_id=0)]
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] [Process 0] MatMul calculation succeeded. Output shape: (8192, 8192)
[pod/v6e-multislice-slice-0-2-rnp55/jax-tpu] Global matrix sum result: 549755813888.0
...
The log output Total Global TPU Devices: 32 confirms the application sees all 32 chips (2 slices Ă— 16 chips per slice). The final global matrix sum, 549755813888.0, provides a numeric validation that all 32 sharded chips are actively computing and communicating.
This demonstrates that JobSet has successfully orchestrated a multislice workload, enabling communication across the DCN and presenting a unified accelerator environment to the training job.
Further Reading
For more detailed information and official tutorials on TPU Multislice and Kueue integration, please refer to the following resources:
- Run a Kueue scheduled JobSet: Official Kueue documentation detailing queue selection and resource configuration for JobSets.
Feedback
Was this page helpful?
Glad to hear it! Please tell us how we can improve.
Sorry to hear that. Please tell us how we can improve.