CS224W Lecture 16

LLM + GNN

Language models know that "Aspirin" is a drug. GNNs know that Aspirin is three hops from migraines in the biomedical graph. Neither alone is enough. Together, they're remarkably powerful.

Prerequisites: Transformer attention + GNN message passing. That's it.
10
Chapters
5+
Simulations
0
Assumed Knowledge

Chapter 0: Why Combine?

You're building a drug repurposing system. You have a biomedical knowledge graph with 100,000 entities: drugs, genes, diseases, proteins. Each entity has a name. Some have long text descriptions. The graph has 5 million edges representing known interactions.

You try two approaches. First, train a GNN using one-hot entity IDs as node features. It learns graph topology brilliantly — it knows that Aspirin is three hops from migraines via COX2 — but it doesn't know that Aspirin's text description mentions "anti-inflammatory" and "analgesic," which would help identify which other diseases it might treat.

Second, feed all entity text descriptions to an LLM. It understands that "Aspirin" and "ibuprofen" are both NSAIDs with similar mechanisms. But it doesn't know the structure of the graph — it doesn't know that Aspirin is connected to 47 diseases in your KG while a different drug connects to only 3.

Both approaches use only half the information. The graph structure is the missing context for the LLM. The rich text semantics are the missing features for the GNN. Combining them gets you both.

What Each Architecture Brings

LLMs excel at: understanding natural language, zero-shot generalization to unseen entities (from text), reasoning over entity descriptions, capturing semantic similarity ("Aspirin" ≈ "ibuprofen" because their descriptions overlap), and leveraging massive pre-training on world knowledge.
GNNs excel at: capturing multi-hop structural patterns ("friends of friends"), aggregating information from large neighborhoods efficiently, handling graph-specific tasks (link prediction, community detection), and operating on graphs too large for quadratic attention.
LLM vs. GNN: What Each Sees

Click a node to see what the LLM "sees" (text description) vs. what the GNN "sees" (graph neighborhood). The combination uses both views.

A paper entity in a citation graph has a 200-word abstract. What information does a GNN using one-hot node IDs fail to capture that an LLM could provide?

Chapter 1: LLM as Feature Encoder

The simplest combination: use a frozen LLM to encode each node's text description into a fixed embedding vector, then use that vector as the node feature for a downstream GNN. The LLM is a preprocessing step. It runs once. The GNN trains on the resulting features.

This works because modern LLMs (BERT, Sentence-BERT, LLaMA) are excellent sentence encoders. Given the text "aspirin: a nonsteroidal anti-inflammatory drug used to treat pain, fever, and inflammation," a BERT encoder produces a 768-dimensional vector that captures the semantic content. Two drugs with similar descriptions will have similar vectors — a richer initialization than random or one-hot.

Why freeze the LLM? Fine-tuning an LLM on a graph task requires backpropagating through both the GNN and the LLM. Computationally expensive. More importantly, the LLM has billions of parameters trained on massive text data — fine-tuning on a small graph dataset risks catastrophic forgetting (the LLM "forgets" its general language knowledge while adapting to your graph).

Concrete Pipeline

Node text descriptions
Each node has a text string. "Aspirin: NSAID used for pain, fever, inflammation. Mechanism: inhibits COX enzymes."
Frozen LLM encoding
LLM([CLS] + text) → [N, 768] feature matrix. Runs once offline. Cached to disk. No gradients.
GNN training
Use the 768-dim LLM embeddings as node features x_i. Train GNN normally on graph task (node classification, link prediction). Only GNN weights are updated.
Output
GNN has combined text semantics (from LLM init) with structural context (from message passing). Better than either alone.
python
# Step 1: encode all node texts with a frozen LLM
from sentence_transformers import SentenceTransformer

encoder = SentenceTransformer('all-MiniLM-L6-v2')  # frozen
node_texts = ["Aspirin: NSAID...", "COX2: enzyme...", ...]
node_feats = encoder.encode(node_texts, convert_to_tensor=True)
# node_feats: [N, 384] — one vector per node, from text alone

