Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fit the training data #4

Open
Chen-Cai-OSU opened this issue Nov 8, 2020 · 24 comments
Open

fit the training data #4

Chen-Cai-OSU opened this issue Nov 8, 2020 · 24 comments

Comments

@Chen-Cai-OSU
Copy link

Chen-Cai-OSU commented Nov 8, 2020

Hi Fabian,

When I use se3-transformer for my dataset. I find it seems quite difficult for the model to fit the training data. To understand why, I create a simple task in the following way.

I generate a few hundred point clouds sampled on the surface of an ellipsoid in 3d (centered at (0,0,0)). I construct a KNN (k=5) graph for each point cloud. The goal is to predict the first eigenvector of the covariance matrix for each point cloud. This is a type-1 feature. I am using the following model

ModuleList(
  (0): GSE3Res(
    (GMAB): ModuleDict(
      (v): GConvSE3Partial(structure=[(8, 0), (8, 1)])
      (k): GConvSE3Partial(structure=[(8, 0)])
      (q): G1x1SE3(structure=[(8, 0)])
      (attn): GMABSE3(n_heads=8, structure=[(8, 0), (8, 1)])
    )
    (project): G1x1SE3(structure=[(32, 0), (32, 1)])
    (add): GSum(structure=[(32, 0), (32, 1)])
  )
  (1): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (2): GConvSE3(structure=[(32, 0), (32, 1)], self_interaction=True)
  (3): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (4): GConvSE3(structure=[(32, 0), (32, 1)], self_interaction=True)
  (5): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (6): GConvSE3(structure=[(32, 0), (32, 1)], self_interaction=True)
  (7): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (8): GConvSE3(structure=[(1, 1)], self_interaction=True)
  (9): GAvgVecPooling(
    (pool): AvgPooling()
  )
)

however, when I trained with Adam optimizer with learning rate 1e-3 to minimize the mse loss between predicted eigenvector and the true eigenvector, I end up the training loss roughly around 0.05, and the cosine similarly between predicted eigenvector and true eigenvector is roughly 0.65. (which means the the angle is more than 45 degree)

The number of the parameters of the model is 643296, which is probably not large but my data set is also tiny (200 neighborhood graphs constructed from the point clouds). So I am a bit surprised why the model cannot even fit the data exactly. (I am trying to use more layers but Cuda memory quickly runs out)

Is there some places I should pay special attention to when using a se3-transformer? Maybe because the when the equivariance constraint is placed on the kernel, the model will be less flexible therefore harder to fit the training data? Should I try to increase model size, or try different optimizers and learning rates? I can share the data if needed.

Thank you!

@FabianFuchsML
Copy link
Owner

Hi Chen,

