VLA (Vision-Language-Action) model training workflow for autonomous driving. Use when setting up training pipelines, configuring distributed training (DeepSpeed/FSDP), preparing multimodal datasets (nuScenes/Waymo/custom), building VLA architectures, or debugging training issues. Also use for "训练", "train", "finetune", "微调", "分布式", "distributed", "数据处理", "data pipeline" queries.
End-to-end training pipeline for Vision-Language-Action models in autonomous driving.
Typical VLA model:
Camera Images → Vision Encoder (CLIP/DINOv2/SigLIP) → Visual Tokens
↓
Text Instructions → Tokenizer → Text Tokens → [LLM Backbone] → Action Tokens → Action Head → Waypoints/Controls
↑
Fused Multimodal Input
# Standard nuScenes loading
from nuscenes.nuscenes import NuScenes
nusc = NuScenes(version='v1.0-trainval', dataroot='/data/nuscenes')
# Key tables: sample, sample_data, ego_pose, calibrated_sensor
# Camera keys: CAM_FRONT, CAM_FRONT_LEFT, CAM_FRONT_RIGHT, CAM_BACK, CAM_BACK_LEFT, CAM_BACK_RIGHT
import torch
from torch.utils.data import Dataset
class VLADataset(Dataset):
"""
Each sample: {
'images': Tensor[N_cam, C, H, W], # multi-view camera images
'text': str, # language instruction / scene description
'actions': Tensor[T, action_dim], # future waypoints or controls
'ego_state': Tensor[state_dim], # current ego vehicle state
}
"""
def __init__(self, data_root, split='train', transform=None):
self.data_root = data_root
self.split = split
self.transform = transform
self.samples = self._load_samples()
def _load_samples(self):
# Load annotation file (JSON/pickle)
import json
with open(f'{self.data_root}/{self.split}.json') as f:
return json.load(f)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
images = self._load_images(sample['image_paths'])
if self.transform:
images = torch.stack([self.transform(img) for img in images])
return {
'images': images,
'text': sample['instruction'],
'actions': torch.tensor(sample['future_waypoints'], dtype=torch.float32),
'ego_state': torch.tensor(sample['ego_state'], dtype=torch.float32),
}
# Launch
deepspeed --num_gpus=8 train.py \
--deepspeed_config ds_config.json \
--model_name vla_base \
--data_root /data/nuscenes \
--output_dir /data/checkpoints/vla_exp01
ds_config.json (ZeRO Stage 2, good balance):
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "none"},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"overlap_comm": true
},
"gradient_accumulation_steps": 4,
"gradient_clipping": 1.0,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
ZeRO Stage cheat sheet:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
model = FSDP(
model,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, # ~ZeRO-2
device_id=torch.cuda.current_device(),
)
# torchrun (FSDP)
torchrun --nproc_per_node=8 --nnodes=1 train.py
# Multi-node
torchrun --nproc_per_node=8 --nnodes=2 \
--node_rank=$RANK --master_addr=$MASTER --master_port=29500 train.py
# SLURM
srun --gres=gpu:8 --ntasks-per-node=8 --nodes=2 \
python -m torch.distributed.run --nproc_per_node=8 train.py
# Optimizer