TorchDrug is a PyTorch-based machine learning platform for drug discovery. Use it for graph-based molecular representation learning, molecular property prediction (ADMET, activity), retrosynthesis prediction, drug-target interaction (DTI) modeling, and pretraining on large molecular datasets. Provides GNN layers (GraphConv, GAT, MPNN), pretrained models, and benchmark datasets in a unified PyTorch-compatible API.
TorchDrug is a comprehensive machine learning framework for drug discovery built on PyTorch. It provides graph-based molecular representations (atoms as nodes, bonds as edges), a library of graph neural network (GNN) architectures, benchmark datasets, and pretrained models for tasks including molecular property prediction, drug-target interaction, retrosynthesis, and generative molecular design. TorchDrug integrates with PyTorch Lightning and standard ML tooling, making it accessible to both computational chemists and ML practitioners.
torchdrug, torch, torch-geometric, rdkitpip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118
pip install torch-geometric
pip install torchdrug
pip install rdkit
import torch
from torchdrug import data, datasets, models, tasks, core
# Load a benchmark dataset and train a GNN for property prediction
dataset = datasets.BBBP("~/data/bbbp", node_feature="default", edge_feature="default")
print(f"Dataset: {len(dataset)} molecules, task: BBBP (blood-brain barrier penetration)")
# Define model: GIN encoder
model = models.GIN(
input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256],
short_cut=True,
batch_norm=True,
concat_hidden=True,
)
# Define training task
task = tasks.PropertyPrediction(
model, task=dataset.tasks,
criterion="bce", metric=("auprc", "auroc"),
)
# Train with the Solver
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0])
solver.train(num_epoch=50)
print("Training complete")
TorchDrug represents molecules as typed graphs. data.Molecule is the core data structure.
from torchdrug import data
from rdkit import Chem
# Create a molecule from SMILES
smiles = "CC(=O)Oc1ccccc1C(=O)O" # aspirin
mol = data.Molecule.from_smiles(smiles, node_feature="default", edge_feature="default")
print(f"Atoms: {mol.num_node}")
print(f"Bonds: {mol.num_edge}")
print(f"Node feature dim: {mol.node_feature.shape}") # (N_atoms, feature_dim)
print(f"Edge feature dim: {mol.edge_feature.shape}") # (N_bonds*2, feature_dim)
# Convert a MoleculeNet / custom SMILES list to a dataset
from torchdrug import data as td_data
import pandas as pd
df = pd.read_csv("compounds.csv") # columns: smiles, label
molecules = [td_data.Molecule.from_smiles(s) for s in df["smiles"] if s]
print(f"Loaded {len(molecules)} valid molecules")
# Check feature dimensions
print(f"Default atom feature dim: {molecules[0].node_feature.shape[1]}")
TorchDrug provides GIN, RGCN, GraphSAGE, GAT, MPNN, AttentiveFP, and more.
from torchdrug import models, datasets
dataset = datasets.ESOL("~/data/esol", node_feature="default", edge_feature="default")
feature_dim = dataset.node_feature_dim
# Graph Isomorphism Network (GIN) — good default for property prediction
gin = models.GIN(
input_dim=feature_dim,
hidden_dims=[256, 256, 256],
short_cut=True,
batch_norm=True,
concat_hidden=True, # concatenate layer representations
)
print(f"GIN output_dim: {gin.output_dim}")
from torchdrug import models
# Message Passing Neural Network (MPNN) — captures edge features
mpnn = models.MPNN(
input_dim=feature_dim,
hidden_dim=256,
edge_input_dim=16, # edge feature dimension
num_layer=4,
num_gru_layer=1,
)
# Graph Attention Network (GAT) — attention-weighted neighbors
gat = models.GAT(
input_dim=feature_dim,
hidden_dims=[256, 256],
edge_input_dim=16,
num_head=8,
batch_norm=True,
)
print(f"MPNN output_dim: {mpnn.output_dim}, GAT output_dim: {gat.output_dim}")
Wrap a GNN encoder with a prediction head for classification or regression.
import torch
from torchdrug import datasets, models, tasks, core
# Regression example: ESOL aqueous solubility
dataset = datasets.ESOL("~/data/esol", node_feature="default", edge_feature="default")
train, val, test = dataset.split()
print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")
model = models.GIN(
input_dim=dataset.node_feature_dim,
hidden_dims=[300, 300],
short_cut=True,
batch_norm=True,
concat_hidden=True,
)
task = tasks.PropertyPrediction(
model,
task=dataset.tasks, # list of property names
criterion="mse", # "mse" for regression, "bce" for classification
metric=("mae", "rmse"),
num_mlp_layer=2,
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3, weight_decay=1e-5)
solver = core.Engine(task, train, val, test, optimizer,
batch_size=32, log_interval=50)
solver.train(num_epoch=100)
# Evaluate on test set
metrics = solver.evaluate("test")
print(f"Test RMSE: {metrics['rmse']:.4f}")
print(f"Test MAE: {metrics['mae']:.4f}")
Predict binding affinity between molecules and protein sequences.
from torchdrug import datasets, models, tasks, core
import torch
# Load a DTI dataset (e.g., Davis kinase binding affinities)
dataset = datasets.Davis("~/data/davis",
mol_node_feature="default",
mol_edge_feature="default")
train, val, test = dataset.split()
# Molecule encoder
mol_model = models.GIN(
input_dim=dataset.mol_node_feature_dim,
hidden_dims=[256, 256],
short_cut=True,
batch_norm=True,
concat_hidden=True,
)
# Protein encoder (CNN on sequence)
prot_model = models.ProteinCNN(
input_dim=21, # amino acid vocabulary size
hidden_dims=[128, 128, 128],
kernel_size=3,
)
task = tasks.InteractionPrediction(
mol_model, prot_model,
task=dataset.tasks,
criterion="mse",
metric=("rmse", "pearsonr"),
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train, val, test, optimizer,
batch_size=64, log_interval=100)
solver.train(num_epoch=50)
metrics = solver.evaluate("test")
print(f"DTI Test RMSE: {metrics['rmse']:.4f}")
print(f"DTI Pearson r: {metrics['pearsonr']:.4f}")
Predict one-step retrosynthetic disconnections to find plausible building blocks.
from torchdrug import datasets, models, tasks, core
import torch
# USPTO-50k retrosynthesis benchmark
dataset = datasets.USPTO50k("~/data/uspto50k",
as_synthon=False,
atom_feature="default",
bond_feature="default")
train, val, test = dataset.split()
# Reaction-predicting GNN
model = models.RGCN(
input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256, 256],
num_relation=dataset.num_bond_type,
batch_norm=True,
)
task = tasks.CenterIdentification(
model,
feature=("graph", "atom", "bond"),
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train, val, test, optimizer,
batch_size=64, log_interval=100)
solver.train(num_epoch=50)
metrics = solver.evaluate("test")
print(f"Retrosynthesis top-1 accuracy: {metrics.get('accuracy', 'N/A')}")
Use TorchDrug's pretrained GNN representations as features for downstream tasks.
from torchdrug import models
# Load a GNN pretrained on ChEMBL with context-prediction self-supervised learning
pretrained_gin = models.GIN(
input_dim=39,
hidden_dims=[300, 300, 300, 300, 300],
short_cut=False,
batch_norm=True,
concat_hidden=False,
)
# Load pretrained weights (download from TorchDrug model zoo)
import torch
ckpt = torch.load("gin_supervised_contextpred.pth", map_location="cpu")
pretrained_gin.load_state_dict(ckpt)
pretrained_gin.eval()
print(f"Pretrained GIN loaded, output_dim={pretrained_gin.output_dim}")
print("Use as encoder in PropertyPrediction task for transfer learning")
Molecules are represented as attributed graphs: atoms are nodes with features (atomic number, degree, charge, aromaticity) and bonds are edges with features (bond type, ring membership). All TorchDrug models operate on these graph representations rather than SMILES strings or fingerprints.
from torchdrug import data
mol = data.Molecule.from_smiles("c1ccccc1") # benzene
print(f"Atoms: {mol.num_node}, Bonds: {mol.num_edge // 2}")
print(f"Atom features (first atom): {mol.node_feature[0]}")
TorchDrug uses a core.Engine (also called Solver) to handle the training loop, logging, checkpointing, and multi-GPU setup. Pass the task, train/val/test splits, and optimizer to the Engine rather than writing a manual training loop.
# Engine handles: batch iteration, loss backward, logging, checkpointing
solver = core.Engine(
task, train_set, valid_set, test_set, optimizer,
batch_size=32,
log_interval=100,
gpus=[0, 1], # multi-GPU support
)
solver.train(num_epoch=100)
solver.save("checkpoint.pth")
Goal: Train a GIN model to predict blood-brain barrier penetration from SMILES, then predict on new compounds.
import torch
import pandas as pd
from torchdrug import data, datasets, models, tasks, core
# 1. Load dataset
dataset = datasets.BBBP("~/data/bbbp", node_feature="default", edge_feature="default")
train, val, test = dataset.split()
print(f"BBBP: {len(train)} train, {len(val)} val, {len(test)} test molecules")
# 2. Build model
model = models.GIN(
input_dim=dataset.node_feature_dim,
hidden_dims=[256, 256],
short_cut=True, batch_norm=True, concat_hidden=True,
)
task = tasks.PropertyPrediction(
model, task=dataset.tasks,
criterion="bce", metric=("auroc", "auprc"),
)
# 3. Train
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train, val, test, optimizer,
batch_size=32, log_interval=50)
solver.train(num_epoch=100)
metrics = solver.evaluate("test")
print(f"Test AUROC: {metrics['auroc']:.4f}")
# 4. Predict on new SMILES
new_smiles = ["CC(=O)Oc1ccccc1C(=O)O", "c1ccc(cc1)N"]
task.eval()
with torch.no_grad():
for smi in new_smiles:
mol = data.Molecule.from_smiles(smi, node_feature="default", edge_feature="default")
batch = data.Batch.from_data_list([mol])
pred = task.predict(batch)
print(f" {smi}: BBB penetration probability = {pred.sigmoid().item():.3f}")
Goal: Simultaneously predict 12 toxicity endpoints using a shared GNN encoder.
import torch
from torchdrug import datasets, models, tasks, core
# Tox21: 12 toxicity assays, multi-label classification
dataset = datasets.Tox21("~/data/tox21", node_feature="default", edge_feature="default")
train, val, test = dataset.split()
print(f"Tox21 tasks ({len(dataset.tasks)}): {dataset.tasks}")
model = models.GIN(
input_dim=dataset.node_feature_dim,
hidden_dims=[300, 300, 300],
short_cut=True, batch_norm=True, concat_hidden=True,
)
# Multi-task: one output head per toxicity assay
task = tasks.PropertyPrediction(
model, task=dataset.tasks,
criterion="bce",
metric=("auroc",),
num_mlp_layer=2,
)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train, val, test, optimizer, batch_size=64)
solver.train(num_epoch=100)
metrics = solver.evaluate("test")
for name, val_score in metrics.items():
print(f" {name}: {val_score:.4f}")
| Parameter | Module | Default | Range / Options | Effect |
|---|---|---|---|---|
hidden_dims | GIN/MPNN/GAT | [256, 256] | list of int | Width and depth of GNN layers |
short_cut | GIN | False | True, False | Add residual connection between layers |
batch_norm | GIN/MPNN | False | True, False | Apply batch normalization after each layer |
concat_hidden | GIN | False | True, False | Concatenate all layer outputs as final representation |
num_mlp_layer | PropertyPrediction | 1 | 1–4 | Depth of MLP prediction head after GNN |
criterion | PropertyPrediction | "mse" | "mse", "bce", "ce" | Loss function: regression, binary/multi-label classification |
batch_size | Engine | 32 | 8–512 | Training batch size |
Use concat_hidden=True for GIN on small datasets: Concatenating all layer outputs provides a richer molecular representation and often improves performance when training data is limited (<10,000 molecules).
Apply batch_norm=True for training stability: Batch normalization reduces sensitivity to learning rate and initialization, especially with deep GNNs (3+ layers).
Start with pretrained GNN weights for small datasets: TorchDrug's model zoo provides GINs pretrained on ChEMBL via self-supervised learning. Fine-tuning from these outperforms random initialization on datasets <1,000 molecules.
Validate on scaffold splits, not random splits: Random train/test splits overestimate generalization because structurally similar molecules appear in both sets. Use dataset.split(test_scaffold_ratio=0.1) for more realistic evaluation.
Handle missing labels in multi-task datasets: Many MoleculeNet datasets (Tox21, SIDER) have missing assay values. TorchDrug's PropertyPrediction task handles NaN labels automatically, but verify that missing rates are not too high for rare assays.
When to use: Visualize a molecular library in embedding space or use GNN features in scikit-learn models.
import torch
import numpy as np
from torchdrug import data, models
model = models.GIN(input_dim=39, hidden_dims=[300, 300], concat_hidden=True)
model.eval()
smiles_list = ["CC(=O)O", "c1ccccc1", "CCN", "CC(=O)Oc1ccccc1C(=O)O"]
embeddings = []
with torch.no_grad():
for smi in smiles_list:
mol = data.Molecule.from_smiles(smi, node_feature="default")
batch = data.Batch.from_data_list([mol])
graph_feat = model(batch, batch.node_feature.float())["graph_feature"]
embeddings.append(graph_feat.squeeze(0).numpy())
emb_matrix = np.stack(embeddings)
print(f"Embedding matrix: {emb_matrix.shape}") # (N_mols, embed_dim)
When to use: Training on proprietary assay data rather than benchmark datasets.
from torchdrug import data
import torch
class CustomDataset(data.MoleculeDataset):
def __init__(self, csv_path, smiles_col="smiles", label_col="activity"):
import pandas as pd
df = pd.read_csv(csv_path).dropna(subset=[smiles_col])
smiles_list = df[smiles_col].tolist()
targets = df[label_col].tolist()
self.load_smiles(smiles_list, {"activity": targets},
node_feature="default", edge_feature="default")
self.tasks = ["activity"]
dataset = CustomDataset("assay_data.csv", smiles_col="smiles", label_col="pIC50")
print(f"Custom dataset: {len(dataset)} molecules")
| Problem | Cause | Solution |
|---|---|---|
ImportError: torchdrug | Package not installed | pip install torchdrug after installing PyTorch |
CUDA error: device-side assert | Label dtype mismatch | Ensure regression labels are float, classification labels are long |
| Poor test metrics with small dataset | Overfitting | Use pretrained weights, add dropout, or reduce model depth |
KeyError: task name in dataset.tasks | Task name mismatch | Print dataset.tasks to see exact task names; pass the same list to PropertyPrediction |
RuntimeError: Expected all tensors on same device | Mixed CPU/GPU tensors | Use solver = core.Engine(..., gpus=[0]) to ensure consistent device placement |
| Slow training | CPU-only mode | Install CUDA-compatible PyTorch; set gpus=[0] in Engine |
| Missing assay values cause NaN loss | Dataset has missing labels | Set criterion="bce" — TorchDrug masks NaN labels during loss computation |
rdkit — molecular fingerprints and cheminformatics preprocessing before TorchDrugdiffdock — structure-based docking complementary to TorchDrug's ligand-based prediction