Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some sort of batched reduction to desc.batching #1564

Open
f0uriest opened this issue Feb 4, 2025 · 0 comments
Open

Add some sort of batched reduction to desc.batching #1564

f0uriest opened this issue Feb 4, 2025 · 0 comments

Comments

@f0uriest
Copy link
Member

f0uriest commented Feb 4, 2025

          not for this PR, but at some point might be worth implementing some sort of batched reduction `desc.batching`

Originally posted by @f0uriest in #1440 (comment)

Basically, if we need to evaluate some expensive function over a bunch of inputs and then reduce it (eg sum). Currently the two main options are a single loop which is slow on GPU, or a full vmap which means materializing the full array in memory (or perhaps hoping the the compiler fuses stuff so it isn't but still). One could potentially improve the performance of the loop by unrolling part of it though that increases compile time. Ideally you want something that takes advantage of the fact that the loop iterations are independent.

semantically should be equivalent to the following:

def unbatched_reduce(fun, x, reduction=jnp.add):
    out = 0
    for xi in x:
        out = reduction(out, fun(xi))
    return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant