Converting PyTorch model weights to Keras h5 format for keras_cv_attention_models
Convert pretrained PyTorch weights to Keras h5 format.
The core task is aligning the weight name order between torch and keras models. download_and_load.keras_reload_from_torch_model is a convenience helper that automates this, but direct manual conversion is also fine.
state_dict_stack_by_layer): Groups torch state_dict entries by layer name (splitting on .), filtering via skip_weights and unstack_weights.align_layer_names_multi_stage): Reorders keras layer names to match torch weight order.keras_reload_stacked_state_dict): Applies standard transforms (Conv2D/Dense transpose, etc.) plus custom additional_transfer overrides, then saves.| Parameter | Purpose |
|---|---|
skip_weights | Weight name suffixes to drop (e.g., ["num_batches_tracked", "relative_position_index"]) |
unstack_weights | Weights kept as individual entries instead of grouped with their layer (e.g., ["cls_token", "pos_embed", "gamma_1"]) |
| Parameter | Purpose |
|---|---|
tail_align_dict | Reposition layers by tail name: {tail_name: offset}. Negative offset moves layer earlier. Can be scoped by stack: {"stack3": {"attn_gamma": -6}} |
full_name_align_dict | Reposition by exact name: value can be negative offset, absolute position, or another layer's name string |
tail_split_position | Where to split name into head/tail (default 2). E.g., 1 → head=stack1, tail=attn_gamma |
specific_match_func | Function returning the complete ordered name list, bypassing all alignment logic. Use for complex cases where dicts can't express the mapping |
| Parameter | Purpose |
|---|---|
additional_transfer | Custom transforms: {LayerClass: lambda ww: [...]} or {"name_suffix": lambda ww: [...]}. Applied after default Conv2D/Dense transposes |
pretrained=None, classifier_activation=Nonedo_convert=False first to inspect both name listsdo_convert=True — it predicts with both models and prints resultsEgyptian_cat for the cat test image)md5sum output.h5 → add hash to PRETRAINED_DICTunstack_weightsdo_convert=False to see lists side-by-side; adjust offsets or use specific_match_func for full controlrescale_mode in add_pre_post_process(), classifier_activation, or intermediate layer outputs