Skip to content

Commit

Permalink
Change outputs to plugin format
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Feb 13, 2024
1 parent 449b605 commit 350e350
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _main():
"--output",
default="file",
help="Where to output the results",
choices=available_outputs(),
choices=sorted(available_outputs()),
)

parser.add_argument(
Expand Down
14 changes: 6 additions & 8 deletions ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging

import climetlab as cml
import entrypoints
import numpy as np

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,18 +113,15 @@ def write(self, *args, **kwargs):
pass


OUTPUTS = dict(
file=FileOutput,
none=NoneOutput,
)


def get_output(name, owner, *args, **kwargs):
result = OUTPUTS[name](owner, *args, **kwargs)
result = available_outputs()[name].load()(owner, *args, **kwargs)
if kwargs.get("hindcast_reference_year") is not None:
result = HindcastReLabel(owner, result, **kwargs)
return result


def available_outputs():
return sorted(OUTPUTS.keys())
result = {}
for e in entrypoints.get_group_all("ai_models.output"):
result[e.name] = e
return result
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def read(fname):
"mars=ai_models.inputs:MarsInput",
"cds=ai_models.inputs:CdsInput",
],
"ai_models.output": [
"file=ai_models.outputs:FileOutput",
"none=ai_models.outputs:NoneOutput",
],
},
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down

0 comments on commit 350e350

Please sign in to comment.