diff --git a/asset/graph/correct_reference.png b/asset/graph/correct_reference.png new file mode 100644 index 00000000..de7fe420 Binary files /dev/null and b/asset/graph/correct_reference.png differ diff --git a/asset/graph/inverse_edge.png b/asset/graph/inverse_edge.png new file mode 100644 index 00000000..8be87a36 Binary files /dev/null and b/asset/graph/inverse_edge.png differ diff --git a/asset/graph/wrong_reference.png b/asset/graph/wrong_reference.png new file mode 100644 index 00000000..03b78071 Binary files /dev/null and b/asset/graph/wrong_reference.png differ diff --git a/doc/source/api/layers.rst b/doc/source/api/layers.rst index f09fa456..269a6ccb 100644 --- a/doc/source/api/layers.rst +++ b/doc/source/api/layers.rst @@ -214,6 +214,12 @@ Variadic .. autofunction:: variadic_sample +.. autofunction:: variadic_meshgrid + +.. autofunction:: variadic_to_padded + +.. autofunction:: padded_to_variadic + Tensor Reduction ^^^^^^^^^^^^^^^^ .. autofunction:: masked_mean diff --git a/doc/source/api/metrics.rst b/doc/source/api/metrics.rst index 446022ba..bee8d738 100644 --- a/doc/source/api/metrics.rst +++ b/doc/source/api/metrics.rst @@ -26,9 +26,21 @@ R2 ^^ .. autofunction:: r2 -Variadic Accuracy -^^^^^^^^^^^^^^^^^ -.. autofunction:: variadic_accuracy +Accuracy +^^^^^^^^ +.. autofunction:: accuracy + +Matthews Correlation Coefficient +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofuction:: matthews_corrcoef + +Pearson Correlation Coefficient +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofunction:: pearsonr + +Spearman Correlation Coefficient +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofunction:: spearmanr Chemical Metrics diff --git a/doc/source/notes/reference.rst b/doc/source/notes/reference.rst index 86277aed..e35538e2 100644 --- a/doc/source/notes/reference.rst +++ b/doc/source/notes/reference.rst @@ -30,6 +30,7 @@ the result is not desired. The edges are masked out correctly, but the values of inverse indexes are wrong. .. code:: python + with graph.edge(): graph.inv_edge_index = torch.tensor(inv_edge_index) g1 = graph.edge_mask([0, 2, 3]) @@ -55,34 +56,4 @@ since the corresponding inverse edge has been masked out. :width: 33% We can use ``graph.node_reference()`` and ``graph.graph_reference()`` for references -to nodes and graphs respectively. - -Use Cases in Proteins ---------------------- - -In :class:`data.Protein`, the mapping ``atom2residue`` is implemented as -references. The intuition is that references enable flexible indexing on either atoms -or residues, while maintaining the correspondence between two views. - -The following example shows how to track a specific residue with ``atom2residue`` in -the atom view. For a protein, we first create a mask for atoms in a glutamine (GLN). - -.. code:: python - - protein = data.Protein.from_sequence("KALKQMLDMG") - is_glutamine = protein.residue_type[protein.atom2residue] == protein.residue2id["GLN"] - with protein.node(): - protein.is_glutamine = is_glutamine - -We then apply a mask to the protein residue sequence. In the output protein, -``atom2residue`` is able to map the masked atoms back to the glutamine residue. - -.. code:: python - - p1 = protein[3:6] - residue_type = p1.residue_type[p1.atom2residue[p1.is_glutamine]] - print([p1.id2residue[r] for r in residue_type.tolist()]) - -.. code:: bash - - ['GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN', 'GLN'] \ No newline at end of file +to nodes and graphs respectively. \ No newline at end of file diff --git a/doc/source/notes/variadic.rst b/doc/source/notes/variadic.rst index 2eedcc96..e0c671f0 100644 --- a/doc/source/notes/variadic.rst +++ b/doc/source/notes/variadic.rst @@ -113,6 +113,7 @@ Naturally, the prediction over nodes also forms a variadic tensor with ``num_nod :func:`variadic_topk `, :func:`variadic_randperm `, :func:`variadic_sample `, + :func:`variadic_meshgrid `, :func:`variadic_log_softmax `, :func:`variadic_cross_entropy `, diff --git a/torchdrug/data/graph.py b/torchdrug/data/graph.py index 07d1c60d..dbedfdc9 100644 --- a/torchdrug/data/graph.py +++ b/torchdrug/data/graph.py @@ -699,6 +699,17 @@ def edge_mask(self, index): num_relation=self.num_relation, meta_dict=meta_dict, **data_dict) def line_graph(self): + """ + Construct a line graph of this graph. + The node feature of the line graph is inherited from the edge feature of the original graph. + + In the line graph, each node corresponds to an edge in the original graph. + For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, + there is a directed edge (a, b) -> (b, c) in the line graph. + + Returns: + Graph + """ node_in, node_out = self.edge_list.t()[:2] edge_index = torch.arange(self.num_edge, device=self.device) edge_in = edge_index[node_out.argsort()] @@ -1627,6 +1638,17 @@ def subbatch(self, index): return self.graph_mask(index, compact=True) def line_graph(self): + """ + Construct a packed line graph of this packed graph. + The node features of the line graphs are inherited from the edge features of the original graphs. + + In the line graph, each node corresponds to an edge in the original graph. + For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, + there is a directed edge (a, b) -> (b, c) in the line graph. + + Returns: + PackedGraph + """ node_in, node_out = self.edge_list.t()[:2] edge_index = torch.arange(self.num_edge, device=self.device) edge_in = edge_index[node_out.argsort()] diff --git a/torchdrug/layers/functional/functional.py b/torchdrug/layers/functional/functional.py index 19dbe0ec..acb7d771 100644 --- a/torchdrug/layers/functional/functional.py +++ b/torchdrug/layers/functional/functional.py @@ -375,6 +375,9 @@ def variadic_sort(input, size, descending=False): input (Tensor): input of shape :math:`(B, ...)` size (LongTensor): size of sets of shape :math:`(N,)` descending (bool, optional): return ascending or descending order + + Returns + (Tensor, LongTensor): sorted values and indexes """ index2sample = _size_to_index(size) index2sample = index2sample.view([-1] + [1] * (input.ndim - 1)) @@ -445,6 +448,21 @@ def variadic_sample(input, size, num_sample): def variadic_meshgrid(input1, size1, input2, size2): + """ + Compute the Cartesian product for two batches of sets with variadic sizes. + + Suppose there are :math:`N` sets in each input, + and the sizes of all sets are summed to :math:`B_1` and :math:`B_2` respectively. + + Parameters: + input1 (Tensor): input of shape :math:`(B_1, ...)` + size1 (LongTensor): size of :attr:`input1` of shape :math:`(N,)` + input2 (Tensor): input of shape :math:`(B_2, ...)` + size2 (LongTensor): size of :attr:`input2` of shape :math:`(N,)` + + Returns + (Tensor, Tensor): the first and the second elements in the Cartesian product + """ grid_size = size1 * size2 local_index = variadic_arange(grid_size) local_inner_size = size2.repeat_interleave(grid_size) @@ -456,6 +474,19 @@ def variadic_meshgrid(input1, size1, input2, size2): def variadic_to_padded(input, size, value=0): + """ + Convert a variadic tensor to a padded tensor. + + Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. + + Parameters: + input (Tensor): input of shape :math:`(B, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + value (scalar): fill value for padding + + Returns: + (Tensor, BoolTensor): padded tensor and mask + """ num_sample = len(size) max_size = size.max() starts = torch.arange(num_sample, device=size.device) * max_size @@ -469,6 +500,13 @@ def variadic_to_padded(input, size, value=0): def padded_to_variadic(padded, size): + """ + Convert a padded tensor to a variadic tensor. + + Parameters: + padded (Tensor): padded tensor of shape :math:`(N, ...)` + size (LongTensor): size of sets of shape :math:`(N,)` + """ num_sample, max_size = padded.shape[:2] starts = torch.arange(num_sample, device=size.device) * max_size ends = starts + size diff --git a/torchdrug/transforms/transform.py b/torchdrug/transforms/transform.py index bc86d129..3245ca98 100644 --- a/torchdrug/transforms/transform.py +++ b/torchdrug/transforms/transform.py @@ -11,6 +11,7 @@ class TargetNormalize(object): """ Normalize the target values in a sample. + Parameters: mean (dict of float): mean of targets std (dict of float): standard deviation of targets