Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QOL] Logging, Type Hints and Quantity helpers #108

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict, List, Optional, TypeVar, cast

from ndsl.comm.comm_abc import Comm, ReductionOperator, Request
from ndsl.logging import ndsl_log


T = TypeVar("T")
Expand Down Expand Up @@ -43,70 +42,46 @@ def Get_size(self) -> int:
return self._comm.Get_size()

def bcast(self, value: Optional[T], root=0) -> T:
ndsl_log.debug("bcast from root %s on rank %s", root, self._comm.Get_rank())
return self._comm.bcast(value, root=root)

def barrier(self):
ndsl_log.debug("barrier on rank %s", self._comm.Get_rank())
self._comm.barrier()

def Barrier(self):
pass

def Scatter(self, sendbuf, recvbuf, root=0, **kwargs):
ndsl_log.debug("Scatter on rank %s with root %s", self._comm.Get_rank(), root)
self._comm.Scatter(sendbuf, recvbuf, root=root, **kwargs)

def Gather(self, sendbuf, recvbuf, root=0, **kwargs):
ndsl_log.debug("Gather on rank %s with root %s", self._comm.Get_rank(), root)
self._comm.Gather(sendbuf, recvbuf, root=root, **kwargs)

def allgather(self, sendobj: T) -> List[T]:
ndsl_log.debug("allgather on rank %s", self._comm.Get_rank())
return self._comm.allgather(sendobj)

def Send(self, sendbuf, dest, tag: int = 0, **kwargs):
ndsl_log.debug("Send on rank %s with dest %s", self._comm.Get_rank(), dest)
self._comm.Send(sendbuf, dest, tag=tag, **kwargs)

def sendrecv(self, sendbuf, dest, **kwargs):
ndsl_log.debug("sendrecv on rank %s with dest %s", self._comm.Get_rank(), dest)
return self._comm.sendrecv(sendbuf, dest, **kwargs)

def Isend(self, sendbuf, dest, tag: int = 0, **kwargs) -> Request:
ndsl_log.debug("Isend on rank %s with dest %s", self._comm.Get_rank(), dest)
return self._comm.Isend(sendbuf, dest, tag=tag, **kwargs)

def Recv(self, recvbuf, source, tag: int = 0, **kwargs):
ndsl_log.debug("Recv on rank %s with source %s", self._comm.Get_rank(), source)
self._comm.Recv(recvbuf, source, tag=tag, **kwargs)

def Irecv(self, recvbuf, source, tag: int = 0, **kwargs) -> Request:
ndsl_log.debug("Irecv on rank %s with source %s", self._comm.Get_rank(), source)
return self._comm.Irecv(recvbuf, source, tag=tag, **kwargs)

def Split(self, color, key) -> "Comm":
ndsl_log.debug(
"Split on rank %s with color %s, key %s", self._comm.Get_rank(), color, key
)
return self._comm.Split(color, key)

def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
ndsl_log.debug(
"allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.allreduce(sendobj, self._op_mapping[op])

def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op])

def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce (in place) on rank %s with operator %s",
self._comm.Get_rank(),
op,
)
return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op])
5 changes: 4 additions & 1 deletion ndsl/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def __get__(self, obj, objtype=None) -> SDFGEnabledCallable:
def orchestrate(
*,
obj: object,
config: DaceConfig,
config: Optional[DaceConfig],
method_to_orchestrate: str = "__call__",
dace_compiletime_args: Optional[Sequence[str]] = None,
):
Expand All @@ -455,6 +455,9 @@ def orchestrate(
dace_compiletime_args: list of names of arguments to be flagged has
dace.compiletime for orchestration to behave
"""
if config is None:
raise ValueError("DaCe config cannot be None")

if dace_compiletime_args is None:
dace_compiletime_args = []

Expand Down
3 changes: 3 additions & 0 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ndsl.dsl.stencil_config import CompilationConfig, RunMode, StencilConfig
from ndsl.dsl.typing import Float, Index3D, cast_to_index3d
from ndsl.initialization.sizer import GridSizer, SubtileGridSizer
from ndsl.logging import ndsl_log
from ndsl.quantity import Quantity
from ndsl.testing.comparison import LegacyMetric

Expand Down Expand Up @@ -374,6 +375,8 @@ def nothing_function(*args, **kwargs):
setattr(self, "__call__", nothing_function)

def __call__(self, *args, **kwargs) -> None:
if self.stencil_config.verbose:
ndsl_log.debug(f"Running {self._func_name}")
args_list = list(args)
_convert_quantities_to_storage(args_list, kwargs)
args = tuple(args_list)
Expand Down
1 change: 1 addition & 0 deletions ndsl/dsl/stencil_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class StencilConfig(Hashable):
compare_to_numpy: bool = False
compilation_config: CompilationConfig = CompilationConfig()
dace_config: Optional[DaceConfig] = None
verbose: bool = False

def __post_init__(self):
self.backend_opts = {
Expand Down
36 changes: 34 additions & 2 deletions ndsl/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import logging
import os
import sys
from typing import Annotated

from mpi4py import MPI

Expand All @@ -18,7 +21,7 @@
}


def _ndsl_logger():
def _ndsl_logger() -> logging.Logger:
name_log = logging.getLogger(__name__)
name_log.setLevel(LOGLEVEL)

Expand All @@ -36,4 +39,33 @@ def _ndsl_logger():
return name_log


ndsl_log = _ndsl_logger()
def _ndsl_logger_on_rank_0() -> logging.Logger:
name_log = logging.getLogger(f"{__name__}_on_rank_0")
name_log.setLevel(LOGLEVEL)

rank = MPI.COMM_WORLD.Get_rank()

if rank == 0:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(LOGLEVEL)
formatter = logging.Formatter(
fmt=(
f"%(asctime)s|%(levelname)s|rank {MPI.COMM_WORLD.Get_rank()}|"
"%(name)s:%(message)s"
),
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
name_log.addHandler(handler)
else:
name_log.disabled = True
return name_log


ndsl_log: Annotated[
logging.Logger, "NDSL Python logger, logs on all rank"
] = _ndsl_logger()

ndsl_log_on_rank_0: Annotated[
logging.Logger, "NDSL Python logger, logs on rank 0 only"
] = _ndsl_logger_on_rank_0()
5 changes: 5 additions & 0 deletions ndsl/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import matplotlib.pyplot as plt
import numpy as np
from mpi4py import MPI

import ndsl.constants as constants
from ndsl.dsl.typing import Float, is_float
Expand Down Expand Up @@ -152,6 +153,10 @@ def from_data_array(
gt4py_backend=gt4py_backend,
)

def to_netcdf(self, name: str, rank: int = -1) -> None:
if rank < 0 or MPI.COMM_WORLD.Get_rank() == rank:
self.data_array.to_netcdf(f"{name}__r{rank}.nc4")

def halo_spec(self, n_halo: int) -> QuantityHaloSpec:
return QuantityHaloSpec(
n_halo,
Expand Down