Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.
Add support for a new HuggingFace model architecture in the Archon training engine.
This skill is triggered when:
ModelSpec or model type for ArchonBefore starting, ensure:
config.json with model_type)meta-llama/Llama-3-8B)Read the HuggingFace model's source code to extract key architecture information.
Action: Fetch and analyze the model's HuggingFace configuration and modeling files.
Read the model's config.json (via AutoConfig.from_pretrained) to identify:
model_type string (this is the key used for registry lookup)qk_norm, attention_bias, MoE fields)Read the HuggingFace modeling_*.py source to identify:
tie_word_embeddings appear in config?Summarize findings in a checklist like:
Target model: <name>
HF model_type: "<model_type>" (and variants like "<model_type>_moe" if applicable)
Attention: [standard GQA / with QK norm / with bias / sliding window / ...]
FFN: [SwiGLU / GeGLU / standard MLP / ...]
MoE: [no / yes - num_experts, top_k, shared_experts]
RoPE: [standard / YaRN / NTK-aware / ...]
Norm: [RMSNorm / LayerNorm] with [pre-norm / post-norm]
Weight tying: [yes / no]
Choose the closest existing implementation as a starting point:
| Target characteristics | Reference | Why |
|---|---|---|
| Dense-only, standard GQA, no QK norm | qwen2 | Simplest baseline, pure dense |
| Has QK norm, or has MoE support | qwen3 | Supports QK norm + MoE + shared experts |
Action: Copy the reference model directory as the starting point:
areal/experimental/models/archon/<model>/
__init__.py
spec.py
model/
args.py
model.py
rope.py
state_dict_adapter.py
infra/
parallelize.py
args.pyAdapt <Model>ModelArgs to match the target model's HuggingFace config fields.
Key changes from reference:
Update the @dataclass fields to match the target model's hyperparameters:
dim, n_layers, n_heads,
n_kv_heads, vocab_size, head_dim, hidden_dim, norm_eps, rope_theta,
etc.)attention_bias, qk_norm, sliding_window)Update from_hf_config() to correctly map HuggingFace config attributes:
getattr(hf_config, "field_name", default) for optional fieldsCritical: Verify every field mapping against the HF model's config.json. Incorrect
mappings here cause silent errors downstream.
Base class contract (BaseModelArgs):
@dataclass
class <Model>ModelArgs(BaseModelArgs):
# ... model-specific fields ...
@classmethod
def from_hf_config(
cls,
hf_config: PretrainedConfig,
is_critic: bool = False,
**kwargs,
) -> <Model>ModelArgs:
# Map HF config fields to Archon model args
...
model.pyAdapt the model architecture to match the target model.
Key components to adapt:
Normalization (RMSNorm or similar):
elementwise_affine is configurableLayerNorm, implement accordinglyAttention module:
nn.Linear(..., bias=True/False))q_norm/k_norm if the model has them, remove if it doesn'tn_kv_heads < n_heads for grouped-query attentionset_cp_group / _sp_enabled pattern from the referenceFeedForward module:
w2(silu(w1(x)) * w3(x)) -- most common for modern LLMsMoE module replaces FeedForward on designated layersTransformerBlock: Pre-norm (most modern LLMs) vs post-norm
_is_moe_layer() if applicableTop-level Model (<Model>Model(BaseArchonModel)):
tok_embeddings, layers (as ModuleDict), norm, output/scoreinit_weights(): Match initialization scheme from HFinit_buffers(): RoPE cache + MoE buffersforward(): Must follow BaseArchonModel signature:
(tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> TensorBase class contract (BaseArchonModel):
class <Model>Model(BaseArchonModel):
def forward(self, tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> torch.Tensor: ...
def init_weights(self) -> None: ...
def init_buffers(self, buffer_device) -> None: ...
rope.pyHandle the rotary position embedding variant.
Options:
Standard RoPE (same as qwen2/qwen3): Re-export from qwen2:
from areal.experimental.models.archon.qwen2.model.rope import (
apply_rotary_emb,
precompute_rope_cache,
repeat_kv,
reshape_for_broadcast,
rotate_half,
)
Custom RoPE (YaRN, NTK-aware, etc.): Implement custom precompute_rope_cache()
and apply_rotary_emb() functions. The key difference is usually in how inv_freq
is computed (scaling factors, interpolation, etc.).
state_dict_adapter.pyMap between HuggingFace and Archon weight key names.
This is the most error-prone step. The adapter must correctly handle:
Key name mapping (from_hf_map dict):
model.embed_tokens.weight -> tok_embeddings.weightmodel.layers.{}.self_attn.q_proj.weight ->
layers.{}.attention.wq.weightmodel.layers.{}.mlp.gate_proj.weight -> layers.{}.feed_forward.w1.weightmodel.layers.{}.input_layernorm.weight ->
layers.{}.attention_norm.weightlm_head.weight -> output.weightNone): rotary_emb.inv_freq (computed at runtime)Reverse mapping (to_hf_map): Auto-generated from from_hf_map
MoE expert weights (if applicable): 3D<->2D conversion for expert weights. Copy the MoE handling from qwen3 if the model has MoE.
Weight tying: Skip output.weight during to_hf() if tie_word_embeddings=True
Verification approach: After implementation, the adapter should satisfy:
# Roundtrip: archon -> hf -> archon preserves all keys
hf_sd = adapter.to_hf(archon_sd)
roundtrip_sd = adapter.from_hf(hf_sd)
assert set(roundtrip_sd.keys()) == set(archon_sd.keys())
Base class contract (BaseStateDictAdapter):
class <Model>StateDictAdapter(BaseStateDictAdapter):
def from_hf(self, hf_state_dict) -> dict[str, Any]: ...
def to_hf(self, archon_state_dict) -> dict[str, Any]: ...
def convert_single_to_hf(self, name, tensor) -> list[tuple[str, torch.Tensor]]: ...
parallelize.pyDefine the parallelization strategy for the model.
The parallelize function applies parallelism in this order:
Key adaptations by model architecture:
use_local_output=False (DTensor output for
norm), add SequenceParallel(sequence_dim=2) for q_norm/k_normuse_local_output=Trueapply_moe_ep_tp() and apply_non_moe_tp()Function signature (must match ParallelizeFn protocol):
def parallelize_<model>(
model: nn.Module,
parallel_dims: ArchonParallelDims,
param_dtype: torch.dtype = torch.bfloat16,
reduce_dtype: torch.dtype = torch.float32,
loss_parallel: bool = True,
cpu_offload: bool = False,
reshard_after_forward_policy: str = "default",
ac_config: ActivationCheckpointConfig | None = None,
enable_compile: bool = True,
) -> nn.Module:
spec.py and RegisterAssemble the ModelSpec and register it.
from areal.experimental.models.archon.model_spec import ModelSpec, register_model_spec
from areal.experimental.models.archon.pipeline_parallel import pipeline_llm
from areal.experimental.models.archon.<model>.infra.parallelize import parallelize_<model>
from areal.experimental.models.archon.<model>.model.args import <Model>ModelArgs
from areal.experimental.models.archon.<model>.model.model import <Model>Model
from areal.experimental.models.archon.<model>.model.state_dict_adapter import (
<Model>StateDictAdapter,
)
<MODEL>_SPEC = ModelSpec(
name="<Model>",
model_class=<Model>Model,
model_args_class=<Model>ModelArgs,
state_dict_adapter_class=<Model>StateDictAdapter,
parallelize_fn=parallelize_<model>,
supported_model_types=frozenset({"<model_type>"}), # From HF config.json
pipelining_fn=pipeline_llm,
)
# Auto-register when module is imported
register_model_spec(<MODEL>_SPEC)
__all__ = ["<MODEL>_SPEC"]
Note: supported_model_types should include all HF model_type strings that this
implementation handles (e.g., {"qwen3", "qwen3_moe"} for Qwen3).
__init__.pyAdd the import to areal/experimental/models/archon/__init__.py:
from areal.experimental.models.archon.<model> import spec as <model>_spec # noqa: F401
This triggers auto-registration when the module is imported.
Verification should be done in stages, adapting based on available hardware and the test
patterns in tests/experimental/archon/.
Before writing tests, examine the existing test files to understand current