# Step 2: train GNN using LLM features as input
class GraphSAGE(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SAGEConv(384, 256)  # input dim = LLM embed dim
        self.conv2 = SAGEConv(256, num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))  # aggregate LLM features
        return self.conv2(x, edge_index)       # predict

model = GraphSAGE()
# Train model on (node_feats, edge_index, labels). LLM is never called again.
Which LLM to use? Sentence-BERT variants (all-MiniLM, all-mpnet) are fast and often sufficient. For domain-specific graphs (biomedical, legal, code), domain-specific LLMs (PubMedBERT, LegalBERT, CodeBERT) dramatically outperform general-purpose encoders. The LLM's pre-training domain should match the graph's domain.
In the "LLM as frozen feature encoder" approach, what is updated during GNN training?

Chapter 2: GNN as Structure Encoder

Now flip the roles. The GNN is the preprocessing step. It runs over the graph and computes structural embeddings for each node: where is this node in the graph topology? How central is it? What does its neighborhood look like? These structural features are then injected into an LLM as additional context.

This approach is motivated by a core LLM limitation: LLMs process text sequences. They have no native ability to represent "this entity is the hub of a star subgraph" or "these two entities are 4 hops apart." Graph structure is a different inductive bias from language. The GNN bridges this gap by converting graph topology into a fixed-size vector that the LLM can process.

What Structural Embeddings Capture

Local structure: Node degree, clustering coefficient, triangle counts. A node with degree 500 is a hub; degree 2 is a leaf. These features change how you should interpret the node's text (a hub paper is a seminal work; a leaf paper is niche).
Relative position: The structural embedding of node A, computed relative to query node Q, encodes how far A is from Q and through what types of paths. "A is 2 hops from Q via relation type R" is structural context that changes how Q should attend to A.
Global position: Graph autoencoder embeddings or spectral embeddings (Laplacian eigenvectors) place nodes in a space where nearby positions mean similar graph-structural roles. Node A and B with similar spectral embeddings are "structurally similar" even if far apart in the graph.
Structural Embedding Visualizer

Select a structural feature type. Watch how it assigns different "structural descriptions" to nodes. The GNN computes these as vectors; the LLM sees them as additional input tokens.

Why can't a standard LLM reason about "node X is a hub with 500 connections" just from the node's text description?

Chapter 3: LLM + GNN Pipeline (SHOWCASE)

The showcase architecture: an end-to-end pipeline where an LLM encodes node text into semantic features, a GNN aggregates those features through the graph structure, and the combined representation is used for prediction. This is the architecture of systems like GIANT, TAPE, and GraphGPT.

Watch the data flow through each stage. Every tensor shape is shown. Every operation is motivated. By the end, you'll understand exactly why each component exists and what would break without it.

Input: Text-Attributed Graph
N=1000 nodes, each with text description. E=5000 edges. Labels for 200 nodes. Task: classify remaining 800 nodes.
LLM encoding stage
LLM Encoding: text → [N, 768]
Each node's text description fed to BERT → [CLS] token → 768-dim embedding. Shape: [1000, 768]. Frozen. Runs once.
dimension reduction
Linear Projection: [N, 768] → [N, 128]
Lightweight trainable projection to match GNN hidden dim. Prevents the GNN from being dominated by 768-dim LLM features.
GNN aggregation stage
GNN: [N, 128] + edges → [N, 128]
2-layer GraphSAGE. Each node aggregates its neighbors' projected LLM features. Captures "what kinds of things are my neighbors?" in structured form.
prediction
Classifier: [N, 128] → [N, C]
Linear layer → class probabilities. Cross-entropy loss on labeled nodes. Only projection + GNN + classifier weights are trained.
Full LLM+GNN Pipeline — Interactive