That's a great toy experiment! It's also quite different from the ones we ran, but here are my thoughts:

  • Type 1 feature prediction is obviously different to type 0 for multiple reasons. An important one is that for type 1 you can't have (unconstrained) fully connected layers in the end. In a way, you could argue that type 0 is less constrained than type 1 prediction. Attention (vs. convolution) should alleviate some of this. While it will always depend on the dataset what works best, my default recommendation would be to only use attention layers and no pure convolution layers (it seems you opted for 1 attention + 4 convolutions).

  • I would also recommend really look into the learnable parameters and where they come from. Remember that whenever you have invariant features, you can run them through any kind of (learnable) nonlinearity (e.g. fully connected layers) you like - without breaking equivariance. This allows you to add capacity to the neural network.

  • A good check is (but you probably did this already) to rotate both data input (and therefore predictions) and ground truth vectors and see whether predictions and ground truth rotate in the same way.

  • I see you do a global pooling in the last layer. You could alternatively let each point predict the output with a separate loss and train from those losses. (and only average the predictions for inference)

  • How many points do you have? I am wondering whether with 5 nearest neighbors, you actually efficiently transport the information throughout the entire point cloud? You could also try other things like fully connected graphs or random connections as this particular problem asks primarily for global context (whereas local information almost doesn't matter).

  • A few hundred point clouds is actually not that little, if you just wanna test the models capability to overfit to the training set. Maybe try with 10? Also, learning rates etc. obviously always make a difference as well

@Chen-Cai-OSU
Copy link
Author

Hi Fabian,

Thanks for all the great suggestions! I would like to report some empirical observations.

I reduce the number of KNN graphs from point clouds to 1. In that case, I am able to fit the data exactly. However, even I slightly increase to the number of graphs to 5, then I am not the ability to fit the data. It also seems that more graphs I have, it's harder for SE3-transformer to fit the data. (The cosine similarity decreases when I increase the number of graphs)

I spend two hours trying different K, the size of the graph, varying architecture (switch from GConvSE3 to GSE3Res), and learning rate decay. however, I didn't see significant improvement.

According to this recent paper https://arxiv.org/abs/2010.02449 by Haggai Maron, SE3-transformer are universal. So I would like to try my best to solve this toy problem before moving to real data.

Let me know if you want to try it yourself. I can share the script.

@FabianFuchsML
Copy link
Owner

Hi Chen, yes, it would be very interesting to have a look at your script. Feel free to send it over!

@Chen-Cai-OSU
Copy link
Author

Chen-Cai-OSU commented Nov 11, 2020

# Created at 2020-11-11
# Summary: syn data shared with Fabian

import os
import pickle

import dgl
import networkx as nx
import scipy
import torch
from scipy.stats import special_ortho_group
from tqdm import tqdm

try:  # make sure load and save works for different versions of dgl
    from dgl import save_graphs, load_graphs
except ImportError:
    from dgl.data import load_graphs, save_graphs

import numpy as np
from scipy.spatial import distance
from sklearn.preprocessing import normalize

export_dir = os.path.join('pca-data', '')


def save_info(path, info):
    """ Save dataset related information into disk.

    Parameters
    ----------
    path : str
        File to save information.
    info : dict
        A python dict storing information to save on disk.
    """
    with open(path, "wb") as pf:
        pickle.dump(info, pf)


def load_info(path):
    """ Load dataset related information from disk.

    Parameters
    ----------
    path : str
        File to load information from.

    Returns
    -------
    info : dict
        A python dict storing information loaded from disk.
    """
    with open(path, "rb") as pf:
        info = pickle.load(pf)
    return info


class Data(object):
    def __init__(self, n=200, k=180, verbose=False):
        self.n = n
        self.k = k
        self.verbose = verbose
        self.node_attributes = ['x', 'y', 'z', 'dummy1', 'dummy2', 'dummy3']

    def _get_nx(self):
        """ generate points on a ellopsoid """
        points = np.random.rand(self.n, 3) - 0.5
        points = normalize(points, axis=1)
        points -= np.mean(points, axis=0)

        scale_matrix = np.diag(np.random.rand(3)) * 5  # np.array([[1,0,0],[0,2,0],[0,0,3]])
        rot = special_ortho_group.rvs(3)
        points = points @ scale_matrix @ rot

        D = distance.squareform(distance.pdist(points))
        closest = np.argsort(D, axis=1)
        closest = closest[:, 1:self.k + 1]
        g = nx.Graph()

        nodes = []
        for i in range(self.n):
            node = (i, {'x': points[i][0], 'y': points[i][1], 'z': points[i][2], 'dummy1': 1, 'dummy2': 1, 'dummy3': 1})
            nodes.append(node)
        g.add_nodes_from(nodes)

        edges = []
        for i in range(self.n):
            for j in closest[i]:
                edges.append((i, j))
        g.add_edges_from(edges)

        cov = points.T @ points * (1 / (self.n-1))
        vals, vecs = scipy.linalg.eigh(cov)
        return g, vecs

    def read_gml_dgl(self, n_feat=10):
        #          DGLGraph(num_nodes=21, num_edges=420,
        #          ndata_schemes={'x': Scheme(shape=(0,), dtype=torch.float32), 'f': Scheme(shape=(16, 1), dtype=torch.float32)}
        #          edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(0,), dtype=torch.float32)})
        nxg, vecs = self._get_nx()
        label = {}
        label['vecs'] = torch.tensor(vecs)

        try:
            g = dgl.from_networkx(nxg, node_attrs=self.node_attributes)
        except AttributeError:
            g = dgl.DGLGraph()
            g.from_networkx(nxg, node_attrs=self.node_attributes)

        node_data = torch.stack([g.ndata.__getitem__(k) for k in self.node_attributes], 1).type(torch.FloatTensor)
        g.ndata['x'] = node_data
        for k in self.node_attributes:
            if k != 'x':
                g.ndata.pop(k)

        src, dst = g.edges()[0], g.edges()[1]
        pos = g.ndata['x'][:, :3]
        try:
            g.ndata['f'] = g.ndata['x'][:, 3:, None]
            g.edata['d'] = pos[dst, :] - pos[src, :]
            g.edata['w'] = torch.rand(g.num_edges(), 0)
        except:  # 0.4.3.post2
            g.ndata['f'] = g.ndata['x'][:, 3:, None]
            g.edata['d'] = pos[dst, :] - pos[src, :]
            g.edata['w'] = torch.rand(len(g.edges), 0)

        return g, label

    def get_data(self, lib='dgl'):
        g, label_dict = self.read_gml_dgl()
        return g, label_dict


class SynDGLDataset():
    def __init__(self, root=export_dir, n_graph=2, n_pts=200, n_nbrs=200,reload=False):
        # modified from https://bit.ly/3jx6Eue
        self.save_path = root
        self.n_graph = n_graph
        self.reload = reload
        self.name = f'pcd-cov-{n_graph}'
        self.n_pts = n_pts
        self.n_nbrs = n_nbrs

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
        self.info_path = os.path.join(self.save_path, self.name + '_info.pkl')

    def get_graphs(self):
        self.graphs = []
        self.labels = {}
        for idx in tqdm(range(1, self.n_graph + 1)):
            g, label_dict = Data(n=self.n_pts, k=self.n_nbrs).get_data(lib='dgl')
            self.graphs.append(g)
            self.labels[idx] = label_dict

    def save(self):
        if self.has_cache():
            return

        self.get_graphs()
        # save graphs and labels
        save_graphs(self.graph_path, self.graphs)  # {'labels': self.labels}
        # save other information in python dict
        save_info(self.info_path, self.labels)

    def load(self):
        print(f'load pca-cov from {self.graph_path}')
        graphs, _ = load_graphs(self.graph_path)
        label_dict = load_info(self.info_path)
        return graphs, label_dict

    def has_cache(self):
        ret = os.path.exists(self.graph_path) and os.path.exists(self.info_path)
        if self.reload: ret = False
        if ret: print(f'{self.graph_path} exist')
        return ret


if __name__ == '__main__':
    for n in [5]:
        D = SynDGLDataset(n_graph=n, n_pts=200, n_nbrs=200, reload=True)
        D.save()
        graphs, label_dict = D.load()
        print(graphs)

@Chen-Cai-OSU
Copy link
Author

Chen-Cai-OSU commented Nov 11, 2020

I am using the following version
dgl: 0.4.3post2
torch: 1.6.0

You can modify the n_pts and n_nbrs for your needs. Let me know if you need any clarification. Thank you!

@FabianFuchsML
Copy link
Owner

Awesome, thank you! Hopefully, I will find the time to play around with it next week, but no promises.

@FabianFuchsML
Copy link
Owner

Another suggestion for analysing what's happening: you could try to have a few fully connected layers on each of the per-point outputs. This will obviously break the equivariance, but it could help analyse where the overfitting breaks. You can also directly compare that to using the same amount of fully connected layers applied directly (and again in a per-point fashion) to the inputs.

@Chen-Cai-OSU
Copy link
Author

Awesome, thank you! Hopefully, I will find the time to play around with it next week, but no promises.

No worries. I guess right now everyone is busy with ICLR rebuttal:-)

