Source code for src.algorithms.deep.models.GCNEncoder

import torch
from torch import nn
from torch_geometric.nn import GCNConv


[docs] class GCNEncoder(nn.Module): """Graph Convolutional Network Encoder for Graph Autoencoder :param in_channels: Number of input channels :type in_channels: int :param latent_dim: Latent dimension :type latent_dim: int :param activation: Activation function :type activation: nn.Module """ def __init__(self, in_channels: int, latent_dim: int, activation: nn.Module = nn.ReLU): super(GCNEncoder, self).__init__() self.conv1: GCNConv = GCNConv(in_channels, latent_dim * 2) self.conv2: GCNConv = GCNConv(latent_dim * 2, latent_dim) self.act: nn.Module = activation()
[docs] def forward(self, x: torch.tensor, edge_index: torch.tensor) -> torch.tensor: """Forward pass :param x: Node features :type x: torch.tensor :param edge_index: Edge index :type edge_index: torch.tensor :return: Node embeddings :rtype: torch.tensor """ x = self.act(self.conv1(x, edge_index)) return self.conv2(x, edge_index)