Watch data flow through the pipeline. Click "Step" to advance one stage at a time. Each stage shows what's computed and the tensor shape at that point. Click a node to highlight its specific path through the pipeline.

Graph density 0.4
Ablation results (TAPE, He et al. 2023): on ogbn-arxiv (citation graph with paper abstracts). GNN only (one-hot features): 71.7% accuracy. LLM only (no graph): 73.4%. LLM features + GNN: 76.5%. Fine-tuned LLM + GNN: 77.9%. Each addition provides a consistent lift — the LLM and GNN are genuinely complementary.
In the LLM+GNN pipeline, why project LLM embeddings from 768 to 128 dimensions before the GNN?

Chapter 4: Joint Training vs Cascade

Chapter 1's approach (LLM encoder → frozen → GNN trains) is called a cascade: the two components are separate, one feeds the other, no gradients flow between them. Chapter 3's full pipeline can also be a cascade, or it can be a joint training setup where gradients from the GNN task flow back into the LLM to fine-tune it for graph-specific features.

This is one of the most important design decisions in LLM+GNN systems. The tradeoff is sharp.

Cascade (Frozen LLM)

Pros: Fast. Cheap. The LLM runs once at preprocessing time. The 768-dim node features are cached. GNN training is the only expensive step. No risk of catastrophic forgetting. Works with any LLM you can't fine-tune (API-only models).
Cons: The LLM encodes text without seeing the graph. It doesn't know that "Aspirin" should be encoded differently for a drug interaction graph vs. a historical document graph. The LLM's generic text features may not be optimally informative for the specific graph task. There's a fixed mismatch between what the LLM optimized for (general language modeling) and what the GNN needs (graph task features).

Joint Training (Fine-tuned LLM)

Pros: The LLM adapts to the graph task. If classifying papers by research area, the LLM learns to emphasize method names and dataset names (strong area predictors) over author acknowledgments (weak predictors). The node features become task-specific, not generic. Consistently 1-3% better accuracy in benchmarks.
Cons: Expensive. You backpropagate through the LLM on every training step. For BERT-base with 110M parameters, this is ~10× slower than cascade. Risk of catastrophic forgetting if learning rate is too high. Requires fitting the LLM in GPU memory during training — often requires model parallelism or gradient checkpointing.

The Middle Ground: LoRA Fine-tuning

LoRA (Low-Rank Adaptation) offers a practical compromise. Instead of updating all LLM weights, inject small low-rank adapter matrices into the LLM's attention layers. Only the adapters (typically <1% of total parameters) are trainable. The backbone LLM weights are frozen. Gradients flow through the adapters but not the backbone.

python
# LoRA: fine-tune LLM with <1% of parameters trainable
from peft import get_peft_model, LoraConfig

lora_config = LoraConfig(
    r=16,            # rank of the adapter matrices
    lora_alpha=32,    # scaling factor
    target_modules=["q_proj", "v_proj"],  # which layers to adapt
    lora_dropout=0.1,
)
llm = get_peft_model(base_llm, lora_config)
# llm now has ~0.5% trainable params (LoRA) + 99.5% frozen (backbone)
# Gradient flows only through LoRA adapters → fast, cheap, no catastrophic forgetting

# Use in LLM+GNN pipeline: LLM(+LoRA) → node_feats → GNN → task loss → backprop
optimizer = Adam([*llm.parameters(), *gnn.parameters()], lr=1e-4)
Why does cascade (frozen LLM) training produce slightly worse results than joint training, even though the GNN can still see the same graph structure?

Chapter 5: Graph-Augmented LLMs

Flip the architecture again. Instead of using the LLM to help the GNN, use the graph to help the LLM. The LLM remains the primary model — it generates text, answers questions, summarizes. But before the LLM processes a query, you augment its context with graph-derived information about the relevant entities.

This is graph-augmented generation, a graph-structured variant of retrieval-augmented generation (RAG). In standard RAG, you retrieve relevant text passages from a corpus and prepend them to the LLM's prompt. In graph-augmented generation, you retrieve relevant graph facts and paths, convert them to text, and prepend them.