@Chen-Cai-OSU
Copy link
Author

Chen-Cai-OSU commented Nov 14, 2020

Hi Fabian, I have another question regarding the time and memory of the se3-transformer (in general equivariant NN). It seems that the equivariance comes with a price of slower runtime and more memory.

In the paper, you mention that it takes 2.5 days to train the nets for QM9. If it's only for one regression (or it's the total time for 6 tasks?), then it's roughly 72min for one epoch. In contrast, I run a simple example on qm9 here https://github.com/rusty1s/pytorch_geometric/blob/master/examples/qm9_nn_conv.py and it took 3min per epoch.

I was wondering is it 20x more time roughly the right scale here? Also, seems that equivariance nets in general is memory expansive. Would you like to point out the source of slow speed and high memory usage? I am interested in improving it if it is not too hard.

@blondegeek
Copy link

blondegeek commented Nov 15, 2020

Hi @Chen-Cai-OSU and @FabianFuchsML,

Multiple things to check here:

  1. Make sure the model L=1 output convention and the coordinate convention you use for your vector match. For example, in e3nn L=1 features are given in order (y, z, x) to match the conventions for real spherical harmonics whereas Cartesian coordinates are (x, y, z)
  2. One of the reasons why the model is not fitting may be because the question is not symmetrically well-posed. There are actually degenerate answers for the first eigenvector -- it can be v OR -v (although the degeneracy can be higher depending on the shape). You can try instead predict an L=2 feature which is capable of predicting a double-headed ray (similar to a vector but symmetric under 180 rotations). For example, you would just plug in the vector from your eigensolver into the expressions for the L=2 spherical harmonics (xy, yz, 2z^2-x^2-y^2, zx, x^2-y^2) to get the appropriate prediction coefficients.

