diff --git a/ai_models/__main__.py b/ai_models/__main__.py index 4986923..3ee43c5 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -97,7 +97,7 @@ def _main(): "--output", default="file", help="Where to output the results", - choices=available_outputs(), + choices=sorted(available_outputs()), ) parser.add_argument( diff --git a/ai_models/outputs/__init__.py b/ai_models/outputs/__init__.py index bc9dcc9..2b304e1 100644 --- a/ai_models/outputs/__init__.py +++ b/ai_models/outputs/__init__.py @@ -9,6 +9,7 @@ import logging import climetlab as cml +import entrypoints import numpy as np LOG = logging.getLogger(__name__) @@ -112,18 +113,15 @@ def write(self, *args, **kwargs): pass -OUTPUTS = dict( - file=FileOutput, - none=NoneOutput, -) - - def get_output(name, owner, *args, **kwargs): - result = OUTPUTS[name](owner, *args, **kwargs) + result = available_outputs()[name].load()(owner, *args, **kwargs) if kwargs.get("hindcast_reference_year") is not None: result = HindcastReLabel(owner, result, **kwargs) return result def available_outputs(): - return sorted(OUTPUTS.keys()) + result = {} + for e in entrypoints.get_group_all("ai_models.output"): + result[e.name] = e + return result diff --git a/setup.py b/setup.py index 7bb2b73..7ee30ab 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,10 @@ def read(fname): "mars=ai_models.inputs:MarsInput", "cds=ai_models.inputs:CdsInput", ], + "ai_models.output": [ + "file=ai_models.outputs:FileOutput", + "none=ai_models.outputs:NoneOutput", + ], }, classifiers=[ "Development Status :: 3 - Alpha",