Diagnose, monitor, and debug ML model training runs. Use this skill when the user wants to: understand why a model isn't learning, diagnose training instability (NaN loss, spikes, divergence), interpret loss curves and training metrics, decide whether to adjust hyperparameters mid-run, figure out if a run is worth continuing or should be killed, debug quantization-aware training issues, or understand gradient behavior. Trigger when the user mentions: loss curves, training logs, gradient norms, learning rate schedules, NaN loss, training divergence, "model isn't learning", "loss is stuck", "should I kill this run", validation loss, overfitting/underfitting, warmup, cooldown, weight decay tuning, or any question about how training is going. Also use when someone pastes training logs and wants interpretation.
You are helping someone understand and fix their training runs. Training ML models is mostly watching numbers and making judgment calls about whether those numbers look right. This skill is about developing that judgment.
The loss curve is the single most informative signal during training. Here's what different shapes tell you.
Loss
|
|\.
| \.
| \..
| '...
| '''''....___
|_________________________ Steps
Smooth, monotonically decreasing, with diminishing returns. Validation loss tracks training loss with a small gap. This is what you want.
Flat start then sudden drop (S-curve): Normal for some architectures (ternary models, very deep networks). The model is learning internal representations before they start helping. Don't kill the run during the flat part -- wait at least 2x longer than the flat period before deciding it's not learning.
Occasional small spikes (2-3x the running average) are usually fine, especially early in training. The model recovers. Worry when:
Loss goes to NaN: Almost always one of:
Debug by: lowering LR 10x, adding gradient clipping, checking for numerical operations that could produce NaN, printing intermediate activations to find where the NaN first appears.
Loss plateaus then drops: This is often grokking -- the model memorizes first, then suddenly generalizes. More common with weight decay and smaller models. Usually a good sign, but verify the drop corresponds to actual generalization by checking validation metrics.
Validation loss diverges from training loss: Overfitting. Training loss keeps improving but validation loss gets worse. Solutions:
Loss oscillates without converging: Learning rate too high, or batch size too small. Try: halving the learning rate, doubling the batch size, or both.
Training loss: The primary signal. Should decrease. Log it per step and per epoch.
Validation loss: Check every 100-500 steps. If it diverges from training loss, you're overfitting.
Learning rate: Especially important with schedules. Verify the schedule is doing what you think. Plot it.
Step time (ms): Should be stable. If it creeps up, you may have a memory leak or data loading bottleneck. Sudden increases often mean the model fell off GPU onto CPU for some operation.
Memory usage: Should be stable after the first few steps. Increasing memory = memory leak (usually from accumulating computation graphs).
Gradient norm: The L2 norm of all gradients. Healthy range depends on model size, but:
Weight statistics: Mean, std, min, max of each parameter group. If weights grow unbounded or collapse to zero, something is wrong.
Activation statistics: Same thing, but for intermediate activations. Dead ReLU neurons (always zero) indicate the model is losing capacity.
Gradient statistics per layer: Are all layers receiving gradients? Is one layer's gradient 1000x larger than another's? This indicates an imbalance that will cause some layers to train much faster than others.
QAT has its own failure modes:
Quantization gap widens during training: The difference between quantized and unquantized loss grows over time. This means the model is learning features that can't survive quantization. Solutions:
Ternary weight collapse: All weights in a layer drift to the same ternary value (usually 0). The layer is effectively dead. Solutions:
Mixed-precision instability: When some layers are quantized and others aren't, the gradient scales can be wildly different. Solutions:
Increase if: Loss is decreasing very slowly and gradient norms are tiny. The model can handle more. Decrease if: Loss is noisy, spiky, or diverging. Also decrease if gradient norms are very large. Standard approach: Use a schedule (warmup + cosine decay) and don't touch it. If the schedule isn't working, redesign it rather than manually adjusting mid-run.
Increase if: Training is too noisy, you have memory headroom, and you want smoother gradients. Decrease if: You're overfitting or need more gradient noise for exploration. Note: Changing batch size mid-run changes the effective learning rate. If you double batch size, consider halving LR.
Increase if: Overfitting (val loss diverges from train loss) or weights are growing unbounded. Decrease if: Model isn't fitting the training data (underfitting) or weights are collapsing to zero.
Add/tighten if: Gradient norms spike or you see NaN loss. Loosen if: Gradient norms are consistently well below the clip threshold (the clipping isn't doing anything, but also isn't hurting).
Kill it if:
Don't kill it if:
When someone pastes training logs, look for:
Then give a diagnosis: "Your training looks healthy, loss is decreasing at a reasonable rate, no red flags. At this trajectory, you'll reach approximately X loss by step Y" or "Your loss spiked at step 500 and hasn't recovered. This is likely [cause]. Try [fix]."
| Symptom | Likely Cause | Fix |
|---|---|---|
| Loss NaN | LR too high, numerical instability | Lower LR 10x, add grad clip, check for log(0) |
| Loss flat from start | Model too small, LR too low, data bug | Check data, increase LR, verify forward pass |
| Loss spikes regularly | LR too high, bad batches | Lower LR, check data quality |
| Val loss diverges | Overfitting | More regularization, less model capacity |
| Training very slow | Data loading bottleneck, no compile | Profile, add workers, use torch.compile |
| OOM at step N | Memory leak, activation caching | Check for detached tensors, use gradient checkpointing |
| Gradients all zero | Dead model, detached computation | Check requires_grad, verify backward pass |
| Loss decreases then plateaus early | LR schedule wrong, model capacity hit | Check schedule, try larger model |
| Quantized model much worse | QAT not working, precision too low | Start QAT earlier, use group quantization, check scaling |