Regarding memory and runtime of equivariant networks, (at least for e3nn) these bottlenecks are primarily due to the combinatorial nature of the geometric tensor product (contracting two representation indices with Clebsch-Gordan coefficients) and the fact that there are no readily available CUDA kernels for doing these operations. There are many ways around this. For example, one can create specialized geometric tensor product modules that do not do all possible products, but rather a subset.

Hope that helps a bit!
Tess

@Chen-Cai-OSU
Copy link
Author

Hi @blondegeek!

Thanks a lot for all the suggestions and explanations.
I think for 1) I have checked that the output rotates properly when I rotate the input 2) I am using (negative) absolute cosine similarity as loss and metric so the 'up-to-sign' problem should be already solved. However, I am not able to fit even three point clouds, (200 points each point cloud) which is a bit puzzing. I will wait for Fabian to take a look at the dataset.

Would you like to elaborate on ``combinatorial nature of the geometric tensor product''? You mean to calculate the tensor product of type-a and type-b (a,b=0,1,2,...) irreducible representations is very expansive? How do you choose the subset of products? Thank you!

@blondegeek
Copy link

blondegeek commented Nov 15, 2020 via email

@FabianFuchsML
Copy link
Owner

Hi both!

Some additional remarks about speed (I will put this somewhere in the readme):

  • Constructing equivariant layers is indeed computationally expensive, one of the bottlenecks being the computation of the spherical harmonics.

  • We did put significant effort into speeding up the computation of spherical harmonics, in part by parallelising the computation on the GPU (the paper goes into more depth about what we did). A key challenge here is that they depend on the relative distances and hence are different for each point cloud / graph.

  • Further speeding up SE3-Transformer type approaches (approaches working with irreducible representations, that is) and potentially making them more memory efficient would certainly be a big step forward. If anyone is interested in researching this direction, we can only encourage them.

Here are some ideas about speeding up the SE3-Transformer:

  • Depending on the dataset and your system requirements, it might be possible to cache the spherical harmonics / basis vectors. This could tremendously speed things up as it basically addresses all the bottlenecks at once.

  • A not-so-puristic alternative/version of the above is voxelising the input. This destroys exact equivariance but makes it much easier to cache the spherical harmonics as it reduces the number of overall evaluations.

  • Lower hanging fruit: check what part of the network you can make slimmer for a specific task. E.g., the number of degrees makes a big difference in speed & memory and the effect on performance saturates.

Best
Fabian

@FabianFuchsML
Copy link
Owner

@blondegeek

Regarding memory and runtime of equivariant networks, (at least for e3nn) these bottlenecks are primarily due to the combinatorial nature of the geometric tensor product (contracting two representation indices with Clebsch-Gordan coefficients) and the fact that there are no readily available CUDA kernels for doing these operations. There are many ways around this. For example, one can create specialized geometric tensor product modules that do not do all possible products, but rather a subset.

This is super interesting! It sounds sort of similar to what I found out when I spent some time digging into what the bottlenecks are - but then also not quite. I wish I could remember more precisely what my findings were. In the beginning, the bottleneck was definitely purely the spherical harmonics. But after speeding them up by shifting the computations to the GPU (all the credit here goes to Daniel), the bottleneck was equally split between multiple parts - one of them being constructing the basis vectors from the spherical harmonics and the Clebsch-Gordon coefficients. It sounds like there is some potential if one wanted to get into CUDA programming.

@Chen-Cai-OSU
Copy link
Author

Chen-Cai-OSU commented Nov 16, 2020

Thanks @blondegeek for the reference and explanations.

Have you been able to successfully overfit to one example? That should help debug whether the task is set up correctly.

Yes. For a single pointclouds I can overfit. For two point clouds, I can also overfit. But starting from 3 point clouds, I cannot overfit anymore :-(

Another thing to be aware of, even if your loss function allows for both -v and v to be "correct", the network will ALWAYS output the linear combination of the two degenerate possibilities, which in this case is zero.

I didn't understand why this is the case. In your paper, I understand it's easy to convert a rectangle into a square as the latter has more symmetry but not the other way around. But how is this related to the prediction of eigenvectors? If my loss is set to encourage the output to be either v or -v, why will the network want to output the 0? I guess this is a subtle point that I haven't yet grasped.

@FabianFuchsML
Copy link
Owner

FabianFuchsML commented Nov 19, 2020

Hi Chen,

I had a little time to look through your code today. It's a great toy example and I would love to try to debug / crack it, but I am not too optimistic that I will have the time to get to it. I looked at how you sample points on an ellipsoide in _get_nx(). At first, you seem to sample on a sphere. Then there is a line points -= np.mean(points, axis=0), which moves the sphere away from the origin. Is this on purpose? You did state at some point that you were trying to sample from an ellipsoid centered around 0. Maybe I am overlooking something and you are re-centering it later, but I would recommend working with ellipsoids which are centered around 0 (as you said), as this seems less prone to errors down the line. Also, did you visualise the ellipsoids together with their principal axes?

@Chen-Cai-OSU
Copy link
Author

Then there is a line points -= np.mean(points, axis=0), which moves the sphere away from the origin. Is this on purpose?

Hi Fabian, I remember I tried both 1) points -= np.mean(points, axis=0) and 2) not recentering it. np.mean(points, axis=0) is very close to the center. It's just the sample mean. I tried both versions and didn't see significant differences in term of fitting training data.

