Skip to main content
MaxDiffusion provides high-performance inference for latent diffusion models on Google Cloud TPUs and GPUs. All models leverage JAX for compilation and XLA for hardware-specific optimizations.

Supported models

MaxDiffusion supports the following models for inference:

Stable Diffusion

SD 2 base and SD 2.1 for 512×512 image generation

Stable Diffusion XL

SDXL for high-quality 1024×1024 images with dual text encoders

Flux

Flux dev and schnell variants with optimized flash attention

Wan

Wan 2.1 and 2.2 for text-to-video and image-to-video generation

LTX Video

LTX-Video for high-quality video generation with conditioning

ControlNet

Conditional generation with ControlNet for SD 1.4 and SDXL

Key features

Sharding strategies

MaxDiffusion supports multiple parallelism strategies for efficient inference:
  • Data parallelism (DDP): Replicate the model across devices and process different prompts in parallel
  • FSDP: Shard model parameters across devices to fit larger models in memory
  • Context parallelism: Split sequence dimension for handling longer context
Configure sharding with parameters:
ici_data_parallelism=4      # Number of data parallel devices
ici_fsdp_parallelism=-1     # Fully shard model parameters
ici_context_parallelism=2   # Context parallel degree

Trillium optimizations

TPU v6e (Trillium) benefits from optimized flash attention block sizes. Enable by uncommenting the flash_block_sizes configuration in model config files:

Encoder offloading

For models with large text encoders (like Flux), offload encoders to keep the transformer and VAE in HBM:
offload_encoders=False  # Keep all components in HBM

Precision control

All models use bfloat16 by default for optimal performance on TPUs:
  • Activations: bfloat16
  • Weights: bfloat16
  • Latents: float32 for numerical stability

Common parameters

All inference scripts accept these common parameters:
ParameterDescriptionDefault
promptText prompt for generationRequired
negative_promptNegative prompt to avoid conceptsEmpty string
num_inference_stepsNumber of denoising stepsModel-specific
guidance_scaleClassifier-free guidance strength7.5
per_device_batch_sizeBatch size per device1
seedRandom seed for reproducibility0
output_dirDirectory for saving outputs/tmp/
jax_cache_dirJAX compilation cache directoryRequired

Performance tips

  1. Use flash attention: Set attention="flash" for 2-4x speedup on supported hardware
  2. Enable HF transfer: Set HF_HUB_ENABLE_HF_TRANSFER=1 for faster model downloads
  3. Cache compilations: Use jax_cache_dir to avoid recompiling on subsequent runs
  4. Optimize batch size: Increase per_device_batch_size to maximize hardware utilization
  5. Use async collectives: Set LIBTPU_INIT_ARGS for better communication overlap on TPUs

Next steps

Stable Diffusion XL

Generate high-quality images with SDXL

Flux

Fast inference with Flux dev and schnell

Wan video generation

Create videos with Wan 2.1 and 2.2

LoRA loading

Load custom LoRA adapters