Knowledge Graph + LLM for Question Answering

Question: "What disease does the drug that inhibits COX2 treat?" To answer, you need: (Drug, inhibits, COX2) to find the drug, then (Drug, treats, Disease) to find the disease. This is a 2-hop KG query. The LLM alone might guess "Aspirin treats pain" from parametric memory — but parametric memory can be wrong or outdated. Augmenting with the KG provides verified, updatable facts.

The augmentation pipeline: (1) parse the question to extract the query entities, (2) run a KG query to retrieve relevant subgraph paths, (3) verbalize the paths ("Aspirin inhibits COX2; COX2 causes Migraine; therefore Aspirin treats Migraine"), (4) prepend verbalized paths to the LLM prompt, (5) LLM answers with this verified context grounding its response.

Verbalization: Graphs to Text

Verbalization is the process of converting a graph subgraph into a natural language description the LLM can process. Simple template: "(entity1, relation, entity2)" → "entity1 [relation] entity2." More sophisticated: train a small seq2seq model to generate fluent descriptions of subgraph patterns.

python
# Graph-augmented LLM query
def graph_augmented_answer(question, kg, llm):
    # 1. Extract entities mentioned in question
    entities = entity_linker(question)   # ["Aspirin", "migraine"]

    # 2. Retrieve relevant KG subgraph (2-3 hop paths)
    paths = kg.shortest_paths(entities, max_hops=3)
    # paths: [("Aspirin","inhibits","COX2"), ("COX2","causedBy","Migraine")]

    # 3. Verbalize paths to text
    context = verbalize(paths)
    # "Aspirin inhibits COX2. COX2 is causally related to Migraine."

    # 4. Augment prompt with graph context
    prompt = f"""Knowledge graph context: {context}

Question: {question}
Answer based on the provided context:"""

    # 5. LLM generates grounded answer
    return llm.generate(prompt)
Why this beats pure parametric memory: LLM parametric knowledge is frozen at training time and can be wrong or outdated. KG facts are curated, verified, and updatable. A drug interaction updated in 2024 appears immediately in the KG; the LLM trained in 2023 doesn't know about it. Graph augmentation grounds the LLM in current, verified knowledge.
In graph-augmented LLM generation, why is "verbalization" (converting graph paths to text) necessary?

Chapter 6: LLM-Augmented GNNs

Beyond just using LLM embeddings as static node features, LLMs can help GNNs in more dynamic ways: generating pseudo-labels for unlabeled nodes, suggesting which edges should exist (edge prediction augmentation), and explaining GNN predictions in natural language. This chapter covers three LLM→GNN augmentation strategies that go beyond simple feature encoding.

LLM as Annotator (Pseudo-Labels)

Most graph datasets have very few labeled nodes. Training a GNN on 10 labels out of 10,000 nodes is hard. LLMs can help: given a node's text description, prompt the LLM to predict its label (e.g., paper topic category). These are pseudo-labels — noisy but much more numerous than human labels.

