Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Int precision and stencil regeneration change #104

Merged
merged 10 commits into from
Feb 12, 2025
2 changes: 1 addition & 1 deletion ndsl/boilerplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_factories(

compilation_config = CompilationConfig(
backend=backend,
rebuild=True,
rebuild=False,
validate_args=True,
format_source=False,
device_sync=False,
Expand Down
4 changes: 2 additions & 2 deletions ndsl/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ndsl.dsl.caches.cache_location import identify_code_path
from ndsl.dsl.caches.codepath import FV3CodePath
from ndsl.dsl.gt4py_utils import is_gpu_backend
from ndsl.dsl.typing import floating_point_precision
from ndsl.dsl.typing import get_precision
from ndsl.optional_imports import cupy as cp


Expand Down Expand Up @@ -264,7 +264,7 @@ def __init__(
"compiler", "cuda", "syncdebug", value=dace_debug_env_var
)

if floating_point_precision() == 32:
if get_precision() == 32:
# When using 32-bit float, we flip the default dtypes to be all
# C, e.g. 32 bit.
dace.Config.set(
Expand Down
43 changes: 33 additions & 10 deletions ndsl/dsl/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,41 @@
DTypes = Union[bool, np.bool_, int, np.int32, np.int64, float, np.float32, np.float64]


# Depreciated version of get_precision, but retained for a PACE dependency
def floating_point_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))


def get_precision() -> int:
return int(os.getenv("PACE_FLOAT_PRECISION", "64"))


# We redefine the type as a way to distinguish
# the model definition of a float to other usage of the
# common numpy type in the rest of the code.
NDSL_32BIT_FLOAT_TYPE = np.float32
NDSL_64BIT_FLOAT_TYPE = np.float64
NDSL_32BIT_INT_TYPE = np.int32
NDSL_64BIT_INT_TYPE = np.int64


def global_set_floating_point_precision():
"""Set the global floating point precision for all reference
to Float in the codebase. Defaults to 64 bit."""
global Float
precision_in_bit = floating_point_precision()
def global_set_precision() -> type:
"""Set the global precision for all references of
Float and Int in the codebase. Defaults to 64 bit."""
global Float, Int
precision_in_bit = get_precision()
if precision_in_bit == 64:
return NDSL_64BIT_FLOAT_TYPE
return NDSL_64BIT_FLOAT_TYPE, NDSL_64BIT_INT_TYPE
elif precision_in_bit == 32:
return NDSL_32BIT_FLOAT_TYPE
return NDSL_32BIT_FLOAT_TYPE, NDSL_32BIT_INT_TYPE
else:
NotImplementedError(
raise NotImplementedError(
f"{precision_in_bit} bit precision not implemented or tested"
)


# Default float and int types
Float = global_set_floating_point_precision()
Int = np.int_
Float, Int = global_set_precision()
Bool = np.bool_

FloatField = Field[gtscript.IJK, Float]
Expand All @@ -68,10 +74,27 @@ def global_set_floating_point_precision():
FloatFieldK = Field[gtscript.K, Float]
FloatFieldK64 = Field[gtscript.K, np.float64]
FloatFieldK32 = Field[gtscript.K, np.float32]

IntField = Field[gtscript.IJK, Int]
IntField64 = Field[gtscript.IJK, np.int64]
IntField32 = Field[gtscript.IJK, np.int32]
IntFieldI = Field[gtscript.I, Int]
IntFieldI64 = Field[gtscript.I, np.int64]
IntFieldI32 = Field[gtscript.I, np.int32]
IntFieldJ = Field[gtscript.J, Int]
IntFieldJ64 = Field[gtscript.J, np.int64]
IntFieldJ32 = Field[gtscript.J, np.int32]
IntFieldIJ = Field[gtscript.IJ, Int]
IntFieldIJ64 = Field[gtscript.IJ, np.int64]
IntFieldIJ32 = Field[gtscript.IJ, np.int32]
IntFieldK = Field[gtscript.K, Int]
IntFieldK64 = Field[gtscript.K, np.int64]
IntFieldK32 = Field[gtscript.K, np.int32]

BoolField = Field[gtscript.IJK, Bool]
BoolFieldI = Field[gtscript.I, Bool]
BoolFieldJ = Field[gtscript.J, Bool]
BoolFieldK = Field[gtscript.K, Bool]
BoolFieldIJ = Field[gtscript.IJ, Bool]

Index3D = Tuple[int, int, int]
Expand Down
Loading