$38
Structured checklist for building and verifying ML training pipelines. Every training script must satisfy each applicable item before it ships. Items marked (conditional) apply only when the stated condition is true.
Framework: PyTorch Lightning is the default training framework for all neural network training (SL and RL). Tree-based models (XGBoost, LightGBM, CatBoost) use their native APIs directly.
Field references (read on demand for detailed field tables):
@dataclass inheriting from the correct level in the hierarchy (Base → SL Neural / SL Tree / RL → task-specific). See references/config-system.md for the full hierarchy.to_dict() / from_dict() round-trip via JSON.--config flag selects among config classes by name. Per-field CLI overrides apply on top. Priority: dataclass defaults < config class < CLI flags.max_wall_clock_hours).num_epochs is set high (effectively unlimited) so early stopping is the binding constraint — not a hard epoch cap.@dataclass classes inheriting from a base, overriding only differing fields. Includes: baseline, single-dimension variants, and at least one aggressive multi-dimension variant.partN/ layout: config.py, data.py, model.py, train.py. See references/codebase-structure.md.src/ — config hierarchy, W&B integration (src/wandb_utils.py), system metrics. Never duplicated per-part.main().lightning.LightningModule:
__init__: store hyperparams via self.save_hyperparameters(), build layersforward(): pure forward pass (inference)training_step(batch, batch_idx): forward + loss, return loss. Log via self.log()validation_step(batch, batch_idx): forward + loss + eval metrics. Log via self.log()configure_optimizers(): return optimizer and optional LR scheduler dictlightning.LightningDataModule:
setup(stage): load/split datasetstrain_dataloader(): return training DataLoaderval_dataloader(): return validation DataLoadertest_dataloader(): return test DataLoader (optional)model.py additionally exposes: save_model(), load_model_from_checkpoint() for standalone checkpoint loading outside Lightning.WandbLogger:
from lightning.pytorch.loggers import WandbLogger
logger = WandbLogger(project="my-project", name=cfg.name, log_model=False)
trainer = Trainer(logger=logger)
self.save_hyperparameters() in LightningModule.__init__().self.log() and self.log_dict() inside training_step / validation_step:
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log_dict({"val_loss": val_loss, "val_rmse": rmse}, on_epoch=True)
train_loss, val_loss, eval metrics, lr)batch/: per-step metrics (use on_step=True)timing/: per-epoch timing breakdown (see Phase 3.5)tracking/: per-epoch (best_{metric}, epochs_since_improvement)system/: per-epoch (GPU/CPU/RAM — Lightning logs GPU metrics automatically when log_every_n_steps is set)logger.experiment.config.update() or self.log_dict() at start: total_params, trainable_params, num_train_samples, num_val_samples, gpu_name.WandbLogger(log_model="all") or manual wandb.log_artifact() — once per run at end of training.from lightning.pytorch.profilers import SimpleProfiler
trainer = Trainer(profiler=SimpleProfiler()) # or AdvancedProfiler() for per-function breakdown
on_train_batch_start/end, on_validation_start/end):
timing/env_step_seconds — environment stepping (RL: time inside env.step() or collector.rollout())timing/inference_seconds — policy forward passtiming/backward_seconds — loss computation + backward pass + optimizer steptiming/data_seconds — data loading, replay buffer samplingtiming/overhead_seconds — everything else (logging, checkpointing)timing/ namespace via self.log() in callbacks.timing/env_step_seconds measures the full ParallelEnv.step() call including inter-process communication.timing/env_step_pct, timing/backward_pct etc. as percentage of total iteration time.trainer = Trainer(precision="bf16-mixed")
Not fp16 — bf16 has same exponent range as fp32, no GradScaler needed. Lightning handles AMP context automatically.torch.amp.autocast() calls in model code — Lightning wraps training_step and validation_step automatically.precision="32-true" is set instead.ParallelEnv (not SerialEnv) to pipeline CPU env stepping with GPU policy inference:
from torchrl.envs import ParallelEnv
env = ParallelEnv(
num_workers=config.num_envs,
create_env_fn=lambda: make_env(config),
)
CloudpickleWrapper, not a lambda capturing local state.env.close() called during cleanup to terminate worker processes (prevents zombie processes).ParallelEnv is especially critical — GPU sits idle during serial physics stepping.env.append_transform(RewardSum()). Without it, TorchRL collectors never populate episode_reward and monitoring is blind.SyncDataCollector silently drops keys not in observation_spec. Add Unbounded specs for diagnostic fields.env.step_and_maybe_reset() (not env.step() + manual step_mdp()) — ParallelEnv.step() does NOT auto-reset done environments.self._device must be torch.device object, not a string — TorchRL's BatchedEnvBase._reset() calls .type on it.os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OMP_NUM_THREADS", "1")
batch["next"], not batch root.BoundedTensorSpec→Bounded, CompositeSpec→Composite, UnboundedContinuousTensorSpec→Unbounded.actual_updates counter for metric averaging — KL early stopping means actual_updates << num_epochs * num_batches.BatchSizeFinder callback:
from lightning.pytorch.callbacks import BatchSizeFinder
trainer = Trainer(callbacks=[BatchSizeFinder(mode="binsearch")])
LightningModule and LightningDataModule must expose a batch_size attribute that BatchSizeFinder can modify. Typically set in config and passed through:
class MyDataModule(L.LightningDataModule):
def __init__(self, batch_size=32, ...):
self.batch_size = batch_size
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size, ...)
probe_auto_batch_size() functions — use BatchSizeFinder instead.gradient_accumulation_steps available to extend effective batch beyond VRAM ceiling:
trainer = Trainer(accumulate_grad_batches=cfg.gradient_accumulation_steps)
cleanup_vram() called between configs: delete model, trainer → torch.cuda.empty_cache() → gc.collect().main() follows this sequence:
lightning.seed_everything(cfg.seed)LightningDataModule (handles data loading)LightningModule (handles model + optimizer + scheduler)WandbLoggerTrainer with all settingstrainer.fit(model, datamodule=dm) — handles entire training looptrainer.test(model, datamodule=dm, ckpt_path="best") — final eval on best checkpointwandb.finish()Trainer:
from lightning.pytorch.callbacks import (
EarlyStopping, ModelCheckpoint, BatchSizeFinder,
LearningRateMonitor, RichProgressBar,
)
callbacks = [
EarlyStopping(monitor="val_loss", patience=cfg.patience, mode="min"),
ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, save_last=True),
BatchSizeFinder(mode="binsearch"),
LearningRateMonitor(logging_interval="epoch"),
]
trainer = Trainer(callbacks=callbacks, ...)
EarlyStopping callback. Patience measured in eval cycles. Optional min_delta for minimum improvement threshold.ModelCheckpoint(save_top_k=1)) and last model (via save_last=True).STOP file between epochs via on_train_epoch_end:
class StopFileCallback(L.Callback):
def on_train_epoch_end(self, trainer, pl_module):
if Path("STOP").exists():
trainer.should_stop = True
trainer.should_stop = True, which lets Lightning finish the current epoch and run cleanup.main() with GpuLock() in if __name__ == "__main__":
if __name__ == "__main__":
from src.utils.gpu_lock import GpuLock
with GpuLock():
main()
Uses flock on /tmp/gpu-task.lock. Concurrent GPU tasks queue (not error).ps aux | grep -E "python.*train" | grep -v grep — confirm GPU is free.nohup and unbuffered output:
PYTHONUNBUFFERED=1 nohup <command> > output/<descriptive_log>.txt 2>&1 &
tail -f output/<log>.txt or /loop 10m /babysit-training.{output.base_dir}/{name}_{YYYYMMDD_HHMMSS}/
trainer = Trainer(default_root_dir=run_dir)
config.json (full snapshot saved manually), console.log, Lightning checkpoint files.ModelCheckpoint saves to {run_dir}/checkpoints/: best.ckpt, last.ckpt.metrics.jsonl via custom callback.metrics.json (final summary) and plots/ (feature importance, pred vs actual, SHAP).(conditional: search space > 3 variants)
partN/sweep.py.main_with_config(cfg) extracted from main() so sweeps call the pipeline programmatically.setup_run() is sweep-aware: detects wandb.run.sweep_id, skips wandb.init(), updates config instead.architecture preset parameter decoded into config fields.--max-hours bounds total sweep wall clock (not per trial).tmux new -s sweep && python partN/sweep.py --max-hours 12.sweep_train() catches exceptions explicitly to unwind stack before cleanup_vram(). Double gc.collect() pass for reference cycles./loop 10m /babysit-training
Covers: process health, metric trending, GPU/system checks, checkpoint integrity, hung process detection, auto-restart from checkpoint, issue documentation.logs/ immediately (not batched).experiments/.issues/ before or alongside the fix.type, status, and type-specific properties).isfinite(loss) before backward, isfinite(grad_norm) after clipping, per-batch KL early stopping.clip_grad_norm_() is actually called in _update() for both critic and actor — config field alone is not enough.