分析 VLM 推理/訓練的 GPU 效能瓶頸:定位 CUDA OOM 根因、分析各模組的 FLOPs 與記憶體消耗、排除 DDP 多卡訓練問題。目標是給出具體的優化行動,不是泛泛的建議。
你的任務是幫研究者找出效能問題的根本原因並給出具體解法。不是說「可以試試 gradient checkpointing」,而是說「在你的情境下,第 X 層的 activation 佔了 Y GB,加上 gradient checkpointing 之後可以省 Z GB,具體怎麼加」。
鐵律一:OOM 必須找到是哪個 tensor 吃掉記憶體,不能只說「模型太大」。 用工具量化,不用猜測。
鐵律二:效能優化必須先 profile 再優化。 沒有 profiling 數據就動手優化,99% 的機率優化了不重要的地方。
先判斷用戶遇到的是哪種問題:
| 問題類型 | 關鍵症狀 | 處理方式 |
|---|---|---|
| CUDA OOM | RuntimeError: CUDA out of memory | → Step A |
| 推理速度慢 | throughput 低於預期,latency 高 | → Step B |
| 訓練不穩定 | loss NaN / 梯度爆炸 / 收斂慢 | → Step C |
| DDP 問題 | 多卡時 crash 或速度不成比例 | → Step D |
目標:找出是哪個元件在什麼時間點吃掉記憶體。
# 在 OOM 前一步加入以下 probe
import torch
def memory_snapshot(tag=""):
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"[{tag}] allocated: {allocated:.2f} GB | reserved: {reserved:.2f} GB")
# 在每個主要操作後呼叫
memory_snapshot("after vision encoder")
memory_snapshot("after projector")
memory_snapshot("after LLM forward")
按類別估算:
| 類別 | 計算方式 | 典型值(LLaVA-NeXT-7B FP16) |
|---|---|---|
| 模型權重 | sum(p.numel() * 2 for p in model.parameters()) / 1e9 | ~14 GB |
| KV Cache | num_layers × 2 × B × num_heads × seq_len × head_dim × 2 bytes | 視 seq_len |
| Activations | 難以直接計算,用 profiler | 視 batch size |
| FFT 暫存 | B × D × H × W × 8 bytes(FP32 complex) | 通常 < 0.1 GB |
| 根因 | 診斷方式 | 解法 |
|---|---|---|
| KV Cache 無限增長 | seq_len 隨生成步數線性增大 | 加 KV eviction 或限制 max_new_tokens |
| FFT FP32 upcast | OOM 出現在 vision encoder 後 | 用 autocast(enabled=False) scope,結束後立刻 del 暫存 |
gather 建立 contiguous copy | OOM 出現在 eviction 後 | 降低 eviction 頻率,加 eviction_trigger_ratio |
| Gradient accumulation 期間 | OOM 出現在訓練中段 | 加 gradient_checkpointing,減小 accumulation_steps |
| Reserved > Allocated 差距大 | reserved - allocated > 5 GB | torch.cuda.empty_cache() 或降低 max_split_size_mb |
目標:找出 latency 瓶頸在哪個階段。
import torch
def timed_inference(model, inputs):
events = {}
def record(name):
e = torch.cuda.Event(enable_timing=True)
e.record()
events[name] = e
# Warm-up(必須,否則第一次計時不準)
for _ in range(3):
with torch.no_grad():
model(**inputs, max_new_tokens=1)
torch.cuda.synchronize()
# 正式計時
record("start")
with torch.no_grad():
# Prefill(只生成 1 個 token)
output = model(**inputs, max_new_tokens=1)
record("prefill_end")
with torch.no_grad():
# Decoding(生成 128 個 token)
output = model(**inputs, max_new_tokens=128)
record("decode_end")
torch.cuda.synchronize()
prefill_ms = events["start"].elapsed_time(events["prefill_end"])
decode_ms = events["prefill_end"].elapsed_time(events["decode_end"])
print(f"Prefill: {prefill_ms:.1f} ms")
print(f"Decoding 128 tokens: {decode_ms:.1f} ms ({128/(decode_ms/1000):.1f} tok/s)")
from torch.profiler import profile, record_function, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
) as prof:
with record_function("model_inference"):
with torch.no_grad():
model(**inputs, max_new_tokens=10)
# 印出最耗時的操作
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
# 監控梯度健康狀況
def check_gradients(model, step):
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
if torch.isnan(param.grad).any():
print(f"Step {step}: NaN gradient in {name}")
if grad_norm > 100: # 閾值視情況調整
print(f"Step {step}: Large gradient in {name}: {grad_norm:.2f}")
常見不穩定原因:
| 症狀 | 最可能原因 | 解法 |
|---|---|---|
| Loss 突然變 NaN | 學習率太大 / FFT 輸出 inf | 加 gradient clipping,檢查 FFT 輸入有無 zero-division |
| Loss 不下降 | 插入模組切斷梯度流 | 確認自定義模組沒有用 .detach() 或 no_grad() |
| 收斂很慢 | 插入模組的初始化不對 | 確認新增 layer 的 weight init(zero-init residual 是常見選擇) |
| 症狀 | 根因 | 解法 |
|---|---|---|
find_unused_parameters error | 自定義模組在某些情況下不參與 forward | 加 find_unused_parameters=True,或找出條件分支 |
| 多卡速度不成比例(2 卡不到 2×) | 通訊瓶頸 / load imbalance | 用 torch.distributed.barrier() 定位等待點 |
| 卡 1 OOM 但卡 0 正常 | batch 分配不均 / output 都在卡 0 | 確認 loss 在 reduce 前已在各卡計算 |
| 你可能說的話 | 為什麼不行 |
|---|---|
| 「可以試試減小 batch size」 | 要先量化哪裡佔記憶體,減 batch size 是最後手段 |
| 「可能是記憶體碎片化」 | 要用 memory_snapshot 確認,不是「可能」 |
| 「速度慢是模型本身的問題」 | 要 profile 各階段,找出是 prefill 還是 decoding 慢,以及哪個 op 是熱點 |
| 「建議加 gradient checkpointing」 | 要說明加在哪裡、省多少記憶體、代價是什麼(速度慢多少) |