Source code for src.graph.Graph

import networkx as nx
import numpy as np
import as mplcm
import matplotlib.colors as mplcolors
import matplotlib.pyplot as plt

[docs] class Graph: """Graph class to represent a graph using the adjacency matrix and features :param adj_matrix: Adjacency matrix of the graph of shape (n, n) where n is the number of nodes in the graph :type adj_matrix: np.ndarray :param features: Features of the nodes in the graph of shape (n, f) where f is the number of features per node :type features: np.ndarray :param labels: Labels of the nodes in the graph of shape (n,) where n is the number of nodes in the graph :type labels: np.ndarray :param directed: Boolean flag to indicate if the graph is directed or undirected :type directed: bool :param dataset_name: Name of the dataset (if possible) :type dataset_name: str """ def __init__(self, adj_matrix: np.ndarray, directed: bool = False, features: np.ndarray = None, labels: np.ndarray = None, dataset_name: np.ndarray = ""): """Constructor method """ self.adj_matrix: np.ndarray = adj_matrix self.edge_index: np.ndarray = np.argwhere(adj_matrix).transpose() self.edge_weight: np.ndarray = adj_matrix[self.edge_index.transpose()[:, 0], self.edge_index.transpose()[:, 1]] self.features: np.ndarray = features if features is not None else np.ones((adj_matrix.shape[0], 1)) self.labels: np.ndarray = labels self.dataset_name: str = dataset_name self.directed: bool = directed self.nx_graph: nx.Graph | nx.DiGraph = nx.DiGraph() if directed else nx.Graph() self._create_nx_graph() def _create_nx_graph(self) -> None: """Adds nodes and edges to the NetworkX graph object """ self.nx_graph.add_nodes_from(range(self.features.shape[0]), features=self.features) self.nx_graph.add_edges_from(self.edge_index.transpose()) for i, (u, v) in enumerate(self.edge_index.transpose()): self.nx_graph[u][v]['weight'] = self.edge_weight[i]
[docs] def draw(self, clusters: list[int] = None, draw_labels: bool = False) -> None: """Draws the graph using the NetworkX draw method. If clusters are provided, the nodes are colored based on the clusters based on :param clusters: List of cluster labels for each node :type clusters: list[int] :param draw_labels: Boolean flag to indicate if the labels should be drawn :type draw_labels: bool """ if clusters is not None: num_colors = max(clusters) + 1 cm = plt.get_cmap('gist_rainbow') cNorm = mplcolors.Normalize(vmin=0, vmax=num_colors - 1) scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm) color_map = [scalarMap.to_rgba(i) for i in range(num_colors)] colors = [color_map[clusters[i]] for i in range(len(clusters))] else: colors = "darkgray" node_size = 50 if len(self.features) < 50 else 5 pos = nx.spring_layout(self.nx_graph) nx.draw_networkx_nodes(self.nx_graph, pos, edgecolors="darkgray", linewidths=.5, node_color=colors, node_size=node_size) nx.draw_networkx_edges(self.nx_graph, pos, width=1, alpha=0.5, edge_color="gray") if draw_labels: nx.draw_networkx_labels(self.nx_graph, pos) plt.tight_layout() plt.axis("off")
def __str__(self): """Returns the string representation of the graph object :return: String representation of the graph object (number of nodes and edges) :rtype: str """ return f"Graph with {self.adj_matrix.shape[0]} nodes and {self.edge_index.shape[1]} edges." def __repr__(self): """Returns the string representation of the graph object :return: String representation of the graph object (number of nodes and edges) :rtype: str """ return self.__str__() def __getitem__(self, key: int) -> np.ndarray: """Returns the adjacency matrix of the graph for the given key :param key: Key to access the adjacency matrix :type key: int :return: Adjacency matrix of the graph for the given key :rtype: np.ndarray """ return self.adj_matrix[key]