@Chen-Cai-OSU
Copy link
Author

Screen Shot 2020-11-19 at 5 35 58 PM

I believe the code is correct.

@blondegeek
Copy link

blondegeek commented Nov 20, 2020

Ok, so here's how I would do it with e3nn and torch_geometric (because I'm not familiar with the se3-transformer codebase or gdl).

Here's a simple network fitting on a single example. Note, I'm not fitting to eigenvectors but rather the rotated scale matrix, which is basically the matrix you are getting the eigenvectors for.

Here's the same network fitting on multiple training examples. It needs more training, but it's starting to get the idea.

@Chen-Cai-OSU
Copy link
Author

Chen-Cai-OSU commented Nov 20, 2020

Ok, so here's how I would do it with e3nn and torch_geometric (because I'm not familiar with the se3-transformer codebase or gdl).

Here's a simple network fitting on a single example. Note, I'm not fitting to eigenvectors but rather the rotated scale matrix, which is basically the matrix you are getting the eigenvectors for.

Here's the same network fitting on multiple training examples. It needs more training, but it's starting to get the idea.

Thanks @blondegeek for the nice notebooks! I am starting to trying out TFN now.

My quick question is that is about predicting the eigenvectors. Is it a bad idea to try to use equivariance NN to predict the eigenvectors in this case? Is it because of both v and -v are the right answer so the NN will tend to output 0? I still don't understand why this is the case. Even if I set the loss (to maximize the absolute cosine similarity between predicted eigenvector and true eigenvector) to account for this "up-to-sign" problem, is this still doomed to fail?

@blondegeek
Copy link

The key issue is that eigenvector solvers are not symmetry preserving, they "pick" an eigenvector typically based on a random initialization or similarly arbitrary procedure. This becomes especially problematic for symmetric structures.

Let's consider two higher symmetry cases. Let's say the scaling matrix is the identity torch.eye(3). What are the principle eigenvectors? They are degenerate -- a sphere is radially symmetric, any three orthogonal directions are equally valid and can be in any order.

How about if the scaling matrix is something like torch.eye(3) * [1, 1, 2] so that the ellispoid is radially symmetric along one axis. You have a similar problem, there is no unique way to choose the eigenvectors in a rotation equivariant manner.

So the issue is more than "just a sign" -- the issue is that the question is symmetrically ill posed. Principle axes are not vectors, they are double headed rays and you need L=2 features to describe them. A 3x3 matrix can handle the spherically symmetric case, it will just predict a matrix with a scalar trace (an identity matrix) and any less symmetric case.

Hope that helps!

@Chen-Cai-OSU
Copy link
Author

Thanks @blondegeek for the further clarification!
I understand the case where eigenvectors with the multiplicity, any vector in the eigenspace can be taken as eigenvectors.

Principle axes are not vectors, they are double headed rays and you need L=2 features to describe them

What is double headed rays exactly? I saw this slide (at around 20 min) in your talk https://sites.google.com/view/equiv-data-aug/home
Screen Shot 2020-11-21 at 12 54 47 PM

I can find the Pseudovector in Wikipedia but I didn't find any good references on the double headed rays and spiral. I am familiar with covariant/contravariant tensors but never head of double headed rays and spiral. Do you mind pointing out some references?

@blondegeek
Copy link

blondegeek commented Nov 21, 2020 via email

@Chen-Cai-OSU
Copy link
Author

@blondegeek Hi Tess, I am using e3nn (it's nice that change of basis matrix can be handled by to_irrep_transformation.) but I had some issues verifying the equivariance for the 3*3 matrix (Rs out=[(1, 0, 1), (1, 2, 1)])

Would you mind taking a look at e3nn/e3nn#149? Many thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants