Xu, Li, Tian, Sonobe, Kawarabayashi, Jegelka — ICML 2018 · arXiv 1806.03536

JK-Net: Jumping Knowledge Networks

In a graph, different nodes need different receptive fields. A leaf node needs local info; a hub needs global context. JK-Net's fix: keep embeddings from every layer, then let the model choose which layer to attend to for each node.

Prerequisites: GCN basics + what graph layers do. That's it.
8
Chapters
4+
Simulations
0
Assumed Knowledge

Chapter 0: One Depth Doesn't Fit All

Imagine two nodes in the same graph: a peripheral node with 2 neighbors, and a hub node connected to 100 others. For the peripheral node, 2 GCN layers is plenty — it sees its whole neighborhood. For the hub, 2 layers means aggregating 100 nodes' worth of information in one shot: the signal gets diluted immediately.

Now flip it. Give the hub 5 layers. The peripheral node now aggregates 5-hop neighborhoods — but on a sparse graph, that might be the entire graph. Its features become an average of every node's features. It's lost its identity in a sea of global information.

The Receptive Field Problem

A leaf node (few neighbors) vs a hub node (many neighbors). Adjust depth K — watch how their receptive fields expand differently. Green = informative neighborhood; red = oversmoothed.

GCN depth K 2
The core tension: GCN depth K is a global hyperparameter — one number for the entire graph. But different nodes in the same graph need different amounts of aggregation. A node at graph periphery needs shallow aggregation (local structure is informative). A node at the dense core needs deep aggregation (or none — its immediate neighbors already contain diverse signal). No single K is right for all nodes.

Standard GCN handles this by choosing a single K (usually 2) and hoping it's good enough for most nodes. JK-Net asks: why not keep the representations from every depth and let the model decide, per node, which depth is most useful?

Why is a single GCN depth K problematic for graphs with mixed node degrees?

Chapter 1: Layer Aggregation

JK-Net's solution is elegant: run K GCN layers as usual, but instead of only using the last layer's output, keep every layer's output and combine them.

Let hv(k) be node v's embedding after layer k. A standard GCN uses only hv(K) for prediction. JK-Net uses all of them: hv(1), hv(2), ..., hv(K).

hvfinal = AGG( hv(1), hv(2), ..., hv(K) )

The aggregation function AGG is the design choice that makes JK-Net flexible. It can be concatenation (keep everything), max-pooling (take the maximum across layers per dimension), or LSTM-attention (learn which layers to attend to).

The "jumping" analogy: Information "jumps" from any layer directly to the final representation — like residual connections but across ALL layers simultaneously. In a standard GCN, information must pass through every subsequent layer. In JK-Net, the model can "jump" to access the representation at exactly layer k, for any k from 1 to K.

Why does this help? Because hv(k) encodes node v's features aggregated from its k-hop neighborhood. By having access to all k from 1 to K, the model can learn to emphasize the k that provides the most useful signal for each node's specific position in the graph.

What distinguishes JK-Net from standard GCN in terms of which layer representations are used?

Chapter 2: JK-Net Architecture (Showcase)

The full JK-Net architecture: K standard GCN layers, each producing a node embedding, followed by a layer-wise aggregation module that selects the right representation per node.

JK-Net Live: Layer Contributions per Node

Click a node to see how much each GCN layer contributes to its final representation. Peripheral nodes (few neighbors) rely on shallow layers; hub nodes rely on deeper layers. Use the aggregation selector to compare methods.

