"""Adapted from https://github.com/kavehhassani/mvgrl/blob/master/node/train.py
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch_geometric.nn import GCNConv
[docs]
class GCN(nn.Module):
	"""Graph Convolutional Network (GCN) with as single layer
	:param in_channels: Number of input features
	:type in_channels: int
	:param out_channels: Number of output features
	:type out_channels: int
	"""
	def __init__(self, in_channels: int, out_channels: int):
		super(GCN, self).__init__()
		self.conv1: GCNConv = GCNConv(in_channels, out_channels)
		self.prelu1: nn.PReLU = nn.PReLU()
		self.conv2: GCNConv = GCNConv(out_channels, out_channels)
		self.prelu2: nn.PReLU = nn.PReLU()
[docs]
	def forward(self, x: torch.tensor, edge_index: torch.tensor, edge_weight: torch.tensor = None) -> torch.tensor:
		"""Forward pass
		:param x: Input features
		:type x: torch.tensor
		:param edge_index: Edge index tensor
		:type edge_index: torch.tensor
		:param edge_weight: Edge weight tensor (if any)
		:type edge_weight: torch.tensor
		:return: Embeddings of the nodes at each GCN layer
		:rtype: torch.tensor
		"""
		h1 = self.prelu1(self.conv1(x, edge_index, edge_weight))
		return h1, self.prelu2(self.conv2(h1, edge_index, edge_weight)) 
 
[docs]
class Projection(nn.Module):
	"""Projection layer
	:param latent_dim: Dimension of the latent space
	:type latent_dim: int
	"""
	def __init__(self, latent_dim: int):
		super(Projection, self).__init__()
		self.fc: nn.Linear = nn.Linear(latent_dim, latent_dim)
		self.prelu: nn.PReLU = nn.PReLU()
[docs]
	def forward(self, h: torch.tensor) -> torch.tensor:
		"""Forward pass
		:param h: Node embeddings
		:type h: torch.tensor
		:return: Projected embeddings
		:rtype: torch.tensor
		"""
		return self.prelu(self.fc(h)) 
 
[docs]
class Readout(nn.Module):
	"""Readout function for a one-layer GCN model
	"""
	def __init__(self, latent_dim: int):
		super(Readout, self).__init__()
		self.fc: nn.Linear = nn.Linear(latent_dim * 2, latent_dim)
[docs]
	def forward(self, h1: torch.tensor, h2: torch.tensor) -> torch.tensor:
		"""Pooling layer
		:param h: Node embeddings at the first GCN layer
		:type h: torch.tensor
		:return: Pooled embeddings
		:rtype: torch.tensor
		"""
		return F.sigmoid(self.fc(torch.cat([h1.mean(dim=-2), h2.mean(dim=-2)], dim=-1))) 
 
[docs]
class Discriminator(nn.Module):
	"""Discriminator module
	:param in_channels: Number of features in the hidden GCN layers
	:type in_channels: int
	"""
	def __init__(self, in_channels: int):
		super(Discriminator, self).__init__()
		self.bilinear: nn.Bilinear = nn.Bilinear(in_channels, in_channels, 1)
[docs]
	def forward(self, ha: torch.tensor, hb: torch.tensor, Ha: torch.tensor, Hb: torch.tensor, Ha_corrupted: torch.tensor, Hb_corrupted: torch.tensor) -> torch.tensor:
		""" Forward pass of the discriminator computer the MI between the two representations of the views
		:param ha: Graph embedding of the original view
		:type ha: torch.tensor
		:param hb: Graph embedding of the diffused view
		:type hb: torch.tensor
		:param Ha: Node embedding of the original view
		:type Ha: torch.tensor
		:param Hb: Node embedding of the diffused view
		:type Hb: torch.tensor
		:param Ha_corrupted: Node embedding of the corrupted original view
		:type Ha_corrupted: torch.tensor
		:param Hb_corrupted: Node embedding of the corrupted diffused view
		:type Hb_corrupted: torch.tensor
		:return: Discriminator output
		:rtype: torch.tensor
		"""
		ha = ha.expand_as(Ha)
		hb = hb.expand_as(Hb)
		return torch.cat(
			[
				self.bilinear(hb, Ha).squeeze(),
				self.bilinear(ha, Hb).squeeze(),
				self.bilinear(hb, Ha_corrupted).squeeze(),
				self.bilinear(ha, Hb_corrupted).squeeze()
			], dim=-1
		) 
 
[docs]
class MVGRLModel(nn.Module):
	"""Multi-View Graph Representation Learning (MVGRL) model
	:param in_channels: Number of input features
	:type in_channels: int
	:param latent_dim: Dimension of the latent space
	:type latent_dim: int
	"""
	def __init__(self, in_channels: int, latent_dim: int):
		super(MVGRLModel, self).__init__()
		self.gcn_real: GCN = GCN(in_channels, latent_dim)
		self.gcn_diff: GCN = GCN(in_channels, latent_dim)
		self.readout: Readout = Readout(latent_dim)
		self.projector_nodes: Projection = nn.Identity()  # Projection(latent_dim)
		self.projector_graph: Projection = nn.Identity()  # Projection(latent_dim)
		self.discriminator: Discriminator = Discriminator(latent_dim)
[docs]
	def forward(self, x: torch.tensor, edge_index: torch.tensor, diff_edge_index: torch.tensor, diff_edge_weight: torch.tensor, corrupted_idx: torch.tensor = None):
		"""Forward pass, a=alpha (original view), b=beta (diffused view)
		:param x: Input features
		:type x: torch.tensor
		:param edge_index: Edge index tensor
		:type edge_index: torch.tensor
		:param diff_edge_index: Diffused edge index tensor
		:type diff_edge_index: torch.tensor
		:param diff_edge_weight: Diffused edge weight tensor
		:type diff_edge_weight: torch.tensor
		:param corrupted_idx: Corrupted index tensor
		:type corrupted_idx: torch.tensor
		"""
		# Graph and node embeddings
		ha1, ha2 = self.gcn_real(x, edge_index)
		hb1, hb2 = self.gcn_diff(x, diff_edge_index, diff_edge_weight)
		Ha = self.projector_nodes(ha2)
		Hb = self.projector_nodes(hb2)
		ha = self.projector_graph(self.readout(ha1, ha2))
		hb = self.projector_graph(self.readout(hb1, hb2))
		# Corrupted features embeddings
		if corrupted_idx is None:
			corrupted_idx = torch.randperm(x.size(0))
		_, ha2_corrupted = self.gcn_real(x[corrupted_idx], edge_index)
		_, hb2_corrupted = self.gcn_diff(x[corrupted_idx], diff_edge_index, diff_edge_weight)
		Ha_corrupted = self.projector_nodes(ha2_corrupted)
		Hb_corrupted = self.projector_nodes(hb2_corrupted)
		# Discriminator output
		disc_out = self.discriminator(ha, hb, Ha, Hb, Ha_corrupted, Hb_corrupted)
		return disc_out, ha + hb, Ha + Hb 
[docs]
	def encode(self, x: torch.tensor, edge_index: torch.tensor, diff_edge_index: torch.tensor, diff_edge_weight: torch.tensor) -> torch.tensor:
		"""Embedding function
		:param x: Input features
		:type x: torch.tensor
		:param edge_index: Edge index tensor
		:type edge_index: torch.tensor
		:param diff_edge_index: Diffused edge index tensor
		:type diff_edge_index: torch.tensor
		:param diff_edge_weight: Diffused edge weight tensor
		:type diff_edge_weight: torch.tensor
		:return: Node embeddings
		:rtype: torch.tensor
		"""
		_, ha2 = self.gcn_real(x, edge_index)
		_, hb2 = self.gcn_diff(x, diff_edge_index, diff_edge_weight)
		return self.projector_nodes(ha2) + self.projector_nodes(hb2)