The catch: LLM pseudo-labels are biased by what the LLM knows (or doesn't know). A paper about "RLHF" might be mislabeled by a 2021 LLM because RLHF wasn't prominent then. And pseudo-labels are independent — the LLM doesn't see the graph structure when making them. A training scheme that trusts pseudo-labels for common classes but defers to the GNN for rare/ambiguous cases works better than naive pseudo-labeling.

LLM for Edge Prediction Augmentation

Real-world graphs are often missing edges that should exist. A citation graph is missing citations for papers published after the scraping date. A collaboration graph is missing co-authorships across institutions. Use an LLM to predict which edges are likely: given two node descriptions, prompt "are these likely to interact?" Add high-confidence predicted edges to the graph before GNN training.

LLM as Explainer

After the GNN makes a prediction, use the LLM to generate a natural language explanation. Pipeline: GNN identifies important subgraph (via GNNExplainer or attention weights) → verbalize the subgraph → prompt LLM to explain the prediction in terms of the verbalized subgraph → LLM outputs human-readable explanation.

python
# LLM-augmented pseudo-label generation
def generate_pseudolabels(unlabeled_nodes, llm, categories):
    pseudo = {}
    for node in unlabeled_nodes:
        prompt = f"""Paper abstract: {node.text}

Classify this paper into one of: {categories}
Respond with exactly one category name."""
        pred = llm.generate(prompt, max_tokens=10)
        confidence = llm.logprob(pred)   # token log-probability
        if confidence > threshold:
            pseudo[node.id] = pred       # only keep high-confidence predictions
    return pseudo

# Add pseudo-labels to training set; weight them lower than human labels
# loss = human_loss + lambda * pseudo_loss  (lambda < 1)
TAPE (He et al. 2023) result: uses LLM-generated explanations as additional node features. Rather than only using the paper abstract, it also uses the LLM's step-by-step reasoning about the paper's topic as an additional feature. This "explanation-as-feature" strategy improves node classification by ~1.5% on ogbn-arxiv — the LLM's chain-of-thought is informative as a feature, not just as output.
LLM-generated pseudo-labels are "noisy." In a training scheme combining human labels and pseudo-labels, how should they be weighted?

Chapter 7: Benchmarks and Results

How much does combining LLMs with GNNs actually help? Let's look at real numbers from key benchmarks. Numbers tell the story better than any abstract argument.

ogbn-arxiv: Citation Graph Node Classification

ogbn-arxiv has 169,343 arXiv papers. Task: classify each paper into one of 40 subject areas. Nodes have paper titles as text features. Edges are citation links.

MethodTest AccuracyKey design
GCN (BoW features)71.7%Bag-of-words node features, no LLM
BERT features + GCN (cascade)73.3%Frozen BERT embeddings as input to GCN
GIANT (LM fine-tuned on neighbors)75.9%LM trained to predict neighbor IDs from text
TAPE (LLM explanation features)76.5%Frozen LLaMA explanations + GNN
TAPE (fine-tuned LLM + GNN)77.9%Joint training with LoRA fine-tuning

ogbl-citation2: Link Prediction

Predict whether two papers will cite each other. Text features from abstracts significantly boost performance over graph-structure-only methods.

MethodMRRKey design
GraphSAGE (one-hot)82.6%No text features
GraphSAGE + BERT features87.3%Frozen BERT cascade
LLM + GNN (joint)90.1%Fine-tuned LLM + GNN, joint training
Key pattern: the gain from adding LLM features is larger when the graph is sparse (fewer edges per node). When graph structure is rich, the GNN can compensate for weak features. When the graph is sparse, node text becomes the primary signal — making LLM encoding critically important. Know your graph density before choosing your architecture.
Performance vs. Graph Density

Drag the slider to change graph edge density. Watch how LLM feature importance changes: sparse graphs need better text features; dense graphs can work with simpler features.

Graph density 0.30
On a dense graph (many edges per node), why does adding LLM features help less than on a sparse graph?

Chapter 8: Challenges

LLM+GNN systems work in practice — the benchmarks prove it. But they come with significant engineering challenges that don't appear in research papers. This chapter is honest about what's hard.

Scalability: The Context Length Problem

A node's "context" in a graph includes its k-hop neighborhood — potentially hundreds of connected entities. Serializing this neighborhood as text (for graph-augmented LLMs) quickly exceeds even the largest LLM context windows. A 2-hop neighborhood in ogbn-arxiv has ~50 papers. Each abstract is 200 words. 50 × 200 = 10,000 words. That's a very long context, and it grows quadratically with hop depth.

Current solutions: (1) truncate the neighborhood aggressively — only include the 5 closest neighbors, (2) use a GNN to aggregate the neighborhood into a fixed-size vector and inject that vector as a "soft prompt," (3) use a cross-attention mechanism where the LLM attends to GNN-computed neighbor representations rather than verbalized text.

Alignment: The Representation Mismatch

LLM hidden states live in a semantic embedding space optimized for language modeling. GNN hidden states live in a space optimized for graph topology prediction. These spaces don't naturally align — a point in LLM space doesn't correspond to the same point in GNN space, even for the same entity. Combining them requires a learned alignment (projection layer), but the projection may lose information from both.

The Knowledge Conflict Problem

LLMs encode world knowledge from pre-training. KGs encode factual knowledge from structured data. These can disagree. An LLM might believe "Drug X treats Disease Y" based on medical text it saw during pre-training. The KG might not have this edge because it was added to the medical literature after the KG was compiled. Which source wins? Systems that blindly concatenate LLM context and KG facts without resolving conflicts produce inconsistent outputs.

The temporal alignment problem: LLMs have a training cutoff; KGs are updated continuously. A 2023 LLM augmented with a 2025 KG gets the latest facts but may generate inconsistent reasoning (the LLM's implicit world model is 2 years out of date). How to handle the temporal mismatch between static parametric LLM knowledge and dynamic graph knowledge is an open research question.
An LLM is augmented with a KG. The LLM's parametric memory says "Drug A treats Disease B." The KG says there's no edge (Drug A, treats, Disease B). What is the core challenge this illustrates?

Chapter 9: Connections

LLM+GNN is where two major threads of machine learning converge. Understanding the connections to adjacent fields prevents reinventing wheels and reveals which open problems are really the same problem in disguise.

LLM+GNN in the Wider Landscape

LLM+GNN conceptParallel in another fieldShared insight
LLM as feature encoder (frozen)ImageNet pre-training for computer visionTransferable representations reduce task-specific data needs
Graph-augmented LLM generationRetrieval-augmented generation (RAG)Ground LLM in external verified knowledge at inference time
Joint training (LLM+GNN)Multi-modal learning (CLIP)Two modalities (text + graph) benefit from joint optimization
LoRA for LLM adaptationAdapter tuning in NLPEfficient fine-tuning by adding small, task-specific modules
LLM as pseudo-labelerSelf-training / semi-supervised learningModel generates labels for unlabeled data to expand training set
GNN as structural context providerPosition encodings in transformersInject structural information (graph position) into attention

Open Questions You Could Research Today

Graph position encoding for LLMs: transformers use 1D position encodings (position in a sequence). How do you generalize this to 2D graph structure? A node's "position" in a graph is a complex, multi-dimensional concept. Learnable graph positional encodings that inject graph structure into LLM attention without verbalization is an open problem.
Instruction-following for graph tasks: large LLMs can follow instructions in natural language. Can you instruct an LLM to perform arbitrary graph tasks ("find the node most central to this subgraph") by serializing the graph as structured text? The answer is partially yes, but current LLMs fail on graphs larger than ~20 nodes due to context length and reasoning limitations.

Where This Series Has Brought You

CS224W Lectures 1-5: Graph basics + node embeddings
DeepWalk, node2vec, link prediction fundamentals
Lectures 6-8: GNN message passing + theory
GCN, GraphSAGE, GAT, WL test, expressiveness
Lectures 9-13: Specialized graph types
Heterogeneous, KG, recommender systems, relational databases
Lectures 14-16: Frontier research
Advanced GNNs, KG foundation models, LLM+GNN integration
→ KG Foundation Models (Lecture 15) — The other half of graph foundation models: handling new entities and schemas in knowledge graphs.
→ Advanced GNN Topics (Lecture 14) — GNN explainability, equivariant GNNs, dynamic graphs — the broader GNN frontier.
"The great thing about a language is you can use it to think with — and the great thing about a graph is you can use it to remember things with. Together, they almost think."
— paraphrased from the spirit of CS224W Lecture 16