Skip to content

Commit

Permalink
Merge branch 'release' into polynomial
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Mar 1, 2025
2 parents 9c9a077 + f5a67a5 commit bd5c605
Show file tree
Hide file tree
Showing 29 changed files with 86 additions and 463 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## Latest Changes

## 0.3.0-rc1

### Breaking Changes
- In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed `math_dtype` and `output_dtype` respectively. Adding consistency with the rest of the library.
- In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` are removed. This is to avoid a proliferation of arguments that are not used in all the implementations. An argument `impl: str` is added instead to select the implementation.
Expand All @@ -16,6 +18,7 @@
### Added
- JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) have only one index. It handles batched/shared/indexed input/output. The indexed input/output are handled by atomic operations.
- Added `__mul__` to `cue.EquivariantTensorProduct` to allow rescaling the equivariant tensor product.
- Added a uniform 1d kernel with scatter/gather fusion under `cuet.primitives.tensor_product.TensorProductUniform4x1dIndexed` and `cuet.primitives.tensor_product.TensorProductUniform3x1dIndexed`.


## 0.2.0 (2025-01-24)
Expand Down
478 changes: 49 additions & 429 deletions LICENSES/Third_party_attr.txt

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ pip install cuequivariance-torch
pip install cuequivariance # Installs only the core non-ML components

# CUDA kernels for different CUDA versions
pip install cuequivariance-ops-torch-cu11
pip install cuequivariance-ops-torch-cu12
pip install cuequivariance-ops-jax-cu12
pip install cuequivariance-ops-torch-cu12 # or cu11
```

## License
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.3.0rc4
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/misc/linalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/cuequivariance/operation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ classifiers = [

[tool.hatch.version]
path = "cuequivariance/VERSION"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+)"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+(?:[a-z]+\\d+)?)"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance_jax/cuequivariance_jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -389,7 +389,9 @@ def segmented_polynomial_jvp(
jvp_buffer_index,
polynomial.jvp([not isinstance(t, ad.Zero) for t in tangents]),
math_dtype,
name + "_jvp",
name
+ "_jvp"
+ "".join("0" if isinstance(t, ad.Zero) else "1" for t in tangents),
impl=impl,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -25,7 +25,10 @@


def sanitize_string(s):
return re.sub(r"[^A-Za-z_]", "", s)
s = re.sub(r"[^A-Za-z0-9_]", "", s)
if s == "" or s[0].isdigit():
s = "_" + s
return s


def segmented_polynomial_ops_impl(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance_jax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ classifiers = [

[tool.hatch.version]
path = "cuequivariance_jax/VERSION"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+)"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+(?:[a-z]+\\d+)?)"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ def __init__(
):
super().__init__()

if use_fallback is not True:
# There is a bug in the CUDA kernel for TransposeIrrepsLayout
# TODO remove this when the bug is fixed
use_fallback = True

info = _transpose_info(segments, device=device)
self.f = None

Expand Down
2 changes: 1 addition & 1 deletion cuequivariance_torch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ classifiers = [

[tool.hatch.version]
path = "cuequivariance_torch/VERSION"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+)"
pattern = "(?P<version>\\d+\\.\\d+\\.\\d+(?:[a-z]+\\d+)?)"
2 changes: 1 addition & 1 deletion docs/api/cuequivariance.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion docs/api/cuequivariance_jax.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/beta.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/index.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
Expand Down

0 comments on commit bd5c605

Please sign in to comment.