From 26317e01f997416319908d62e28f641ec94a2c63 Mon Sep 17 00:00:00 2001 From: "your.user.name" Date: Mon, 21 Oct 2024 09:31:12 -0700 Subject: [PATCH] making readme more clear --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 1979d84e..987f790d 100644 --- a/README.md +++ b/README.md @@ -50,19 +50,19 @@ from topomodelx.nn.simplicial.san import SAN from topomodelx.utils.sparse import from_sparse # Step 1: Load the Karate Club dataset -dataset = tnx.datasets.karate_club(complex_type="simplicial") +karate_club_complex = tnx.datasets.karate_club(complex_type="simplicial") # Step 2: Prepare Laplacians and node/edge features -laplacian_down = from_sparse(dataset.down_laplacian_matrix(rank=1)) -laplacian_up = from_sparse(dataset.up_laplacian_matrix(rank=1)) -incidence_0_1 = from_sparse(dataset.incidence_matrix(rank=1)) +laplacian_down = from_sparse(karate_club_complex.down_laplacian_matrix(rank=1)) +laplacian_up = from_sparse(karate_club_complex.up_laplacian_matrix(rank=1)) +incidence_0_1 = from_sparse(karate_club_complex.incidence_matrix(rank=1)) -x_0 = torch.tensor(np.stack(list(dataset.get_simplex_attributes("node_feat").values()))) -x_1 = torch.tensor(np.stack(list(dataset.get_simplex_attributes("edge_feat").values()))) +x_0 = torch.tensor(np.stack(list(karate_club_complex.get_simplex_attributes("node_feat").values()))) +x_1 = torch.tensor(np.stack(list(karate_club_complex.get_simplex_attributes("edge_feat").values()))) x = x_1 + torch.sparse.mm(incidence_0_1.T, x_0) # Step 3: Define the network -class Network(torch.nn.Module): +class TNN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.base_model = SAN(in_channels, hidden_channels, n_layers=2) @@ -73,7 +73,7 @@ class Network(torch.nn.Module): return torch.sigmoid(self.linear(x)) # Step 4: Initialize the network and perform a forward pass -model = Network(in_channels=x.shape[-1], hidden_channels=16, out_channels=2) +model = TNN(in_channels=x.shape[-1], hidden_channels=16, out_channels=2) y_hat_edge = model(x, laplacian_up=laplacian_up, laplacian_down=laplacian_down) ```