Provides guidance for interpreting and manipulating neural network internals using tdhook with TensorDict and PyTorch hooks. Use when needing attribution maps, activation analysis, probing, steering, activation patching, or weight-level interventions on PyTorch or TensorDict models.
Interpretability with TensorDict and PyTorch hooks. One pattern: with Method(...).prepare(model) as hooked: td = hooked(td).
Docs: Home · Methods · Tutorials · API · GitHub: Xmaster6y/tdhook · Paper: arXiv:2509.25475
Use tdhook when you need to:
Consider alternatives: nnsight (remote 70B+), pyvene (declarative configs), TransformerLens (cached activations).
Goal: Which inputs (pixels, tokens) matter most for a prediction?
Checklist:
init_attr_targets(td, ctx) returning TensorDict with target outputsTensorDict({"input": x, ("baseline", "input"): baseline})hooked(td); read td.get(("attr", "input"))from tdhook.attribution import IntegratedGradients
def init_attr_targets(targets, _):
return TensorDict(out=targets["output"][..., class_idx], batch_size=targets.batch_size)
with IntegratedGradients(init_attr_targets=init_attr_targets).prepare(model) as hooked:
td = hooked(TensorDict({"input": x, ("baseline", "input"): baseline}))
attr = td.get(("attr", "input"))
Goal: Train probes on representations for classification/diagnostics.
Checklist:
ProbeManager(sklearn_model, kwargs, compute_metrics)labels and step_type ("fit" / "predict") to TensorDict"transformer.h.(0|5|10).mlp$"additional_keys=["labels", "step_type"] to Probingin_keys=["input_ids"], out_keys=["logits"]hooked(train_td) then hooked(test_td); read manager.fit_metrics, manager.predict_metricsfrom tdhook.latent.probing import Probing, ProbeManager
manager = ProbeManager(LogisticRegression, {}, compute_metrics)
with Probing(
"transformer.h.(0|5|10).mlp$",
manager.probe_factory,
additional_keys=["labels", "step_type"],
).prepare(model, in_keys=["input_ids"], out_keys=["logits"]) as hooked:
hooked(train_td) # step_type="fit"
hooked(test_td) # step_type="predict"
Goal: Inspect or patch activations at specific modules without high-level methods.
Checklist:
hooked_module.run(data, grad_enabled=...) for low-level controlrun.save("path.to.module") to capturerun.set("path.to.module", tensor) to overrideproxy.resolve() after the run to read cached tensorsrun.save_grad(...), run.set_grad(...)with hooked_module.run(data, grad_enabled=True) as run:
run.save("layers.5.mlp")
run.set("layers.5.attn", override_tensor)
cached = run.get("layers.5.mlp", cache_key="my_key").resolve()
# Attribution (needs baseline for IG)
with IntegratedGradients(init_attr_targets=init_fn).prepare(model) as hooked:
td = hooked(TensorDict({"input": x, ("baseline", "input"): baseline}))
attr = td.get(("attr", "input"))
# Steering: extract and apply
with ActivationAddition(["layer.7.mlp"]).prepare(model) as hooked:
steer = hooked(TensorDict({("positive","input"): pos, ("negative","input"): neg})).get(("steer","layer.7.mlp"))
with SteeringVectors(modules_to_steer=["layer.7.mlp"], steer_fn=lambda k, o: o + scale*steer).prepare(model) as hooked:
out = hooked(TensorDict({"input": x}))
| Key | Purpose |
|---|---|
("baseline", "input") | Attribution baseline |
("positive", "input"), ("negative", "input") | Steering extraction |
("patched", "input") | Patching source |
("attr", key) | Attribution output |
labels, step_type | Probing (via additional_keys) |
| Issue | Solution |
|---|---|
KeyError on TensorDict | Use tuple keys: ("attr", "input"), ("baseline", "input") |
| HuggingFace model fails | Pass in_keys=["input_ids"], out_keys=["logits"] to prepare() |
BilinearProbeManager shape mismatch | Call manager.before_all() before forwards, manager.after_all() after |
| Module path not found | Use relative=False or correct regex; see api.md Module Path Resolution |
Probing step_type missing | Add additional_keys=["labels", "step_type"] and pass both in TensorDict |
| IG baseline wrong shape | Ensure ("baseline", "input") matches "input" shape and device |
See references/issues.md for more patterns.
tdhook.metrics) before trusting heatmaps| Need | Primary Class | Key Params |
|---|---|---|
| Gradient w.r.t. input | Saliency | init_attr_targets, input_modules |
| Path-integral attribution | IntegratedGradients | init_attr_targets, n_steps, baseline in TensorDict |
| Channel-weighted spatial | GradCAM | modules_to_attribute (path → DimsConfig) |
| Extract contrast vector | ActivationAddition | module list, ("positive","input"), ("negative","input") |
| Apply precomputed vector | SteeringVectors | steer_fn(module_key, output) |
| Replace activations | ActivationPatching | patch_fn, ("patched","input") |
| Train classifiers on reps | Probing | key_pattern, probe_factory, additional_keys |
| Zero params by importance | Pruning | importance_callback, amount_to_prune |
| Insert modules inline | Adapters | adapters={path: (adapter, source, target)} |
pip install tdhook tensordict torch
For optional captum-based attribution or sklearn probing:
pip install captum scikit-learn
Colab dev setup: see tutorials.md Setup section.
("baseline", "input"); use zeros or neutral inputstep_type in TensorDict ("fit" for training, "predict" for eval)("attr", "input") not "attr/input" for TensorDicttransformer.h.5.mlp vs layers.5.mlp depends on model| File | Contents |
|---|---|
| references/README.md | Overview |
| references/api.md | Full API: HookedModule, methods by category |
| references/tutorials.md | Use-case tutorials |
| references/issues.md | GitHub issues & solutions |
| references/releases.md | Version history |
| references/file_structure.md | Codebase navigation |
Official docs: Home · Methods · Tutorials · API Reference