Node model
Node Prediction with the Simple API¶
Node prediction GNNs use supervised learning to predict node feature values (categorical labels for classification, and numerical labels for regression).
Part 1: This tutorial shows how to train, analyze, and evaluate a model to
predict the label feature in the MAG citation graph. In this graph, authors,
papers, fields of study, and institutions constitute distinct node sets.
Relationships—such as authors writing papers, author affiliations with
institutions, paper topics, and citations between papers—are defined by specific
edge sets. The objective is to predict the domain of a paper stored in the
labels feature.
Part 2: The first part of the tutorial trains a GNN with access to all articles and citations. The second part addresses real-world constraints where data availability depends on the time of publication. This model is trained and tested using only historical data for each node's prediction.
Installing GF¶
Make sure your machine has a GPU or TPU, otherwise training is going to take forever. If you are using Google Colab, you can get one for free. Just go to Edit > Notebook settings and select your hardware accelerator as TPU or GPU.
# Install DGF (Distributed Graph Flow) and OGB (for the toy dataset).
!pip install dgf ogb -U
Importing libraries¶
import os
os.environ["TF_USE_LEGACY_KERAS"] = "1"
import dgf # Import Graph Flow
Getting the graph data¶
We first fetch the MAG graph.
# Download the Mag graph from the OGB repo.
graph, schema = dgf.io.fetch_ogb_graph("mag")
Caching mag graph at /tmp/gf_fetch/mag.cache OGB dependency not available. Downloading graph from CNS.
Let's look at the graph structure.
# Show the schema
dgf.analyse.print_schema(schema)
Graph Schema:
Node Sets:
author:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
field_of_study:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
institution:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|-------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
| #split | BYTES | CATEGORICAL | None | None |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | None |
| year | INTEGER_64 | NUMERICAL | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
Part 1: Training a time-agnostic model¶
We train a model to predict the labels feature of the paper nodeset.
model = dgf.learning.train_node_model(
graph=graph,
schema=schema,
target_nodeset="paper",
target_column="labels",
# Reduce the number of hops and train steps, for the demo to be faster.
num_sampling_hops=1,
num_train_steps=5000,
)
Using gpu JAX backend
Graph input schema:
Node Sets:
author:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
field_of_study:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
institution:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|-------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
| #split | BYTES | CATEGORICAL | None | None |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | None |
| year | INTEGER_64 | NUMERICAL | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
Preparing dataset
Num. training seed nodes: 704389, Num. validation seed nodes: 32000
Create graph sampler
Compute feature statistics
GraphFeatureStatistics:
Node Sets (4):
'author':
'#id': count=46985, min=nan, max=nan
'field_of_study':
'#id': count=101920, min=nan, max=nan
'institution':
'#id': count=0, min=nan, max=nan
'paper':
'#id': count=111535, min=nan, max=nan
'#split': count=111535, min=nan, max=nan, dictionary=(3)['train': 94675, 'valid': 9821, 'test': 7039]
'feat': count=111535, min=nan, max=nan
'labels': count=111535, min=0.0000, max=348.0000
'year': count=111535, min=2010.0000, max=2019.0000, quantiles=(100)[2010.0000, 2010.0000, 2010.0000, ..., 2019.0000, 2019.0000, 2019.0000]
[Warning] No normalizer created for node set 'author', feature '#id'. [Warning] No normalizer created for node set 'paper', feature '#id'. [Warning] No normalizer created for node set 'field_of_study', feature '#id'. [Warning] No normalizer created for node set 'institution', feature '#id'.
Compute graph statistics for padding
padding: Node Sets:
author: 256 nodes
field_of_study: 384 nodes
institution: 2 nodes
paper: 542 nodes
Edge Sets:
affiliated_with: 2 edges
cites: 507 edges
has_topic: 384 edges
writes: 256 edges
Preparing dataset finished in 4.40 seconds
Normalizer:
Graph Normalizer:
Node Sets:
author:
(No normalizers)
field_of_study:
(No normalizers)
institution:
(No normalizers)
paper:
- #split: DictionaryIndexNormalizer
- feat: IdentityNormalizer
- labels: IdentityNormalizer
- year: SoftQuantileNormalizer
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No normalizers)
cites: (Source: paper, Target: paper)
(No normalizers)
has_topic: (Source: paper, Target: field_of_study)
(No normalizers)
writes: (Source: author, Target: paper)
(No normalizers)
Normalized graph schema:
Node Sets:
author:
(No features)
field_of_study:
(No features)
institution:
(No features)
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|--------------------|------------|-------------|---------|-----------------|
| #split_INDEX | INTEGER_64 | CATEGORICAL | None | 4 |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | 349 |
| year_SOFT_QUANTILE | FLOAT_32 | EMBEDDING | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
Core model config:
EmbedGraph(cat-embedding=64)
Dense(128)
Activation(silu)
Norm(layer_norm)
Graph Convolution Block x2:
X = ...
MPNN:
Message:
Dense(128)
Activation(silu)
Dense(128)
Update:
Dense(128)
Activation(silu)
Dropout(0.1)
Dense(128)
Residual(X)
# Post MPNN
X = ...
Norm(rms_norm)
Dense(512)
Activation(silu)
Dense(128)
Dropout(0.1)
Residual(X)
Identity
Dense(349) # Classification head
Normalized input features:
Node Sets:
author:
(No features)
field_of_study:
(No features)
institution:
(No features)
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|--------------------|------------|-------------|---------|-----------------|
| #split_INDEX | INTEGER_64 | CATEGORICAL | None | 4 |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| year_SOFT_QUANTILE | FLOAT_32 | EMBEDDING | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
Caching validation dataset
Caching validation dataset: 100%|██████████| 1000/1000 [00:06<00:00, 148.22it/s]
Caching validation dataset finished in 6.75 seconds Number of cache validation batches: 1000 Training model Generate first batch to initialize model Create model variables
...Tracing model Create model variables finished in 20.42 seconds Will validate model every 1000 step(s) Will checkpoint model every 1000 step(s) Start training. The first two steps are generally slow.
Training: 0%| | 0/10000 [00:00<?, ?it/s]
...Tracing model
Training: 10%|▉ | 999/10000 [00:32<01:43, 86.89it/s, step=1000, train-accuracy=0.2009, train-loss=3.5856]
...Tracing model
Training: 10%|█ | 1017/10000 [00:35<11:23, 13.13it/s, step=1000, train-accuracy=0.2009, train-loss=3.5856]
Validation loop took 2.76s (only printed once) step:1000 train-accuracy:0.2009 train-loss:3.5856 valid-accuracy:0.2043 valid-loss:3.5337
Training: 20%|██ | 2013/10000 [00:48<05:12, 25.52it/s, step=2000, train-accuracy=0.2056, train-loss=3.4303, valid-accuracy=0.2043, valid-loss=3.5337]
step:2000 train-accuracy:0.2056 train-loss:3.4303 valid-accuracy:0.2192 valid-loss:3.3695
Training: 30%|███ | 3015/10000 [01:01<04:23, 26.46it/s, step=3000, train-accuracy=0.2316, train-loss=3.2759, valid-accuracy=0.2192, valid-loss=3.3695]
step:3000 train-accuracy:0.2316 train-loss:3.2759 valid-accuracy:0.2280 valid-loss:3.2890
Training: 40%|████ | 4014/10000 [01:13<03:54, 25.50it/s, step=4000, train-accuracy=0.2381, train-loss=3.2077, valid-accuracy=0.2280, valid-loss=3.2890]
step:4000 train-accuracy:0.2381 train-loss:3.2077 valid-accuracy:0.2349 valid-loss:3.2208
Training: 50%|█████ | 5019/10000 [01:26<03:05, 26.82it/s, step=5000, train-accuracy=0.2466, train-loss=3.1467, valid-accuracy=0.2349, valid-loss=3.2208]
step:5000 train-accuracy:0.2466 train-loss:3.1467 valid-accuracy:0.2451 valid-loss:3.1452
Training: 60%|██████ | 6010/10000 [01:39<02:30, 26.53it/s, step=6000, train-accuracy=0.2428, train-loss=3.1536, valid-accuracy=0.2451, valid-loss=3.1452]
step:6000 train-accuracy:0.2428 train-loss:3.1536 valid-accuracy:0.2518 valid-loss:3.0848
Training: 70%|███████ | 7013/10000 [01:52<01:50, 27.07it/s, step=7000, train-accuracy=0.2600, train-loss=3.0593, valid-accuracy=0.2518, valid-loss=3.0848]
step:7000 train-accuracy:0.2600 train-loss:3.0593 valid-accuracy:0.2624 valid-loss:3.0291
Training: 80%|████████ | 8020/10000 [02:04<01:07, 29.48it/s, step=8000, train-accuracy=0.2634, train-loss=2.9889, valid-accuracy=0.2624, valid-loss=3.0291]
step:8000 train-accuracy:0.2634 train-loss:2.9889 valid-accuracy:0.2637 valid-loss:2.9987
Training: 90%|█████████ | 9013/10000 [02:17<00:38, 25.91it/s, step=9000, train-accuracy=0.2753, train-loss=2.9800, valid-accuracy=0.2637, valid-loss=2.9987]
step:9000 train-accuracy:0.2753 train-loss:2.9800 valid-accuracy:0.2680 valid-loss:2.9804
Training: 100%|██████████| 10000/10000 [02:28<00:00, 67.42it/s, step=9900, train-accuracy=0.2691, train-loss=2.9823, valid-accuracy=0.2680, valid-loss=2.9804]
step:10000 train-accuracy:0.2691 train-loss:2.9823 valid-accuracy:0.2672 valid-loss:2.9761
Final metrics: {'step': '9900', 'train-accuracy': '0.2691', 'train-loss': '2.9823', 'valid-accuracy': '0.2680', 'valid-loss': '2.9804'}
Training model finished in 170.22 seconds
Once trained, it is important to look at the model:
model.describe()
Node prediction model: Predict the value of a node feature.
- Target nodeset: paper
- Target column: labels
- Number of label classes: 349
- Number of training seed nodes: 704389
- Number of validation seed nodes: 32000
- Training duration: 3m 1s
num_sampling_hops=1 sampling_width=15 num_layers=2 batch_size=32 max_training_time_seconds=None num_train_steps=10000 random_seed=42 node_embedding_dim=128 learning_rate=0.001 opt_weight_decay=0.0001 dropout=0.1 message_pooling='sum' architecture=
Node Sets:
author:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
field_of_study:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
institution:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|-------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
| #split | BYTES | CATEGORICAL | None | None |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | None |
| year | INTEGER_64 | NUMERICAL | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
Normalized schema
Node Sets:
author:
(No features)
field_of_study:
(No features)
institution:
(No features)
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|--------------------|------------|-------------|---------|-----------------|
| #split_INDEX | INTEGER_64 | CATEGORICAL | None | 4 |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | 349 |
| year_SOFT_QUANTILE | FLOAT_32 | EMBEDDING | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
GraphFeatureStatistics:
Node Sets (4):
'author':
'#id': count=46985, min=nan, max=nan
'field_of_study':
'#id': count=101920, min=nan, max=nan
'institution':
'#id': count=0, min=nan, max=nan
'paper':
'#id': count=111535, min=nan, max=nan
'#split': count=111535, min=nan, max=nan, dictionary=(3)['train': 94675, 'valid': 9821, 'test': 7039]
'feat': count=111535, min=nan, max=nan
'labels': count=111535, min=0.0000, max=348.0000
'year': count=111535, min=2010.0000, max=2019.0000, quantiles=(100)[2010.0000, 2010.0000, 2010.0000, ..., 2019.0000, 2019.0000, 2019.0000]
Root: paper ├── cites [width=15] ➔ paper ├── cites (reversed) [width=15] ➔ paper ├── has_topic [width=15] ➔ field_of_study └── writes (reversed) [width=15] ➔ author
EmbedGraph(cat-embedding=64)
Dense(128)
Activation(silu)
Norm(layer_norm)
Graph Convolution Block x2:
X = ...
MPNN:
Message:
Dense(128)
Activation(silu)
Dense(128)
Update:
Dense(128)
Activation(silu)
Dropout(0.1)
Dense(128)
Residual(X)
# Post MPNN
X = ...
Norm(rms_norm)
Dense(512)
Activation(silu)
Dense(128)
Dropout(0.1)
Residual(X)
Identity
Dense(349) # Classification headModel Weights{'float32': 2312413}Node Sets: author: 256 nodes field_of_study: 384 nodes institution: 2 nodes paper: 542 nodes Edge Sets: affiliated_with: 2 edges cites: 507 edges has_topic: 384 edges writes: 256 edges
After training, a model is generally evaluated on a test dataset/graph.
Note: We don't have a test graph, so we use our training dataset here. In a real pipeline, evaluating a model on a training dataset make little sense.
model.evaluate(graph)
Evaluating model on 10000 samples
Inference: 0%| | 0/625 [00:00<?, ?it/s]
...Tracing model
Inference: 100%|██████████| 625/625 [00:05<00:00, 115.33it/s]
- Accuracy: 0.2723
- Num Examples: 10000
We can make predictions for individual nodes:
# Predict the probability of each label class for the nodes 0 and 1.
predictions = model.predict(graph, seed_node_idxs=[0, 1])
print("\npredictions's shape [node idx, class idx]:", predictions.shape)
Inference: 0%| | 0/1 [00:00<?, ?it/s]
...Tracing model
Inference: 100%|██████████| 1/1 [00:01<00:00, 1.85s/it]
predictions's shape [node idx, class idx]: (2, 349)
Part 2: Training a time-aware model¶
The graph schema shows (see dgf.analyse.print_schema()) that the paper
nodeset has a year feature; this is the year of publication of the paper. In
this second part, we want to train and evaluate the model as if the GNN model
were applied at the time of publication, that is, it did not have access to
information about papers published after. This is called time-aware modeling.
For time-aware modeling, we need a timestamp on the target nodeset
(paper in our case) and on the edgesets we want to filter (writes, cites
and has_topic; but not affiliated_with). In this MAG dataset, it makes sense
to simply propagate the paper nodeset timestamp to those edgesets: A paper is
written at the time of its publication.
The method dgf.transform.propagate_timestamp_to_edges does this automatically.
In addition, we want the paper's year feature to also have a TIMESTAMP semantic.
time_graph, time_schema = dgf.transform.propagate_timestamp_to_edges(
graph=graph,
schema=schema,
target_edgesets=["has_topic", "cites", "writes"],
node_timestamps={"paper": "year"},
target_feature="year", # Name of the new edge features.
)
time_schema.node_sets["paper"].features[
"year"
].semantic = dgf.data.FeatureSemantic.TIMESTAMP
dgf.analyse.print_schema(time_schema)
Graph Schema:
Node Sets:
author:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
field_of_study:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
institution:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|-------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
| #split | BYTES | CATEGORICAL | None | None |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | None |
| year | INTEGER_64 | TIMESTAMP | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|------------|---------|-----------------|
| year | INTEGER_64 | TIMESTAMP | None | None |
has_topic: (Source: paper, Target: field_of_study)
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|------------|---------|-----------------|
| year | INTEGER_64 | TIMESTAMP | None | None |
writes: (Source: author, Target: paper)
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|------------|---------|-----------------|
| year | INTEGER_64 | TIMESTAMP | None | None |
We can train the model with time_aware=True.
model = dgf.learning.train_node_model(
graph=time_graph,
schema=time_schema,
target_nodeset="paper",
target_column="labels",
time_aware=True,
verbose=1, # Print less logs than before.
# Reduce the number of hops and train steps, for the demo to be faster.
num_sampling_hops=1,
num_train_steps=5000,
)
Preparing dataset Num. training seed nodes: 704389, Num. validation seed nodes: 32000
[Warning] No normalizer created for node set 'author', feature '#id'. [Warning] No normalizer created for node set 'paper', feature '#id'. [Warning] No normalizer created for node set 'paper', feature 'year'. [Warning] No normalizer created for node set 'field_of_study', feature '#id'. [Warning] No normalizer created for node set 'institution', feature '#id'.
Preparing dataset finished in 4.46 seconds Caching validation dataset Caching validation dataset finished in 4.06 seconds Number of cache validation batches: 1000 Training model Generate first batch to initialize model Create model variables ...Tracing model Create model variables finished in 15.80 seconds Will validate model every 1000 step(s) Will checkpoint model every 1000 step(s) Start training. The first two steps are generally slow.
Training: 0%| | 0/10000 [00:00<?, ?it/s]
...Tracing model
Training: 10%|▉ | 997/10000 [00:26<01:42, 87.93it/s, step=1000, train-accuracy=0.1894, train-loss=3.5854]
...Tracing model
Training: 10%|█ | 1015/10000 [00:29<11:08, 13.44it/s, step=1000, train-accuracy=0.1894, train-loss=3.5854]
Validation loop took 2.70s (only printed once)
Training: 100%|██████████| 10000/10000 [02:21<00:00, 70.86it/s, step=9900, train-accuracy=0.2644, train-loss=2.9738, valid-accuracy=0.2657, valid-loss=2.9960]
Final metrics: {'step': '9900', 'train-accuracy': '0.2644', 'train-loss': '2.9738', 'valid-accuracy': '0.2657', 'valid-loss': '2.9960'}
Training model finished in 158.20 seconds
We look at the model:
model.describe()
Node prediction model: Predict the value of a node feature.
- Target nodeset: paper
- Target column: labels
- Number of label classes: 349
- Number of training seed nodes: 704389
- Number of validation seed nodes: 32000
- Training duration: 2m 46s
num_sampling_hops=1 sampling_width=15 num_layers=2 batch_size=32 max_training_time_seconds=None num_train_steps=10000 random_seed=42 node_embedding_dim=128 learning_rate=0.001 opt_weight_decay=0.0001 dropout=0.1 message_pooling='sum' architecture=
Node Sets:
author:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
field_of_study:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
institution:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|----------|------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|-------------|---------|-----------------|
| #id | BYTES | PRIMARY_ID | None | None |
| #split | BYTES | CATEGORICAL | None | None |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | None |
| year | INTEGER_64 | TIMESTAMP | None | None |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|------------|---------|-----------------|
| year | INTEGER_64 | TIMESTAMP | None | None |
has_topic: (Source: paper, Target: field_of_study)
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|------------|---------|-----------------|
| year | INTEGER_64 | TIMESTAMP | None | None |
writes: (Source: author, Target: paper)
| Feature | Format | Semantic | Shape | Num cat. vals |
|-----------|------------|------------|---------|-----------------|
| year | INTEGER_64 | TIMESTAMP | None | None |
Normalized schema
Node Sets:
author:
(No features)
field_of_study:
(No features)
institution:
(No features)
paper:
| Feature | Format | Semantic | Shape | Num cat. vals |
|--------------|------------|-------------|---------|-----------------|
| #split_INDEX | INTEGER_64 | CATEGORICAL | None | 4 |
| feat | FLOAT_32 | EMBEDDING | (128,) | None |
| labels | INTEGER_64 | CATEGORICAL | None | 349 |
Edge Sets:
affiliated_with: (Source: author, Target: institution)
(No features)
cites: (Source: paper, Target: paper)
(No features)
has_topic: (Source: paper, Target: field_of_study)
(No features)
writes: (Source: author, Target: paper)
(No features)
GraphFeatureStatistics:
Node Sets (4):
'author':
'#id': count=47004, min=nan, max=nan
'field_of_study':
'#id': count=101860, min=nan, max=nan
'institution':
'#id': count=0, min=nan, max=nan
'paper':
'#id': count=72371, min=nan, max=nan
'#split': count=72371, min=nan, max=nan, dictionary=(3)['train': 68044, 'valid': 3061, 'test': 1266]
'feat': count=72371, min=nan, max=nan
'labels': count=72371, min=0.0000, max=348.0000
'year': count=72371, min=2010.0000, max=2019.0000
Root: paper
├── cites [width=15] ➔ paper
├── cites (reversed) [width=15] ➔ paper
├── has_topic [width=15] ➔ field_of_study
└── writes (reversed) [width=15] ➔ author
Temporal Features: {'has_topic': 'year', 'cites': 'year', 'writes': 'year'}EmbedGraph(cat-embedding=64)
Dense(128)
Activation(silu)
Norm(layer_norm)
Graph Convolution Block x2:
X = ...
MPNN:
Message:
Dense(128)
Activation(silu)
Dense(128)
Update:
Dense(128)
Activation(silu)
Dropout(0.1)
Dense(128)
Residual(X)
# Post MPNN
X = ...
Norm(rms_norm)
Dense(512)
Activation(silu)
Dense(128)
Dropout(0.1)
Residual(X)
Identity
Dense(349) # Classification headModel Weights{'float32': 2312285}Node Sets: author: 233 nodes field_of_study: 384 nodes institution: 2 nodes paper: 369 nodes Edge Sets: affiliated_with: 2 edges cites: 335 edges has_topic: 384 edges writes: 233 edges
You can see the effect of time-aware sampling in two places:
- In the "graph sampling" tab, the
edgeset_timestamp_featuresfield show the edge features to time-filter on. - In the "padding tab", the paper nodeset padding is smaller than before. Because of filtering, each paper node has in average half the number of edges, and the graph samples (used internally to train the GNN) are smaller.
The model quality is only a little reduced. In this dataset, future information is not critical.