Skip to main content
ControlNet enables precise control over image generation using conditioning signals like edge maps, depth maps, and segmentation masks.

Installation

ControlNet requires OpenCV for image processing:
apt-get update && apt-get install ffmpeg libsm6 libxext6 -y

Quick start

python src/maxdiffusion/controlnet/generate_controlnet_replicated.py

Supported models

MaxDiffusion supports ControlNet for:
  • Stable Diffusion 1.4: Uses runwayml/stable-diffusion-v1-5 base model
  • Stable Diffusion XL: Uses SDXL base model with ControlNet

Architecture

ControlNet adds trainable copies of the encoder layers to enable conditioning:
  • Control input: Edge map, depth map, or other conditioning signal
  • ControlNet: Trainable encoder that processes control input
  • Base model: Stable Diffusion UNet with injected control features
  • Conditioning scale: Adjustable influence of control signal

SD 1.4 ControlNet

Basic usage

The SD 1.4 pipeline (generate_controlnet_replicated.py:src/maxdiffusion/controlnet/generate_controlnet_replicated.py) uses a Canny edge detector:
python src/maxdiffusion/controlnet/generate_controlnet_replicated.py

Configuration

Customize via config parameters:
config.prompt = "a photograph of a modern building"
config.negative_prompt = "blurry, distorted"
config.controlnet_image = "https://example.com/edge_map.png"
config.controlnet_model_name_or_path = "lllyasviel/sd-controlnet-canny"
config.controlnet_from_pt = True
config.controlnet_conditioning_scale = 1.0
config.num_inference_steps = 50
config.per_device_batch_size = 1
config.seed = 0

Implementation

The SD 1.4 pipeline loads ControlNet and base model (generate_controlnet_replicated.py:41-47):
# Load ControlNet
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    config.controlnet_model_name_or_path, 
    from_pt=config.controlnet_from_pt, 
    dtype=jnp.float32
)

# Load pipeline with ControlNet
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    config.pretrained_model_name_or_path, 
    controlnet=controlnet, 
    revision=config.revision, 
    dtype=jnp.float32
)
params["controlnet"] = controlnet_params

Inference

The pipeline processes control image and generates conditioned output (generate_controlnet_replicated.py:52-74):
# Prepare inputs
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)

# Replicate and shard
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

# Generate
output = pipe(
    prompt_ids=prompt_ids,
    image=processed_image,
    params=p_params,
    prng_seed=rng,
    num_inference_steps=config.num_inference_steps,
    neg_prompt_ids=negative_prompt_ids,
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    jit=True,
).images

SDXL ControlNet

Basic usage

The SDXL pipeline (generate_controlnet_sdxl_replicated.py:src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py) includes Canny edge detection:
python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py

Edge detection

The SDXL pipeline applies Canny edge detection to input images (generate_controlnet_sdxl_replicated.py:44-49):
image = load_image(config.controlnet_image)
image = np.array(image)

# Apply Canny edge detection
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
This creates a 3-channel edge map suitable for ControlNet conditioning.

Implementation

SDXL ControlNet uses bfloat16 precision (generate_controlnet_sdxl_replicated.py:51-63):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    config.controlnet_model_name_or_path, 
    from_pt=config.controlnet_from_pt, 
    dtype=config.activations_dtype
)

pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
    config.pretrained_model_name_or_path, 
    controlnet=controlnet, 
    revision=config.revision, 
    dtype=config.activations_dtype
)

# Cast params to bfloat16 (except scheduler)
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state
params["controlnet"] = controlnet_params

Parameters

ParameterDescriptionDefault
promptText description of desired imageRequired
negative_promptConcepts to avoidEmpty
controlnet_imageURL or path to control imageRequired
controlnet_model_name_or_pathControlNet model checkpointRequired
controlnet_from_ptConvert from PyTorch formatTrue
controlnet_conditioning_scaleControl signal strength (0.0-2.0)1.0
num_inference_stepsDenoising steps50
per_device_batch_sizeImages per device1
seedRandom seed0

Conditioning scale

The controlnet_conditioning_scale parameter controls how strongly the control signal influences generation:
  • 0.0: No control (standard generation)
  • 0.5: Light control, more creative freedom
  • 1.0: Balanced control (recommended)
  • 1.5: Strong control adherence
  • 2.0: Very strict control following

Example: Adjusting control strength

# Light control for more variation
config.controlnet_conditioning_scale = 0.5

# Strong control for precise structure
config.controlnet_conditioning_scale = 1.5

Control signal types

ControlNet supports various conditioning types:

Canny edges

  • Use case: Preserve structural composition
  • Model: lllyasviel/sd-controlnet-canny
  • Preprocessing: Canny edge detection (threshold 100, 200)

Depth maps

  • Use case: Control spatial depth and 3D structure
  • Model: lllyasviel/sd-controlnet-depth
  • Preprocessing: MiDaS depth estimation

Segmentation

  • Use case: Control object layout and positioning
  • Model: lllyasviel/sd-controlnet-seg
  • Preprocessing: Semantic segmentation

Human pose

  • Use case: Control human figure poses
  • Model: lllyasviel/sd-controlnet-openpose
  • Preprocessing: OpenPose skeleton detection

Custom control images

Provide custom edge maps or control signals:
config.controlnet_image = "/path/to/custom_edge_map.png"
Control images should:
  • Match the target generation resolution
  • Be grayscale or 3-channel (RGB)
  • Clearly define structural elements

Multi-device inference

ControlNet uses pmap for multi-device replication:
num_samples = jax.device_count() * config.per_device_batch_size
rng = jax.random.split(rng, jax.device_count())

# Replicate params across devices
p_params = replicate(params)

# Shard inputs
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
This distributes generation across all available devices.

Output

Generated images are saved as generated_image.png. The first image in the batch is saved by default. To save all images:
for i, image in enumerate(output_images):
  image.save(f"generated_image_{i}.png")

Examples

Building from edges

config.prompt = "a modern glass office building, blue sky, photorealistic"
config.negative_prompt = "blurry, cartoon, painting"
config.controlnet_image = "building_edges.png"
config.controlnet_conditioning_scale = 1.2

Portrait from pose

config.prompt = "portrait of a woman in a red dress, studio lighting"
config.negative_prompt = "deformed, blurry"
config.controlnet_model_name_or_path = "lllyasviel/sd-controlnet-openpose"
config.controlnet_image = "pose_skeleton.png"
config.controlnet_conditioning_scale = 1.0

Landscape from depth

config.prompt = "mountain landscape at sunset, dramatic clouds"
config.negative_prompt = "flat, boring"
config.controlnet_model_name_or_path = "lllyasviel/sd-controlnet-depth"
config.controlnet_image = "depth_map.png"
config.controlnet_conditioning_scale = 0.8

Next steps

SDXL inference

Higher quality base model for ControlNet

Stable Diffusion

Standard SD inference without control

Training overview

Train custom ControlNet models

Configuration

Full configuration reference