Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: consider adding the specialized methods for missing matrix products #877

Open
NeilGirdhar opened this issue Dec 25, 2024 · 4 comments
Labels
RFC Request for comments. Feature requests and proposed changes. topic: Linear Algebra Linear algebra.

Comments

@NeilGirdhar
Copy link

NeilGirdhar commented Dec 25, 2024

Background

The matmul function. combines seven cases of matrix products in one function. These seven cases can be cleanly organized as examples of four possibly-vectorized operations:

  • vector-dot, (does not support vectorization, but vector_dot covers this perfectly)
  • matrix-matrix product (covers vectorization competely),
  • matrix-vector-product (only supports vectorization in the matrix term), and
  • vector-matrix-product (only supports vectorization in the matrix term).

Unfortunately, the decomposition isn't clean and the last two have operations have missing vectorization cases. For example,

def matrix_vector_mul(x: RealArray, y: RealArray) -> RealArray:
    """Return the matrix-vector product.

    This is xp.einsum("...ij,...j->...i", x, y).
    """
    xp = get_namespace(x, y)
    y = xp.reshape(y, (*y.shape[:-1], 1, y.shape[-1]))
    return xp.sum(x * y, axis=-1)

These would subsume cases 3 and 4.

Proposal

Consider adding matrix_vector_mul and vector_matrix_mul?

@NeilGirdhar NeilGirdhar changed the title Add matrix_vector_product? Consider adding the missing matrix products Dec 25, 2024
@kgryte kgryte changed the title Consider adding the missing matrix products RFC: consider adding the specialized methods for missing matrix products Dec 26, 2024
@kgryte kgryte added the RFC Request for comments. Feature requests and proposed changes. label Dec 26, 2024
@rgommers
Copy link
Member

It looks like this overlaps with matvec and vecmat, which were just added in NumPy 2.2.0: numpy/numpy#25675. Can you please review that PR and the mailing list discussion linked in the PR description, and confirm that that is what you're proposing here @NeilGirdhar?

@rgommers rgommers added the topic: Linear Algebra Linear algebra. label Dec 26, 2024
@NeilGirdhar
Copy link
Author

NeilGirdhar commented Dec 26, 2024

That's exactly what I want, yes. Will these be added to the Array API? Also, will matmul get the complex conjugation behavior?

@rgommers
Copy link
Member

That's exactly what I want, yes. Will these be added to the Array API?

Great, thanks for confirming. I think what is needed first is to add those functions to other libraries. We can't standardize something that is only present in NumPy. Usually other libraries are happy to match NumPy if the API looks clean and there are no blockers like it not working well for accelerators or JITs (which I wouldn't expect to be an issue in this case). So it's "only" a matter of asking JAX, PyTorch et al. if these functions can be added and then someone implementing them.

@jakevdp
Copy link

jakevdp commented Dec 26, 2024

JAX has already added the new functions in jax-ml/jax#25390

@kgryte kgryte added this to Proposals Dec 26, 2024
@kgryte kgryte moved this to Stage 0 in Proposals Dec 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC Request for comments. Feature requests and proposed changes. topic: Linear Algebra Linear algebra.
Projects
Status: Stage 0
Development

No branches or pull requests

4 participants