Aggregation
Click any node to inspect its layer contributions.
Click a node above
Data flow through JK-Net:
Input: X ∈ ℝn×d (node features)
Layer 1: H(1) = ReLU(Â X W(0)) ∈ ℝn×d'
Layer 2: H(2) = ReLU(Â H(1) W(1)) ∈ ℝn×d'
... (K layers)
Jump: Hfinal = AGG(H(1), ..., H(K)) ∈ ℝn×(K·d') [concat] or ℝn×d' [maxpool/lstm]
Predict: Ŷ = softmax(Hfinal Wout)
What is the output dimension of JK-Net's final representation using concatenation aggregation with K=4 layers and d'=64 hidden units per layer?

Chapter 3: Aggregation Options

JK-Net proposes three ways to combine layer-wise representations. Each has different expressiveness and computational cost.

1. Concatenation

Simply concatenate all K layer outputs for each node:

hvfinal = [hv(1) || hv(2) || ... || hv(K)]

Output dimension: K × d'. The final classifier (a linear layer) then learns to weight the contributions of each layer. Simple, no additional parameters, but the representation grows linearly with K.

2. Max-Pooling

For each feature dimension j, take the maximum across all K layers:

hv,jfinal = maxk=1,...,K hv,j(k)

Output dimension: d' (same as each layer). This is permutation-invariant over layers — the model doesn't care about layer ordering. It selects the "most activated" value for each feature across all depths.

3. LSTM-Attention

Treat the K layer representations as a sequence and run a bidirectional LSTM over them, taking the final hidden state:

hvfinal = LSTMbidir([hv(1), ..., hv(K)])

This is the most expressive — the LSTM learns to weight different layers differently and can model interactions between layer representations. But it requires training additional parameters (the LSTM weights) and adds complexity.

MethodOutput DimExtra ParamsOrder-SensitiveBest For
ConcatenationK·d'NoneYesSimple, all-around
Max-Poold'NoneNoCompact representations
LSTM-Attnd'LSTM paramsYesComplex depth patterns
Which is best? Empirically, max-pooling is often competitive with or better than LSTM-attention despite being simpler. This suggests that for most graph tasks, the key is accessing the right depth, not learning complex interactions between depths. Concatenation often wins when the downstream task can afford the larger representation size.
Why might max-pooling outperform LSTM-attention in JK-Net despite being simpler?

Chapter 4: Influence Distribution

Xu et al. introduce a diagnostic tool: the influence distribution of node v is the probability distribution over all other nodes that captures how much each node u influences node v's representation after K layers.

αv(u) = |∂hv(K) / ∂xu|

For a standard K-layer GCN, node v's influence distribution is exactly the K-hop random walk distribution starting from v — a fixed function of the graph structure. It doesn't adapt to the actual task or the node's position.

The oversmoothing diagnosis: For a node at the periphery of a sparse graph, the K-hop distribution is narrow (only a few nodes are reachable in K steps). Fine. For a node in a dense clique, the K-hop distribution is nearly uniform over the entire clique — every node in the clique has equal influence, regardless of relevance. This is exactly oversmoothing: all nodes in the neighborhood look the same.

JK-Net changes the influence distribution fundamentally. Because node v's final representation is an aggregation over all K layers, and layer k uses k-hop neighborhoods, the effective influence distribution is a mixture of 1-hop through K-hop random walk distributions. The mixture weights are learned — they can be concentrated at shallow depths for peripheral nodes and at deeper depths for hub nodes.

Influence Distribution: GCN vs JK-Net

For a selected node type, visualize which other nodes influence its representation. GCN uses a fixed K-hop distribution; JK-Net learns a mixture. Toggle the mode and node type.

Mode: GCN | Node: Peripheral
How does JK-Net change the influence distribution compared to standard GCN?

Chapter 5: Results

Xu et al. evaluated JK-Net on node classification (Cora, Citeseer, Pubmed) and several social network datasets. The key insight is not raw accuracy numbers but the improvement from increasing depth — JK-Net scales with more layers; GCN degrades.

DatasetGCN (K=2)GCN (K=6)JK-Concat (K=6)JK-MaxPool (K=6)
Cora81.5%79.8%83.3%83.6%
Citeseer70.3%68.1%72.6%73.0%
Pubmed79.0%78.2%79.8%80.2%
The crucial comparison: GCN accuracy drops when going from K=2 to K=6 (e.g., Cora: 81.5% → 79.8%). JK-Net with K=6 beats GCN with K=2 (Cora: 83.6% vs 81.5%). JK-Net scales with depth while GCN degrades. This confirms the hypothesis: the right information for many nodes is at deeper layers, but GCN can't access it without oversmoothing. JK-Net can.

On social network datasets (Reddit, PPI) with stronger long-range dependencies, the improvements are even larger. Hub nodes in social networks benefit particularly from the ability to selectively attend to shallow layers, avoiding the dilution caused by aggregating thousands of neighbors.

What does the GCN K=2 vs K=6 comparison reveal about standard GCN?

Chapter 6: vs Deep GCN

Several approaches tackle the "deep GCN" problem — the fact that standard GCN degrades with depth. JK-Net is one. Let's compare the strategies.

MethodStrategyHandles OversmoothingPer-Node Adaptation
GCN (baseline)Fixed K layers, use last onlyNoNo
ResGCNSkip connections: h^(k+1) += h^(k)PartiallyNo
DenseGCNAll previous layers → current layer inputBetterNo
JK-NetAll layers → final aggregation onlyYesYes (per node)
DropEdgeRandom edge dropout during trainingPartiallyNo
PairNormNormalize to prevent oversmoothingPartiallyNo
JK-Net vs ResGCN: ResGCN adds a skip connection from layer k to layer k+1 — this helps gradient flow but doesn't give the final layer access to all intermediate representations. JK-Net is architecturally different: it's not about gradient flow during training; it's about representation access at inference. The final prediction uses representations from all depths, not just a shortcut-assisted version of the last layer.

The key distinction: ResGCN improves the training of deep GCNs. JK-Net improves the representation available for prediction. These are complementary — you can combine JK-Net with residual connections inside each layer (and this often works best).

What makes JK-Net architecturally different from ResGCN (skip connections)?

Chapter 7: Connections

JK-Net introduced multi-scale representation learning for graphs — an idea that reappears in many subsequent architectures and remains a core design principle for graph neural networks on heterogeneous graphs.

MethodKey IdeaRelation to JK-Net
APPNPPersonalized PageRank as aggregation weightsAdaptive receptive field (different angle)
SIGNMulti-scale precomputed featuresJK-Net idea + SGC precomputation
MixHopMix h^(1), h^(2), ... as GCN layer inputMulti-hop within each layer (not just final)
DAGNNDecouple propagation and transformationSimilar philosophy to JK + SGC
Design Space GNNsStudy of all GNN design choicesJK (skip connections) as one design dimension
JK-Net in the design space: The Design Space for GNNs paper (You et al., 2020) systematically studies which GNN design choices matter. They find that inter-layer connections (skip connections of which JK-Net is the most aggressive form) are one of the most impactful design choices, particularly for deep GNNs. JK-Net's full-layer access is the upper bound of the skip connection spectrum.

When to use JK-Net

Closing thought: JK-Net embodies a principle that appears throughout deep learning: when you don't know which level of representation is most informative, keep them all and let the model learn. This is the same insight behind DenseNet in vision and multi-scale feature pyramids in object detection. JK-Net brings this idea to the graph domain, where the relevant "scale" is not spatial resolution but graph-topological distance.