Temporal Graph Networks: A Deep Dive into Dynamic Graph Learning

Real-world networks are rarely static. Social networks evolve as users form new connections, financial networks change with each transaction, and biological networks transform as proteins interact. Traditional Graph Neural Networks (GNNs) werent designed for this dynamism. Enter Temporal Graph Networks (TGNs), a powerful framework for learning on dynamic graphs.

Understanding Dynamic Graphs

Before diving into TGNs, lets clarify what we mean by dynamic graphs. A temporal graph can be represented as a sequence of time-stamped events: G = {x(t₁), x(t₂), }. Each event can be:

  1. Node-wise: Adding/updating a node with features
  2. Edge-wise: Creating an interaction between nodes

Core Components of TGN

Memory Module

The memory module is TGNs secret weapon. Each node maintains a memory state that captures its history of interactions. This memory gets updated with each new event, allowing the network to learn long-term dependencies.

Message Function

When an interaction occurs between nodes, messages are computed to update their memories. Heres how the message functions work:

class TemporalGraphNetwork:
    def __init__(self, node_features, edge_features, memory_dimension):
        self.node_features = node_features
        self.edge_features = edge_features
        self.memory = {node_id: np.zeros(memory_dimension) for node_id in node_features}
        
    def compute_messages(self, source, target, time, edge_features):
        """Compute messages for source and target nodes."""
        # Get current memory states
        source_memory = self.memory[source]
        target_memory = self.memory[target]
        
        # Compute time delta from last update
        delta_time = time - self.last_update[source]
        
        # Compute messages for both nodes
        source_message = self.message_function(
            source_memory=source_memory,
            target_memory=target_memory,
            delta_time=delta_time,
            edge_features=edge_features
        )
        
        target_message = self.message_function(
            source_memory=target_memory,
            target_memory=source_memory,
            delta_time=delta_time,
            edge_features=edge_features
        )
        
        return source_message, target_message

    def message_function(self, source_memory, target_memory, delta_time, edge_features):
        """Simple concatenation-based message function."""
        return np.concatenate([
            source_memory,
            target_memory,
            [delta_time],
            edge_features
        ])

Message Aggregator

When multiple events involve the same node in a batch, their messages need to be aggregated:

    def aggregate_messages(self, messages, aggregation_type="last"):
        """Aggregate multiple messages for the same node."""
        if aggregation_type == "last":
            return messages[-1]  # Return most recent message
        elif aggregation_type == "mean":
            return np.mean(messages, axis=0)  # Average all messages

Memory Updater

The memory updater is typically implemented using a GRU or LSTM to update node memories based on aggregated messages:

    def update_memory(self, node_id, message, time):
        """Update node memory using GRU-like update."""
        current_memory = self.memory[node_id]
        
        # GRU-like update equations
        update_gate = sigmoid(
            self.W_update @ np.concatenate([current_memory, message])
        )
        reset_gate = sigmoid(
            self.W_reset @ np.concatenate([current_memory, message])
        )
        
        candidate_memory = tanh(
            self.W_candidate @ np.concatenate([
                reset_gate * current_memory,
                message
            ])
        )
        
        # Update memory
        self.memory[node_id] = (
            update_gate * current_memory +
            (1 - update_gate) * candidate_memory
        )
        self.last_update[node_id] = time

Embedding Module

The embedding module generates node embeddings using the current memory state and graph structure:

    def compute_embedding(self, node_id, time, num_layers=1):
        """Compute node embedding using temporal graph attention."""
        current_embedding = self.memory[node_id]
        
        for layer in range(num_layers):
            # Get temporal neighbors
            neighbors = self.get_temporal_neighbors(node_id, time)
            
            # Compute attention scores
            attention_scores = self.compute_temporal_attention(
                query=current_embedding,
                keys=[self.memory[n] for n in neighbors],
                times=[self.last_update[n] for n in neighbors]
            )
            
            # Aggregate neighbor information
            neighbor_info = sum(
                score * self.memory[neighbor]
                for score, neighbor in zip(attention_scores, neighbors)
            )
            
            # Update embedding
            current_embedding = self.embedding_update(
                current_embedding,
                neighbor_info
            )
            
        return current_embedding

Training Process

Training TGN requires careful handling of temporal dependencies. Heres the main training loop:

    def train(self, temporal_edges, batch_size=200):
        """Train the model on temporal edges."""
        message_store = {}  # Store for raw messages
        
        for batch in create_batches(temporal_edges, batch_size):
            # 1. Update memories using stored messages
            self.process_message_store(message_store)
            
            # 2. Compute embeddings
            source_embeddings = [
                self.compute_embedding(edge.source, edge.time)
                for edge in batch
            ]
            target_embeddings = [
                self.compute_embedding(edge.target, edge.time)
                for edge in batch
            ]
            
            # 3. Compute loss
            loss = self.compute_loss(
                source_embeddings,
                target_embeddings,
                batch.labels
            )
            
            # 4. Backward pass
            loss.backward()
            
            # 5. Store raw messages for next batch
            for edge in batch:
                message_store[edge.source] = (
                    edge.source,
                    edge.target,
                    edge.time,
                    edge.features
                )
                message_store[edge.target] = (
                    edge.target,
                    edge.source,
                    edge.time,
                    edge.features
                )

Key Advantages and applications

  1. Memory Efficiency: TGN maintains a compact memory state for each node instead of storing the entire history.
  2. Temporal Awareness: The framework naturally handles time-stamped events and evolving graphs.
  3. Flexibility: Different message, aggregation, and memory update functions can be used based on the specific application.
  4. Scalability: The batched training process allows for efficient processing of large temporal graphs.

TGNs have shown remarkable success in various domains: - Dynamic link prediction in social networks - User-item interaction modeling in recommender systems - Temporal knowledge graph completion - Financial fraud detection - Traffic prediction

Implementation Considerations

  • Batch Size: Smaller batches ensure more frequent memory updates but slower training.
  • Memory Dimension: Larger memory can capture more complex patterns but requires more computation.
  • Neighbor Sampling: Sampling recent neighbors often works better than uniform sampling.
  • Time Encoding: Different time encoding strategies can be used based on the temporal patterns in the data.