Use when adding a new diffusion model or Diffusers pipeline to SGLang.
Use this skill when adding a new diffusion model or pipeline variant to sglang.multimodal_gen.
The recommended default for most new models. Uses a three-stage structure:
BeforeDenoisingStage (model-specific) --> DenoisingStage (standard) --> DecodingStage (standard)
Why recommended? Modern diffusion models have highly heterogeneous pre-processing requirements (different text encoders, different latent formats, different conditioning mechanisms). The Hybrid approach keeps pre-processing isolated per model, avoids fragile shared stages with excessive conditional logic, and lets developers port Diffusers reference code quickly.
Uses the framework's fine-grained standard stages (TextEncodingStage, LatentPreparationStage, TimestepPreparationStage, etc.) to build the pipeline by composition.
This style is appropriate when:
add_standard_t2i_stages() or add_standard_ti2i_stages() may be all you need.See existing Modular examples: QwenImagePipeline (uses add_standard_t2i_stages), FluxPipeline, WanPipeline.
| Situation | Recommended Style |
|---|---|
| Model has unique/complex pre-processing (VLM captioning, AR token generation, custom latent packing, etc.) | Hybrid — consolidate into a BeforeDenoisingStage |
| Model fits neatly into standard text-to-image or text+image-to-image pattern | Modular — use add_standard_t2i_stages() / add_standard_ti2i_stages() |
| Porting a Diffusers pipeline with many custom steps | Hybrid — copy the __call__ logic into a single stage |
| Adding a variant of an existing model that shares most logic | Modular — reuse existing stages, customize via PipelineConfig callbacks |
| A specific pre-processing step needs special parallelism or profiling isolation | Modular — extract that step as a dedicated stage |
Key principle (both styles): The stage(s) before DenoisingStage must produce a Req batch object with all the standard tensor fields that DenoisingStage expects (latents, timesteps, prompt_embeds, etc.). As long as this contract is met, the pipeline remains composable regardless of which style you use.
| Purpose | Path |
|---|---|
| Pipeline classes | python/sglang/multimodal_gen/runtime/pipelines/ |
| Model-specific stages | python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/ |
| PipelineStage base class | python/sglang/multimodal_gen/runtime/pipelines_core/stages/base.py |
| Pipeline base class | python/sglang/multimodal_gen/runtime/pipelines_core/composed_pipeline_base.py |
| Standard stages (Denoising, Decoding) | python/sglang/multimodal_gen/runtime/pipelines_core/stages/ |
| Pipeline configs | python/sglang/multimodal_gen/configs/pipeline_configs/ |
| Sampling params | python/sglang/multimodal_gen/configs/sample/ |
| DiT model implementations | python/sglang/multimodal_gen/runtime/models/dits/ |
| VAE implementations | python/sglang/multimodal_gen/runtime/models/vaes/ |
| Encoder implementations | python/sglang/multimodal_gen/runtime/models/encoders/ |
| Scheduler implementations | python/sglang/multimodal_gen/runtime/models/schedulers/ |
| Model/VAE/DiT configs | python/sglang/multimodal_gen/configs/models/dits/, vaes/, encoders/ |
| Central registry | python/sglang/multimodal_gen/registry.py |
Before writing any code, obtain the model's reference implementation or Diffusers pipeline code. You need the actual source code to work from — do not guess or assume the model's architecture. If the user already gave a HuggingFace model ID or repo, inspect that yourself first. Ask the user only when the reference implementation is private, ambiguous, or otherwise unavailable. Typical sources are:
pipeline_*.py file from the diffusers library or HuggingFace repo)model_index.json and the associated pipeline classOnce you have the reference code, study it thoroughly:
model_index.json to identify required modules (text_encoder, vae, transformer, scheduler, etc.)__call__ method end-to-end. Identify:
Before creating any new files, check whether an existing pipeline or stage can be reused or extended. Only create new pipelines/stages when the existing ones would require extensive modifications or when no similar implementation exists.
Specifically:
BeforeDenoisingStage with minor parameter differencesadd_standard_t2i_stages() / add_standard_ti2i_stages() / add_standard_ti2v_stages() if the model fits standard patternsruntime/pipelines_core/stages/ and stages/model_specific_stages/. If an existing stage handles 80%+ of what the new model needs, extend it rather than duplicating it.AutoencoderKL), text encoders (CLIP, T5), and schedulers. Reuse these directly instead of re-implementing.Rule of thumb: Only create a new file when the existing implementation would need substantial structural changes to accommodate the new model, or when no architecturally similar implementation exists.
Adapt or implement the model's core components in the appropriate directories.
DiT/Transformer (runtime/models/dits/{model_name}.py):
# python/sglang/multimodal_gen/runtime/models/dits/my_model.py
import torch
import torch.nn as nn
from sglang.multimodal_gen.runtime.layers.layernorm import (
LayerNormScaleShift,
RMSNormScaleShift,
)
from sglang.multimodal_gen.runtime.layers.attention.selector import (
get_attn_backend,
)
class MyModelTransformer2DModel(nn.Module):
"""DiT model for MyModel.
Adapt from the Diffusers/reference implementation. Key points:
- Use SGLang's fused LayerNorm/RMSNorm ops (see `existing-fast-paths.md` under the benchmark/profile skill)
- Use SGLang's attention backend selector
- Keep the same parameter naming as Diffusers for weight loading compatibility
"""
def __init__(self, config):
super().__init__()
# ... model layers ...
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.Tensor,
# ... model-specific kwargs ...
) -> torch.Tensor:
# ... forward pass ...
return output
Tensor Parallel (TP) and Sequence Parallel (SP): For multi-GPU deployment, it is recommended to add TP/SP support to the DiT model. This can be done incrementally after the single-GPU implementation is verified. Reference existing implementations and adapt to your model's architecture:
runtime/models/dits/wanvideo.py) — Full TP + SP reference:
ColumnParallelLinear for Q/K/V projections, RowParallelLinear for output projections, attention heads divided by tp_sizeget_sp_world_size(), padding for alignment, sequence_model_parallel_all_gather for aggregationskip_sequence_parallel=is_cross_attention)runtime/models/dits/qwen_image.py) — SP + USPAttention reference:
USPAttention (Ulysses + Ring Attention), configured via --ulysses-degree / --ring-degreeMergedColumnParallelLinear for QKV (with Nunchaku quantization), ReplicatedLinear otherwiseImportant: These are references only — each model has its own architecture and parallelism requirements. Consider:
Key imports for distributed support:
from sglang.multimodal_gen.runtime.distributed import (
divide,
get_sp_group,
get_sp_world_size,
get_tp_world_size,
sequence_model_parallel_all_gather,
)
from sglang.multimodal_gen.runtime.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
ReplicatedLinear,
)
VAE (runtime/models/vaes/{model_name}.py): Implement if the model uses a non-standard VAE. Many models reuse existing VAEs.
Encoders (runtime/models/encoders/{model_name}.py): Implement if the model uses custom text/image encoders.
Schedulers (runtime/models/schedulers/{scheduler_name}.py): Implement if the model requires a custom scheduler not available in Diffusers.
DiT Config (configs/models/dits/{model_name}.py):
# python/sglang/multimodal_gen/configs/models/dits/mymodel.py
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.dits.base import DiTConfig
@dataclass
class MyModelDitConfig(DiTConfig):
arch_config: dict = field(default_factory=lambda: {
"in_channels": 16,
"num_layers": 24,
"patch_size": 2,
# ... model-specific architecture params ...
})
VAE Config (configs/models/vaes/{model_name}.py):
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig
@dataclass
class MyModelVAEConfig(VAEConfig):
vae_scale_factor: int = 8
# ... VAE-specific params ...
Sampling Params (configs/sample/{model_name}.py):
from dataclasses import dataclass
from sglang.multimodal_gen.configs.sample.base import SamplingParams
@dataclass
class MyModelSamplingParams(SamplingParams):
num_inference_steps: int = 50
guidance_scale: float = 7.5
height: int = 1024
width: int = 1024
# ... model-specific defaults ...
The PipelineConfig holds static model configuration and defines callback methods used by the standard DenoisingStage and DecodingStage.
# python/sglang/multimodal_gen/configs/pipeline_configs/my_model.py
from dataclasses import dataclass, field
from sglang.multimodal_gen.configs.pipeline_configs.base import (
ImagePipelineConfig, # for image generation
# SpatialImagePipelineConfig, # alternative base
# VideoPipelineConfig, # for video generation
)
from sglang.multimodal_gen.configs.models.dits.mymodel import MyModelDitConfig
from sglang.multimodal_gen.configs.models.vaes.mymodel import MyModelVAEConfig
@dataclass
class MyModelPipelineConfig(ImagePipelineConfig):
"""Pipeline config for MyModel.
This config provides callbacks that the standard DenoisingStage and
DecodingStage use during execution. The BeforeDenoisingStage handles
all model-specific pre-processing independently.
"""
task_type: ModelTaskType = ModelTaskType.T2I
vae_precision: str = "bf16"
should_use_guidance: bool = True
vae_tiling: bool = False
enable_autocast: bool = False
dit_config: DiTConfig = field(default_factory=MyModelDitConfig)
vae_config: VAEConfig = field(default_factory=MyModelVAEConfig)
# --- Callbacks used by DenoisingStage ---
def get_freqs_cis(self, batch, device, rotary_emb, dtype):
"""Prepare rotary position embeddings for the DiT."""
# Model-specific RoPE computation
...
return freqs_cis
def prepare_pos_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
"""Build positive conditioning kwargs for each denoising step."""
return {
"hidden_states": latent_model_input,
"encoder_hidden_states": batch.prompt_embeds[0],
"timestep": t,
# ... model-specific kwargs ...
}
def prepare_neg_cond_kwargs(self, batch, latent_model_input, t, **kwargs):
"""Build negative conditioning kwargs for CFG."""
return {
"hidden_states": latent_model_input,
"encoder_hidden_states": batch.negative_prompt_embeds[0],
"timestep": t,
# ... model-specific kwargs ...
}
# --- Callbacks used by DecodingStage ---
def get_decode_scale_and_shift(self):
"""Return (scale, shift) for latent denormalization before VAE decode."""
return self.vae_config.latents_std, self.vae_config.latents_mean
def post_denoising_loop(self, latents, batch):
"""Optional post-processing after the denoising loop finishes."""
return latents.to(torch.bfloat16)
def post_decoding(self, frames, server_args):
"""Optional post-processing after VAE decoding."""
return frames
Important: The prepare_pos_cond_kwargs / prepare_neg_cond_kwargs methods define what the DiT receives at each denoising step. These must match the DiT's forward() signature.
This is the heart of the Hybrid pattern. Create a single stage that handles ALL pre-processing.
# python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/my_model.py
import torch
from typing import List, Optional, Union
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
logger = init_logger(__name__)
class MyModelBeforeDenoisingStage(PipelineStage):
"""Monolithic pre-processing stage for MyModel.
Consolidates all logic before the denoising loop:
- Input validation
- Text/image encoding
- Latent preparation
- Timestep/sigma computation
This stage produces a Req batch with all fields required by
the standard DenoisingStage.
"""
def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler):
super().__init__()
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.transformer = transformer
self.scheduler = scheduler
# ... other initialization (image processors, scale factors, etc.) ...
# --- Internal helper methods ---
# Copy/adapt directly from the Diffusers reference pipeline.
# These are private to this stage; no need to make them reusable.
def _encode_prompt(self, prompt, device, dtype):
"""Encode text prompt into embeddings."""
# ... model-specific text encoding logic ...
return prompt_embeds, negative_prompt_embeds
def _prepare_latents(self, batch_size, height, width, dtype, device, generator):
"""Create initial noisy latents."""
# ... model-specific latent preparation ...
return latents
def _prepare_timesteps(self, num_inference_steps, device):
"""Compute the timestep/sigma schedule."""
# ... model-specific timestep computation ...
return timesteps, sigmas
# --- Main forward method ---
@torch.no_grad()
def forward(self, batch: Req, server_args: ServerArgs) -> Req:
"""Execute all pre-processing and populate batch for DenoisingStage.
This method mirrors the first half of a Diffusers pipeline __call__,
up to (but not including) the denoising loop.
"""
device = get_local_torch_device()
dtype = torch.bfloat16
generator = torch.Generator(device=device).manual_seed(batch.seed)
# 1. Encode prompt
prompt_embeds, negative_prompt_embeds = self._encode_prompt(
batch.prompt, device, dtype
)
# 2. Prepare latents
latents = self._prepare_latents(
batch_size=1,
height=batch.height,
width=batch.width,
dtype=dtype,
device=device,
generator=generator,
)
# 3. Prepare timesteps
timesteps, sigmas = self._prepare_timesteps(
batch.num_inference_steps, device
)
# 4. Populate batch with everything DenoisingStage needs
batch.prompt_embeds = [prompt_embeds]
batch.negative_prompt_embeds = [negative_prompt_embeds]
batch.latents = latents
batch.timesteps = timesteps
batch.num_inference_steps = len(timesteps)
batch.sigmas = sigmas
batch.generator = generator
batch.raw_latent_shape = latents.shape
batch.height = batch.height
batch.width = batch.width
return batch
Key fields that DenoisingStage expects on the batch (set these in your forward):
| Field | Type | Description |
|---|---|---|
batch.latents | torch.Tensor | Initial noisy latent tensor |
batch.timesteps | torch.Tensor | Timestep schedule |
batch.num_inference_steps | int | Number of denoising steps |
batch.sigmas | list[float] | Sigma schedule (as a list, not numpy) |
batch.prompt_embeds | list[torch.Tensor] | Positive prompt embeddings (wrapped in list) |
batch.negative_prompt_embeds | list[torch.Tensor] | Negative prompt embeddings (wrapped in list) |
batch.generator | torch.Generator | RNG generator for reproducibility |
batch.raw_latent_shape | tuple | Original latent shape before any packing |
batch.height / batch.width | int | Output dimensions |
The pipeline class is minimal -- it just wires the stages together.
# python/sglang/multimodal_gen/runtime/pipelines/my_model.py
from sglang.multimodal_gen.runtime.pipelines_core import LoRAPipeline
from sglang.multimodal_gen.runtime.pipelines_core.composed_pipeline_base import (
ComposedPipelineBase,
)
from sglang.multimodal_gen.runtime.pipelines_core.stages import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines_core.stages.model_specific_stages.my_model import (
MyModelBeforeDenoisingStage,
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
class MyModelPipeline(LoRAPipeline, ComposedPipelineBase):
pipeline_name = "MyModelPipeline" # Must match model_index.json _class_name
_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"scheduler",
# ... list all modules from model_index.json ...
]
def create_pipeline_stages(self, server_args: ServerArgs):
# 1. Monolithic pre-processing (model-specific)
self.add_stage(
MyModelBeforeDenoisingStage(
vae=self.get_module("vae"),
text_encoder=self.get_module("text_encoder"),
tokenizer=self.get_module("tokenizer"),
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
# 2. Standard denoising loop (framework-provided)
self.add_stage(
DenoisingStage(
transformer=self.get_module("transformer"),
scheduler=self.get_module("scheduler"),
),
)
# 3. Standard VAE decoding (framework-provided)
self.add_standard_decoding_stage()
# REQUIRED: This is how the registry discovers the pipeline
EntryClass = [MyModelPipeline]
In python/sglang/multimodal_gen/registry.py, register your configs:
register_configs(
model_family="my_model",
sampling_param_cls=MyModelSamplingParams,
pipeline_config_cls=MyModelPipelineConfig,
hf_model_paths=[
"org/my-model-name", # HuggingFace model ID(s)
],
)
The EntryClass in your pipeline file is automatically discovered by the registry's _discover_and_register_pipelines() function -- no additional registration needed for the pipeline class itself.
After implementation, you must verify that the generated output is not noise. A noisy or garbled output image/video is the most common sign of an incorrect implementation. Common causes include:
get_decode_scale_and_shift returning wrong values)forward() signature)vae_scale_factor, missing denormalization)is_neox_style set incorrectly)If the output is noise, the implementation is incorrect — do not ship it. Debug by:
| Model | Pipeline | BeforeDenoisingStage | PipelineConfig |
|---|---|---|---|
| GLM-Image | runtime/pipelines/glm_image.py | stages/model_specific_stages/glm_image.py | configs/pipeline_configs/glm_image.py |
| Qwen-Image-Layered | runtime/pipelines/qwen_image.py (QwenImageLayeredPipeline) | stages/model_specific_stages/qwen_image_layered.py | configs/pipeline_configs/qwen_image.py (QwenImageLayeredPipelineConfig) |
| Model | Pipeline | Notes |
|---|---|---|
| Qwen-Image (T2I) | runtime/pipelines/qwen_image.py | Uses add_standard_t2i_stages() — standard text encoding + latent prep fits this model |
| Qwen-Image-Edit | runtime/pipelines/qwen_image.py | Uses add_standard_ti2i_stages() — standard image-to-image flow |
| Flux | runtime/pipelines/flux.py | Uses add_standard_t2i_stages() with custom prepare_mu |
| Wan | runtime/pipelines/wan_pipeline.py | Uses add_standard_ti2v_stages() |
Before submitting, verify:
Common (both styles):
runtime/pipelines/{model_name}.py with EntryClassconfigs/pipeline_configs/{model_name}.pyconfigs/sample/{model_name}.pyruntime/models/dits/{model_name}.pyconfigs/models/dits/{model_name}.pyAutoencoderKL) or create new at runtime/models/vaes/configs/models/vaes/{model_name}.pyregistry.py via register_configs()pipeline_name matches Diffusers model_index.json _class_name_required_config_modules lists all modules from model_index.jsonPipelineConfig callbacks (prepare_pos_cond_kwargs, get_freqs_cis, etc.) match DiT's forward() signatureexisting-fast-paths.md under the benchmark/profile skill)wanvideo.py for TP+SP, qwen_image.py for USPAttention)Hybrid style only:
stages/model_specific_stages/{model_name}.pyBeforeDenoisingStage.forward() populates all fields needed by DenoisingStagebatch.sigmas must be a Python list, not a numpy array. Use .tolist() to convert.batch.prompt_embeds is a list of tensors (one per encoder), not a single tensor. Wrap with [tensor].batch.raw_latent_shape -- DecodingStage uses it to unpack latents.is_neox_style=True = split-half rotation, is_neox_style=False = interleaved. Check the reference model carefully.vae_precision in the PipelineConfig accordingly.After the model produces non-noise output, read
references/testing-and-accuracy.md before
adding GPU cases, component-accuracy skips/hooks, suite entries, or benchmark
claims. That reference tracks the current gpu_cases.py / testcase_configs.py
/ run_suite.py split and the component-accuracy decision rules.