Guide for building Graph Neural Networks with PyTorch Geometric (PyG). Use this skill whenever the user asks about graph neural networks, GNNs, node classification, link prediction, graph classification, message passing networks, heterogeneous graphs, neighbor sampling, or any task involving torch_geometric / PyG. Also trigger when you see imports from torch_geometric, or the user mentions graph convolutions (GCN, GAT, GraphSAGE, GIN), graph data structures, or working with relational/network data. Even if the user just says 'graph learning' or 'geometric deep learning', use this skill.
PyG is the standard library for Graph Neural Networks built on PyTorch. It provides data structures for graphs, 60+ GNN layer implementations, scalable mini-batch training, and support for heterogeneous graphs.
Install: uv add torch_geometric (or uv pip install torch_geometric; requires PyTorch). Optional: pyg-lib, torch-scatter, torch-sparse, torch-cluster for accelerated ops.
Data and HeteroDataA graph lives in a Data object. The key attributes:
from torch_geometric.data import Data
data = Data(
x=node_features, # [num_nodes, num_node_features]
edge_index=edge_index, # [2, num_edges] — COO format, dtype=torch.long
edge_attr=edge_features, # [num_edges, num_edge_features]
y=labels, # node-level [num_nodes, *] or graph-level [1, *]
pos=positions, # [num_nodes, num_dimensions] (for point clouds/spatial)
)
edge_index format is critical: it's a [2, num_edges] tensor where edge_index[0] = source nodes, edge_index[1] = target nodes. It is NOT a list of tuples. If you have edge pairs as rows, transpose and call .contiguous():
# If edges are [[src1, dst1], [src2, dst2], ...] — transpose first:
edge_index = edge_pairs.t().contiguous()
For undirected graphs, include both directions: edge (0,1) needs both [0,1] and [1,0] in edge_index.
For heterogeneous graphs, use HeteroData — see the Heterogeneous Graphs section below.
PyG bundles many standard datasets that auto-download and preprocess:
from torch_geometric.datasets import Planetoid, TUDataset
# Single-graph node classification (Cora, Citeseer, Pubmed)
dataset = Planetoid(root='./data', name='Cora')
data = dataset[0] # single graph with train/val/test masks
# Multi-graph classification (ENZYMES, MUTAG, IMDB-BINARY, etc.)
dataset = TUDataset(root='./data', name='ENZYMES')
# dataset[0], dataset[1], ... are individual graphs
Common datasets by task:
Transforms preprocess or augment graph data, analogous to torchvision transforms:
import torch_geometric.transforms as T
# Common transforms
T.NormalizeFeatures() # Row-normalize node features to sum to 1
T.ToUndirected() # Add reverse edges to make graph undirected
T.AddSelfLoops() # Add self-loop edges
T.KNNGraph(k=6) # Build k-NN graph from point cloud positions
T.RandomJitter(0.01) # Random noise augmentation on positions
T.Compose([...]) # Chain multiple transforms
# Apply as pre_transform (once, saved to disk) or transform (every access)
dataset = ShapeNet(root='./data', pre_transform=T.KNNGraph(k=6),
transform=T.RandomJitter(0.01))
The fastest way to build a GNN — stack conv layers from torch_geometric.nn:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
Important: PyG conv layers do NOT include activation functions — apply them yourself after each layer. This is by design for flexibility.
Pick based on your task and graph structure:
| Layer | Best for | Key idea |
|---|---|---|
GCNConv | Homogeneous, semi-supervised node classification | Spectral-inspired, degree-normalized aggregation |
GATConv / GATv2Conv | When neighbor importance varies | Attention-weighted messages |
SAGEConv | Large graphs, inductive settings | Sampling-friendly, learnable aggregation |
GINConv | Graph classification, maximizing expressiveness | As powerful as WL test |
TransformerConv | Rich edge features, complex interactions | Multi-head attention with edge features |
EdgeConv | Point clouds, dynamic graphs | MLP on edge features (x_i, x_j - x_i) |
RGCNConv | Heterogeneous with many relation types | Relation-specific weight matrices |
HGTConv | Heterogeneous graphs | Type-specific attention |
All conv layers accept (x, edge_index) at minimum. Many also accept edge_attr for edge features.
Use -1 for input channels to let PyG infer dimensions automatically — especially useful for heterogeneous models:
conv = SAGEConv((-1, -1), 64) # Input dims inferred on first forward pass
# Initialize lazy modules:
with torch.no_grad():
out = model(data.x, data.edge_index)
For common architectures, PyG provides ready-made model classes:
from torch_geometric.nn import GraphSAGE, GCN, GAT, GIN
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2,
)
To implement a novel GNN layer, subclass MessagePassing. The framework is:
propagate() orchestrates the message passingmessage() defines what info flows along each edge (the phi function)aggregate() combines messages at each node (sum/mean/max)update() transforms the aggregated result (the gamma function)from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class MyConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "add", "mean", or "max"
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Pre-processing before message passing
x = self.lin(x)
# Start message passing
return self.propagate(edge_index, x=x)
def message(self, x_j):
# x_j: features of source nodes for each edge [num_edges, features]
# The _j suffix auto-indexes source nodes, _i indexes target nodes
return x_j
The _i / _j convention: any tensor passed to propagate() can be auto-indexed by appending _i (target/central node) or _j (source/neighbor node) in the message() signature. So if you pass x=... to propagate, you can access x_i and x_j in message().
Read references/message_passing.md for the full GCN and EdgeConv implementation examples.
# Full-batch training on a single graph (e.g., Cora)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Evaluation
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
Multiple graphs — use DataLoader for mini-batching and global pooling to get graph-level representations:
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
loader = DataLoader(dataset, batch_size=32, shuffle=True)
class GraphClassifier(torch.nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, hidden_ch)
self.lin = torch.nn.Linear(hidden_ch, out_ch)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
x = global_mean_pool(x, batch) # [num_graphs_in_batch, hidden_ch]
return self.lin(x)
# Training loop
for data in loader:
out = model(data.x, data.edge_index, data.batch)
loss = F.cross_entropy(out, data.y)
PyG's DataLoader batches multiple graphs by creating block-diagonal adjacency matrices. The batch tensor maps each node to its graph index. Pooling ops (global_mean_pool, global_max_pool, global_add_pool) use this to aggregate per-graph.
Split edges into train/val/test, use negative sampling:
from torch_geometric.transforms import RandomLinkSplit
transform = RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)
# Encode nodes, then score edges
z = model.encode(train_data.x, train_data.edge_index)
# Positive edges
pos_score = (z[train_data.edge_label_index[0]] * z[train_data.edge_label_index[1]]).sum(dim=1)
Read references/link_prediction.md for the complete link prediction guide: GAE/VGAE autoencoders, full training loops, LinkNeighborLoader for large graphs, heterogeneous link prediction, and evaluation metrics.
For graphs that don't fit in GPU memory, use neighbor sampling via NeighborLoader:
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[15, 10], # Sample 15 neighbors in hop 1, 10 in hop 2
batch_size=128, # Number of seed nodes per batch
input_nodes=data.train_mask, # Which nodes to sample from
shuffle=True,
)
for batch in train_loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
# Only use first batch_size nodes for loss (these are the seed nodes)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
Key points about NeighborLoader:
num_neighbors list length should match GNN depth (number of message passing layers)batch.batch_size nodes in the outputbatch.n_id maps relabeled indices back to original node IDsData and HeteroDataLinkNeighborLoader insteadOther scalability options: ClusterLoader (ClusterGCN), GraphSAINTSampler, ShaDowKHopSampler. For multi-GPU training, DDP, PyTorch Lightning integration, and torch.compile support, read references/scaling.md.
For graphs with multiple node and edge types (social networks, knowledge graphs, recommendation):
from torch_geometric.data import HeteroData
data = HeteroData()
# Node features — indexed by node type string
data['user'].x = torch.randn(1000, 64)
data['movie'].x = torch.randn(500, 128)
# Edge indices — indexed by (src_type, edge_type, dst_type) triplet
data['user', 'rates', 'movie'].edge_index = torch.randint(0, 500, (2, 3000))
data['user', 'follows', 'user'].edge_index = torch.randint(0, 1000, (2, 5000))
# Access convenience dicts
data.x_dict # {'user': tensor, 'movie': tensor}
data.edge_index_dict # {('user','rates','movie'): tensor, ...}
data.metadata() # ([node_types], [edge_types])
1. Auto-convert with to_hetero() — write a homogeneous model, convert automatically:
from torch_geometric.nn import SAGEConv, to_hetero
class GNN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
model = GNN(64, dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
# Now accepts dicts:
out = model(data.x_dict, data.edge_index_dict)
Use (-1, -1) for bipartite input channels (source, target may differ). Lazy init handles the rest.
2. HeteroConv wrapper — different conv per edge type:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv
conv = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
('paper', 'rev_writes', 'author'): GATConv((-1, -1), 64, add_self_loops=False),
}, aggr='sum')
3. Native heterogeneous operators like HGTConv:
from torch_geometric.nn import HGTConv
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(), num_heads=4)
Important for heterogeneous graphs:
T.ToUndirected() to add reverse edge types for bidirectional message flowadd_self_loops in bipartite conv layers (different source/dest types) — use skip connections instead: conv(x, edge_index) + lin(x)input_nodes as ('node_type', mask) tuplenum_neighbors can be a dict keyed by edge type for fine-grained controlRead references/heterogeneous.md for complete examples including training loops and NeighborLoader usage with heterogeneous graphs.
For loading your own data into PyG:
Data objects directly and pass a list to DataLoaderInMemoryDataset — override raw_file_names, processed_file_names, download(), process()Dataset — also override len() and get()Data or HeteroDatafrom_networkx(G) converts a NetworkX graph directlyfrom_scipy_sparse_matrix(adj) extracts edge_indexRead references/custom_datasets.md for complete examples with all patterns, CSV loading with encoders, and the MovieLens walkthrough.
PyG provides torch_geometric.explain for interpreting GNN predictions:
from torch_geometric.explain import Explainer, GNNExplainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
explanation = explainer(data.x, data.edge_index, index=10)
explanation.visualize_graph() # Important subgraph
explanation.visualize_feature_importance(top_k=10) # Feature importance
Available algorithms: GNNExplainer (optimization-based), PGExplainer (parametric, trained), CaptumExplainer (gradient-based via Captum), AttentionExplainer (attention weights). Works for both homogeneous and heterogeneous graphs.
Read references/explainability.md for all algorithms, heterogeneous explanations, evaluation metrics, and PGExplainer training.
[2, num_edges], not [num_edges, 2]. Transpose if needed.add_self_loops=True when source and dest node types differ. Use skip connections instead.batch.batch_size nodes are your seed nodes. Slice predictions and labels accordingly.edge_index, or use T.ToUndirected().-1 input channels need one forward pass with torch.no_grad() before training to initialize parameters.global_mean_pool(x, batch) (not manual reshape) to aggregate node features to graph-level.len(num_neighbors) equal to the number of GNN layers. More hops than layers wastes compute; fewer means wasted model capacity.