import torch
from torch import nn
from torch_geometric.utils import dense_to_sparse
from tqdm import trange
from algorithms.deep.DeepAlgorithm import DeepAlgorithm
from algorithms.deep.models.MVGRLModel import MVGRLModel
from algorithms.deep.utils import get_clusters, compute_diffusion_matrix
from graph import Graph
[docs]
class MVGRL(DeepAlgorithm):
	"""Multi-View Graph Representation Learning algorithm
	:param graph: Graph object
	:type graph: Graph
	:param lr: Learning rate
	:type lr: float
	:param latent_dim: Latent dimension
	:type latent_dim: int
	:param epochs: Number of epochs to run
	:type epochs: int
	:param use_pretrained: Boolean flag to indicate if pretrained model should be used
	:type use_pretrained: bool
	:param save_model: Boolean flag to indicate if the model should be saved after training
	:type save_model: bool
	"""
	def __init__(self, graph: Graph, num_clusters: int, lr: float = .001, latent_dim: int = 16, epochs: int = 100, use_pretrained: bool = True, save_model: bool = False):
		"""Constructor method
		"""
		super(MVGRL, self).__init__(graph, num_clusters=num_clusters, lr=lr, latent_dim=latent_dim, epochs=epochs, use_pretrained=use_pretrained, save_model=save_model)
		self.evaluation_clustering_tries = 400
		self.model: MVGRLModel = MVGRLModel(in_channels=graph.features.shape[1], latent_dim=latent_dim)
		if self.use_pretrained:
			self._load_pretrained()
		self.optimizer: torch.optim.Optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
		# Compute the diffusion matrix
		diff = torch.from_numpy(compute_diffusion_matrix(graph.adj_matrix)).float()
		diff[diff < 1e-5] = 0  # remove small values for computational time (sparse matrix, from 6M to 300k edges for cora)
		diff = dense_to_sparse(diff)
		self.diff_edge_index = diff[0]
		self.diff_edge_weight = diff[1].float()
	def _train(self) -> None:
		"""Trains the model
		"""
		corrupted_labels = torch.randperm(self.x_t.size(0))
		true_labels = torch.ones(self.x_t.size(0) * 2, dtype=torch.float)
		labels = torch.cat([true_labels, true_labels * 0], dim=-1)
		criterion = nn.BCEWithLogitsLoss()
		for _ in (pbar := trange(self.epochs, desc="MVGRL Training")):
			self.model.train()
			self.optimizer.zero_grad()
			# Training the model
			discriminator_output, _, _ = self.model(self.x_t, self.edge_index_t, self.diff_edge_index, self.diff_edge_weight, corrupted_labels)
			loss = criterion(discriminator_output, labels)
			loss.backward()
			self.optimizer.step()
			# Evaluation
			self.model.eval()
			self.clusters = get_clusters(self._encode_nodes(), self.num_clusters)
			evaluation = self.evaluate()
			pbar.set_postfix({"Loss": loss.item(), **dict(evaluation)})
	def _encode_nodes(self) -> torch.tensor:
		"""Encodes the node features using the model
		:return: Node embeddings
		:rtype: torch.tensor
		"""
		return self.model.encode(self.x_t, self.edge_index_t, self.diff_edge_index, self.diff_edge_weight).detach().numpy()
	def __str__(self):
		"""Returns the string representation of the algorithm object
		:return: String representation of the algorithm object
		:rtype: str
		"""
		return "MVGRL algorithm object"