Skip to content

Commit

Permalink
Make Blas flags check lazy
Browse files Browse the repository at this point in the history
It replaces the old warning that does not actually apply by a more informative and actionable one.

This warning was for Ops that might use the alternative blas_headers, which rely on the Numpy C-API.

However, regular PyTensor user has not used this for a while. The only Op that would use C-code with this alternative headers is the GEMM Op which is not included in current rewrites. Instead Dot22 or Dot22Scalar are introduced, which refuse to generate C-code altogether if the blas flags are missing.
  • Loading branch information
ricardoV94 committed Jan 23, 2025
1 parent a0fe30d commit 89af81b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
7 changes: 7 additions & 0 deletions pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2947,6 +2947,13 @@ def check_libs(
except Exception as e:
_logger.debug(e)
_logger.debug("Failed to identify blas ldflags. Will leave them empty.")
warnings.warn(
"PyTensor could not link to a BLAS installation. Operations that might benefit from BLAS will be severely degraded.\n"
"This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.\n"
"Alternatively, you can use a experimental backend such as Numba or JAX that perform their own BLAS optimizations, "
"by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.",
UserWarning,
)
return ""


Expand Down
9 changes: 5 additions & 4 deletions pytensor/tensor/blas_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,11 @@ def blas_header_text():

blas_code = ""
if not config.blas__ldflags:
# This code can only be reached by compiling a function with a manually specified GEMM Op.
# Normal PyTensor usage will end up with Dot22 or Dot22Scalar instead,
# which opt out of C-code completely if the blas flags are missing
_logger.warning("Using NumPy C-API based implementation for BLAS functions.")

# Include the Numpy version implementation of [sd]gemm_.
current_filedir = Path(__file__).parent
blas_common_filepath = current_filedir / "c_code/alt_blas_common.h"
Expand Down Expand Up @@ -1003,10 +1008,6 @@ def blas_header_text():
return header + blas_code


if not config.blas__ldflags:
_logger.warning("Using NumPy C-API based implementation for BLAS functions.")


def mkl_threads_text():
"""C header for MKL threads interface"""
header = """
Expand Down

0 comments on commit 89af81b

Please sign in to comment.