From ec969b66a06e3e5c2f8c041108c8c42c4d39e0fd Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 12 Mar 2024 14:11:17 +0000 Subject: [PATCH 01/21] Add RemoteModel --- ai_models/__main__.py | 11 ++++++----- ai_models/model.py | 18 ------------------ ai_models/remote.py | 38 ++++++++++++++++++++++++++++++++++---- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index 148910f..caad1df 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -271,7 +271,11 @@ def _main(argv): def run(cfg: dict, model_args: list): - model = load_model(cfg["model"], **cfg, model_args=model_args) + if cfg["remote_execution"]: + from .remote import RemoteModel + model = RemoteModel(**cfg, model_args=model_args) + else: + model = load_model(cfg["model"], **cfg, model_args=model_args) if cfg["fields"]: model.print_fields() @@ -289,10 +293,7 @@ def run(cfg: dict, model_args: list): sys.exit(0) try: - if cfg["remote_execution"]: - model.remote(cfg, model_args) - else: - model.run() + model.run() except FileNotFoundError as e: LOG.exception(e) LOG.error( diff --git a/ai_models/model.py b/ai_models/model.py index 3f81be9..636abe8 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -24,7 +24,6 @@ from .checkpoint import peek from .inputs import get_input from .outputs import get_output -from .remote import RemoteClient from .stepper import Stepper LOG = logging.getLogger(__name__) @@ -458,23 +457,6 @@ def write_input_fields( check=True, ) - def remote(self, cfg: dict, model_args: list): - with tempfile.TemporaryDirectory() as tmpdirname: - input_file = os.path.join(tmpdirname, "input.grib") - output_file = os.path.join(tmpdirname, "output.grib") - self.all_fields.save(input_file) - - client = RemoteClient( - input_file=input_file, - output_file=output_file, - ) - - client.run(cfg, model_args) - - ds = cml.load_source("file", output_file) - for field in ds: - self.write(None, template=field) - def load_model(name, **kwargs): return available_models()[name].load()(**kwargs) diff --git a/ai_models/remote.py b/ai_models/remote.py index e047bd0..55af89d 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -1,15 +1,48 @@ import logging import os import sys +import tempfile import time +from functools import cached_property from urllib.parse import urljoin +import climetlab as cml import requests from multiurl import download, robust +from .model import Model + LOG = logging.getLogger(__name__) +class RemoteModel(Model): + def __init__(self, **kwargs): + kwargs["download_assets"] = False + + super().__init__(**kwargs) + + self.cfg = kwargs + self.client = RemoteClient() + + def run(self): + with tempfile.TemporaryDirectory() as tmpdirname: + input_file = os.path.join(tmpdirname, "input.grib") + output_file = os.path.join(tmpdirname, "output.grib") + self.all_fields.save(input_file) + + self.client.input_file = input_file + self.client.output_file = output_file + + self.client.run(self.cfg) + + ds = cml.load_source("file", output_file) + for field in ds: + self.write(None, template=field) + + def parse_model_args(self, args): + return None + + class BearerAuth(requests.auth.AuthBase): def __init__(self, token): self.token = token @@ -61,10 +94,7 @@ def __init__( self.input_file = input_file self._timeout = 300 - def run(self, cfg: dict, model_args: list): - cfg.pop("remote_execution", None) - cfg["model_args"] = model_args - + def run(self, cfg: dict): # upload file with open(self.input_file, "rb") as f: LOG.info("Uploading input file to remote") From f4ef3b023ea595d4a7e08c070dca775f24a4206a Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 12 Mar 2024 16:10:53 +0000 Subject: [PATCH 02/21] Add remote parameter lookup --- ai_models/__main__.py | 1 + ai_models/model.py | 1 - ai_models/remote.py | 76 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 63 insertions(+), 15 deletions(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index caad1df..e86c722 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -273,6 +273,7 @@ def _main(argv): def run(cfg: dict, model_args: list): if cfg["remote_execution"]: from .remote import RemoteModel + model = RemoteModel(**cfg, model_args=model_args) else: model = load_model(cfg["model"], **cfg, model_args=model_args) diff --git a/ai_models/model.py b/ai_models/model.py index 636abe8..5593043 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -10,7 +10,6 @@ import logging import os import sys -import tempfile import time from collections import defaultdict from functools import cached_property diff --git a/ai_models/remote.py b/ai_models/remote.py index 55af89d..1e54c5c 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -3,7 +3,7 @@ import sys import tempfile import time -from functools import cached_property +from functools import cache, cached_property from urllib.parse import urljoin import climetlab as cml @@ -17,12 +17,13 @@ class RemoteModel(Model): def __init__(self, **kwargs): - kwargs["download_assets"] = False - - super().__init__(**kwargs) - self.cfg = kwargs - self.client = RemoteClient() + self.cfg["download_assets"] = False + self.cfg["assets_extra_dir"] = None + self._param = {} + self.api = RemoteClient() + + super().__init__(**self.cfg) def run(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -30,10 +31,10 @@ def run(self): output_file = os.path.join(tmpdirname, "output.grib") self.all_fields.save(input_file) - self.client.input_file = input_file - self.client.output_file = output_file + self.api.input_file = input_file + self.api.output_file = output_file - self.client.run(self.cfg) + self.api.run(self.cfg) ds = cml.load_source("file", output_file) for field in ds: @@ -42,6 +43,41 @@ def run(self): def parse_model_args(self, args): return None + def __getattr__(self, name): + return self.get_param(name) + + @cache + def get_param(self, name): + return self.api.get_param(self.cfg["model"], name).get(name, None) + + @cached_property + def param_level_ml(self): + return self.get_param("param_level_ml") or ([], []) + + @cached_property + def param_level_pl(self): + return self.get_param("param_level_pl") or ([], []) + + @cached_property + def param_sfc(self): + return self.get_param("param_sfc") or [] + + @cached_property + def lagged(self): + return self.get_param("lagged") or False + + @cached_property + def version(self): + return self.get_param("version") or 1 + + @cached_property + def grib_extra_metadata(self): + return self.get_param("grib_extra_metadata") or {} + + @cached_property + def retrieve(self): + return self.get_param("retrieve") or {} + class BearerAuth(requests.auth.AuthBase): def __init__(self, token): @@ -133,17 +169,29 @@ def run(self, cfg: dict): LOG.debug("Result written to %s", self.output_file) - def _request(self, type, href, data=None, json=None, auth=None): - r = robust(type, retry_after=self._timeout)( + def get_param(self, model, param): + if isinstance(param, str): + return self._request( + requests.get, f"metadata/{model}/{param}", with_status=False + ) + else: + return self._request( + requests.post, f"metadata/{model}", json=param, with_status=False + ) + + def _request(self, type, href, data=None, json=None, auth=None, with_status=True): + r = robust(type, retry_after=30)( urljoin(self.url, href), json=json, data=data, auth=self.auth, timeout=self._timeout, ) - - status, href = self._update_state(r) - return status, href + if with_status: + status, href = self._update_state(r) + return status, href + else: + return r.json() def _update_state(self, response: requests.Response): if response.status_code == 401: From 9a13c504663afe5c0274250e5d197683a4cfa480 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 12 Mar 2024 17:30:41 +0000 Subject: [PATCH 03/21] Rename RemoteClient to RemoteAPI --- ai_models/remote.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ai_models/remote.py b/ai_models/remote.py index 1e54c5c..1f468f5 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -21,7 +21,7 @@ def __init__(self, **kwargs): self.cfg["download_assets"] = False self.cfg["assets_extra_dir"] = None self._param = {} - self.api = RemoteClient() + self.api = RemoteAPI() super().__init__(**self.cfg) @@ -88,7 +88,7 @@ def __call__(self, r): return r -class RemoteClient: +class RemoteAPI: def __init__( self, input_file: str = None, From 65a147818c2f2d71dbe1258b5c109f545f5a264c Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 12 Mar 2024 18:06:22 +0000 Subject: [PATCH 04/21] Improve remote request logging --- ai_models/remote.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/ai_models/remote.py b/ai_models/remote.py index 1f468f5..40cdbc8 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -112,7 +112,7 @@ def __init__( if url is None: url = os.getenv("AI_MODELS_REMOTE_URL", "https://ai-models.ecmwf.int") - LOG.info("Using remote %s", url) + LOG.info("Using remote server %s", url) token = token or os.getenv("AI_MODELS_REMOTE_TOKEN", None) @@ -132,29 +132,32 @@ def __init__( def run(self, cfg: dict): # upload file - with open(self.input_file, "rb") as f: - LOG.info("Uploading input file to remote") - status, href = self._request(requests.post, "upload", data=f) + with open(self.input_file, "rb") as file: + LOG.info("Uploading input file to remote server") + _, status, href = self._request(requests.post, "upload", data=file) if status != "success": LOG.error(status) sys.exit(1) - # submit job - status, href = self._request(requests.post, href, json=cfg) + # submit task + id, status, href = self._request(requests.post, href, json=cfg) + + LOG.info("Request submitted") + LOG.info("Request id: %s", id) if status != "queued": LOG.error(status) sys.exit(1) - LOG.info("Job status: queued") + LOG.info("Request is queued") last_status = status while True: - status, href = self._request(requests.get, href) + _, status, href = self._request(requests.get, href) if status != last_status: - LOG.info("Job status: %s", status) + LOG.info("Request is %s", status) last_status = status if status == "failed": @@ -163,7 +166,7 @@ def run(self, cfg: dict): if status == "ready": break - time.sleep(4) + time.sleep(5) download(urljoin(self.url, href), target=self.output_file) @@ -180,29 +183,33 @@ def get_param(self, model, param): ) def _request(self, type, href, data=None, json=None, auth=None, with_status=True): - r = robust(type, retry_after=30)( + response = robust(type, retry_after=30)( urljoin(self.url, href), json=json, data=data, auth=self.auth, timeout=self._timeout, ) + + if response.status_code == 401: + LOG.error("Unauthorized Access. Check your token.") + sys.exit(1) + if with_status: - status, href = self._update_state(r) - return status, href + id, status, href = self._update_state(response) + return id, status, href else: - return r.json() + return response.json() def _update_state(self, response: requests.Response): - if response.status_code == 401: - return "Unauthorized Access", None - try: data = response.json() href = data["href"] status = data["status"].lower() + id = data["id"] except Exception: status = f"{response.status_code} {response.url} {response.text}" href = None + id = None - return status, href + return id, status, href From 95f85f09a1cd14226425b08a7035223a60e92260 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Wed, 13 Mar 2024 16:40:12 +0000 Subject: [PATCH 05/21] Refactor remote parameter cache --- ai_models/remote.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/ai_models/remote.py b/ai_models/remote.py index 40cdbc8..a5d89af 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -3,7 +3,7 @@ import sys import tempfile import time -from functools import cache, cached_property +from functools import cached_property from urllib.parse import urljoin import climetlab as cml @@ -20,11 +20,16 @@ def __init__(self, **kwargs): self.cfg = kwargs self.cfg["download_assets"] = False self.cfg["assets_extra_dir"] = None + + self.model = self.cfg["model"] self._param = {} self.api = RemoteAPI() super().__init__(**self.cfg) + def __getattr__(self, name): + return self.get_parameter(name) + def run(self): with tempfile.TemporaryDirectory() as tmpdirname: input_file = os.path.join(tmpdirname, "input.grib") @@ -43,40 +48,41 @@ def run(self): def parse_model_args(self, args): return None - def __getattr__(self, name): - return self.get_param(name) + def get_parameter(self, name): + if (param := self._param.get(name, None)) is not None: + return param - @cache - def get_param(self, name): - return self.api.get_param(self.cfg["model"], name).get(name, None) + self._param.update(self.api.metadata(self.model, name)) + + return self._param[name] @cached_property def param_level_ml(self): - return self.get_param("param_level_ml") or ([], []) + return self.get_parameter("param_level_ml") or ([], []) @cached_property def param_level_pl(self): - return self.get_param("param_level_pl") or ([], []) + return self.get_parameter("param_level_pl") or ([], []) @cached_property def param_sfc(self): - return self.get_param("param_sfc") or [] + return self.get_parameter("param_sfc") or [] @cached_property def lagged(self): - return self.get_param("lagged") or False + return self.get_parameter("lagged") or False @cached_property def version(self): - return self.get_param("version") or 1 + return self.get_parameter("version") or 1 @cached_property def grib_extra_metadata(self): - return self.get_param("grib_extra_metadata") or {} + return self.get_parameter("grib_extra_metadata") or {} @cached_property def retrieve(self): - return self.get_param("retrieve") or {} + return self.get_parameter("retrieve") or {} class BearerAuth(requests.auth.AuthBase): @@ -172,15 +178,17 @@ def run(self, cfg: dict): LOG.debug("Result written to %s", self.output_file) - def get_param(self, model, param): + def metadata(self, model, param) -> dict: if isinstance(param, str): return self._request( requests.get, f"metadata/{model}/{param}", with_status=False ) - else: + elif isinstance(param, (list, dict)): return self._request( requests.post, f"metadata/{model}", json=param, with_status=False ) + else: + raise ValueError("param must be a string, list, or dict with 'param' key.") def _request(self, type, href, data=None, json=None, auth=None, with_status=True): response = robust(type, retry_after=30)( From 2ad1ec27e5cf7d9965c79c000fde23fce617ad6f Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Wed, 13 Mar 2024 17:09:55 +0000 Subject: [PATCH 06/21] Preload remote parameters --- ai_models/remote.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ai_models/remote.py b/ai_models/remote.py index a5d89af..4839cc9 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -25,6 +25,8 @@ def __init__(self, **kwargs): self._param = {} self.api = RemoteAPI() + self.load_parameters() + super().__init__(**self.cfg) def __getattr__(self, name): @@ -48,6 +50,24 @@ def run(self): def parse_model_args(self, args): return None + def load_parameters(self): + params = self.api.metadata( + self.model, + [ + "expver", + "version", + "grid", + "area", + "param_level_ml", + "param_level_pl", + "param_sfc", + "lagged", + "grib_extra_metadata", + "retrieve", + ], + ) + self._param.update(params) + def get_parameter(self, name): if (param := self._param.get(name, None)) is not None: return param From 2ffa1ba76275fcfeaf90c8f147e2727726cfec6c Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Thu, 14 Mar 2024 13:24:25 +0000 Subject: [PATCH 07/21] Refactor api request func --- ai_models/remote.py | 62 ++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/ai_models/remote.py b/ai_models/remote.py index 4839cc9..83cf935 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -160,57 +160,54 @@ def run(self, cfg: dict): # upload file with open(self.input_file, "rb") as file: LOG.info("Uploading input file to remote server") - _, status, href = self._request(requests.post, "upload", data=file) + data = self._request(requests.post, "upload", data=file) - if status != "success": - LOG.error(status) + if data["status"] != "success": + LOG.error(data["status"]) sys.exit(1) # submit task - id, status, href = self._request(requests.post, href, json=cfg) + data = self._request(requests.post, data["href"], json=cfg) LOG.info("Request submitted") - LOG.info("Request id: %s", id) - if status != "queued": - LOG.error(status) + if data["status"] != "queued": + LOG.error(data["status"]) sys.exit(1) + LOG.info("Request id: %s", data["id"]) LOG.info("Request is queued") - last_status = status + + last_status = data["status"] while True: - _, status, href = self._request(requests.get, href) + data = self._request(requests.get, data["href"]) - if status != last_status: - LOG.info("Request is %s", status) - last_status = status + if data["status"] != last_status: + LOG.info("Request is %s", data["status"]) + last_status = data["status"] - if status == "failed": + if data["status"] == "failed": sys.exit(1) - if status == "ready": + if data["status"] == "ready": break time.sleep(5) - download(urljoin(self.url, href), target=self.output_file) + download(urljoin(self.url, data["href"]), target=self.output_file) LOG.debug("Result written to %s", self.output_file) def metadata(self, model, param) -> dict: if isinstance(param, str): - return self._request( - requests.get, f"metadata/{model}/{param}", with_status=False - ) + return self._request(requests.get, f"metadata/{model}/{param}") elif isinstance(param, (list, dict)): - return self._request( - requests.post, f"metadata/{model}", json=param, with_status=False - ) + return self._request(requests.post, f"metadata/{model}", json=param) else: raise ValueError("param must be a string, list, or dict with 'param' key.") - def _request(self, type, href, data=None, json=None, auth=None, with_status=True): + def _request(self, type, href, data=None, json=None, auth=None): response = robust(type, retry_after=30)( urljoin(self.url, href), json=json, @@ -223,21 +220,12 @@ def _request(self, type, href, data=None, json=None, auth=None, with_status=True LOG.error("Unauthorized Access. Check your token.") sys.exit(1) - if with_status: - id, status, href = self._update_state(response) - return id, status, href - else: - return response.json() - - def _update_state(self, response: requests.Response): try: data = response.json() - href = data["href"] - status = data["status"].lower() - id = data["id"] - except Exception: - status = f"{response.status_code} {response.url} {response.text}" - href = None - id = None - return id, status, href + if status := data.get("status"): + data["status"] = status.lower() + + return data + except Exception: + return {"status": f"{response.url} {response.status_code} {response.text}"} From 4d065b89b9bf532589241dddee7e5464b3ceddec Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Thu, 14 Mar 2024 13:38:17 +0000 Subject: [PATCH 08/21] Log remote error reason if there is one --- ai_models/remote.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ai_models/remote.py b/ai_models/remote.py index 83cf935..27de98a 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -164,6 +164,8 @@ def run(self, cfg: dict): if data["status"] != "success": LOG.error(data["status"]) + if reason := data.get("reason"): + LOG.error(reason) sys.exit(1) # submit task @@ -173,6 +175,8 @@ def run(self, cfg: dict): if data["status"] != "queued": LOG.error(data["status"]) + if reason := data.get("reason"): + LOG.error(reason) sys.exit(1) LOG.info("Request id: %s", data["id"]) @@ -188,6 +192,8 @@ def run(self, cfg: dict): last_status = data["status"] if data["status"] == "failed": + if reason := data.get("reason"): + LOG.error(reason) sys.exit(1) if data["status"] == "ready": From 89138c1c75ba9ed33a0262ca3d035ef13097bca3 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Fri, 15 Mar 2024 16:10:09 +0000 Subject: [PATCH 09/21] Print remote models when queried with --models --remote --- ai_models/__main__.py | 14 +++++++++++++- ai_models/remote.py | 10 +++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index e86c722..175bf3f 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -231,7 +231,19 @@ def _main(argv): args, unknownargs = parser.parse_known_args(argv) if args.models: - for p in sorted(available_models()): + if args.remote_execution: + from .remote import RemoteAPI + + api = RemoteAPI() + models = api.models() + if len(models) == 0: + print(f"No remote models available on {api.url}") + sys.exit(0) + print(f"Models available on remote server {api.url}:") + else: + models = available_models() + + for p in sorted(models): print(p) sys.exit(0) diff --git a/ai_models/remote.py b/ai_models/remote.py index 27de98a..c2b1c9c 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -213,6 +213,14 @@ def metadata(self, model, param) -> dict: else: raise ValueError("param must be a string, list, or dict with 'param' key.") + def models(self): + results = self._request(requests.get, "models") + + if not isinstance(results, list): + return [] + + return results + def _request(self, type, href, data=None, json=None, auth=None): response = robust(type, retry_after=30)( urljoin(self.url, href), @@ -229,7 +237,7 @@ def _request(self, type, href, data=None, json=None, auth=None): try: data = response.json() - if status := data.get("status"): + if isinstance(data, dict) and (status := data.get("status")): data["status"] = status.lower() return data From 52137851cfaa664dbd9c5f3683a4e49706e5f231 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Mon, 18 Mar 2024 14:20:35 +0000 Subject: [PATCH 10/21] Patch retrieve requests for remote models --- ai_models/remote.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ai_models/remote.py b/ai_models/remote.py index c2b1c9c..d8dee94 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -50,6 +50,10 @@ def run(self): def parse_model_args(self, args): return None + def patch_retrieve_request(self, request): + patched = self.api.patch_retrieve_request(self.cfg, request) + request.update(patched) + def load_parameters(self): params = self.api.metadata( self.model, @@ -221,6 +225,14 @@ def models(self): return results + def patch_retrieve_request(self, cfg, request): + cfg["patchrequest"] = request + result = self._request(requests.post, "patch", json=cfg) + if status := result.get("status"): + LOG.error(status) + sys.exit(1) + return result + def _request(self, type, href, data=None, json=None, auth=None): response = robust(type, retry_after=30)( urljoin(self.url, href), From 095c2c35da76ff038b1a01b28400e0a267111041 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Mon, 25 Mar 2024 15:57:29 +0000 Subject: [PATCH 11/21] Show remote progress bar --- ai_models/remote.py | 32 +++++++++++++++++++++++++++----- setup.py | 1 + 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/ai_models/remote.py b/ai_models/remote.py index d8dee94..a11103f 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -9,6 +9,7 @@ import climetlab as cml import requests from multiurl import download, robust +from tqdm import tqdm from .model import Model @@ -187,21 +188,42 @@ def run(self, cfg: dict): LOG.info("Request is queued") last_status = data["status"] + pbar = None while True: data = self._request(requests.get, data["href"]) - if data["status"] != last_status: - LOG.info("Request is %s", data["status"]) - last_status = data["status"] + if data["status"] == "ready": + if pbar is not None: + pbar.close() + LOG.info("Request is ready") + break if data["status"] == "failed": + LOG.error("Request failed") if reason := data.get("reason"): LOG.error(reason) sys.exit(1) - if data["status"] == "ready": - break + if data["status"] != last_status: + LOG.info("Request is %s", data["status"]) + last_status = data["status"] + + if progress := data.get("progress"): + if pbar is None: + pbar = tqdm( + total=progress.get("total", 0), + unit="steps", + ncols=70, + leave=False, + initial=1, + bar_format="{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} {unit}{postfix}", + ) + if eta := progress.get("eta"): + pbar.set_postfix_str(f"ETA: {eta}") + if status := progress.get("status"): + pbar.set_description(status.strip().capitalize()) + pbar.update(progress.get("step", 0) - pbar.n) time.sleep(5) diff --git a/setup.py b/setup.py index 6e97df8..8ef8ae1 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ def read(fname): "gputil", "earthkit-meteo", "pyyaml", + "tqdm", ], extras_require={ "provenance": [ From 8fbc4b5cae3d66d68ebbf67a572df0849782913e Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 26 Mar 2024 13:32:46 +0000 Subject: [PATCH 12/21] Fix remote api config priority constructor arg > env var > config > default value --- ai_models/remote.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/ai_models/remote.py b/ai_models/remote.py index a11103f..48493d0 100644 --- a/ai_models/remote.py +++ b/ai_models/remote.py @@ -138,16 +138,17 @@ def __init__( with open(configfile, "r") as f: config = safe_load(f) or {} - url = config.get("url", None) - token = config.get("token", None) - - if url is None: - url = os.getenv("AI_MODELS_REMOTE_URL", "https://ai-models.ecmwf.int") - LOG.info("Using remote server %s", url) + url = ( + url + or os.getenv("AI_MODELS_REMOTE_URL") + or config.get("url") + or "https://ai-models.ecmwf.int" + ) + LOG.info("Using remote server %s", url) - token = token or os.getenv("AI_MODELS_REMOTE_TOKEN", None) + token = token or os.getenv("AI_MODELS_REMOTE_TOKEN") or config.get("token") - if token is None: + if not token: LOG.error( "Missing remote token. Set it in %s or in env AI_MODELS_REMOTE_TOKEN", configfile, From 13c730d8c21c3151e37fb849bf1854155e037592 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 26 Mar 2024 14:10:14 +0000 Subject: [PATCH 13/21] Refactor remote.py to remote subpackage --- ai_models/remote/__init__.py | 4 + ai_models/{remote.py => remote/api.py} | 99 ------------------------ ai_models/remote/model.py | 102 +++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 99 deletions(-) create mode 100644 ai_models/remote/__init__.py rename ai_models/{remote.py => remote/api.py} (67%) create mode 100644 ai_models/remote/model.py diff --git a/ai_models/remote/__init__.py b/ai_models/remote/__init__.py new file mode 100644 index 0000000..3a3d0ab --- /dev/null +++ b/ai_models/remote/__init__.py @@ -0,0 +1,4 @@ +from .api import RemoteAPI +from .model import RemoteModel + +__all__ = ["RemoteAPI", "RemoteModel"] diff --git a/ai_models/remote.py b/ai_models/remote/api.py similarity index 67% rename from ai_models/remote.py rename to ai_models/remote/api.py index 48493d0..d14412a 100644 --- a/ai_models/remote.py +++ b/ai_models/remote/api.py @@ -1,115 +1,16 @@ import logging import os import sys -import tempfile import time -from functools import cached_property from urllib.parse import urljoin -import climetlab as cml import requests from multiurl import download, robust from tqdm import tqdm -from .model import Model - LOG = logging.getLogger(__name__) -class RemoteModel(Model): - def __init__(self, **kwargs): - self.cfg = kwargs - self.cfg["download_assets"] = False - self.cfg["assets_extra_dir"] = None - - self.model = self.cfg["model"] - self._param = {} - self.api = RemoteAPI() - - self.load_parameters() - - super().__init__(**self.cfg) - - def __getattr__(self, name): - return self.get_parameter(name) - - def run(self): - with tempfile.TemporaryDirectory() as tmpdirname: - input_file = os.path.join(tmpdirname, "input.grib") - output_file = os.path.join(tmpdirname, "output.grib") - self.all_fields.save(input_file) - - self.api.input_file = input_file - self.api.output_file = output_file - - self.api.run(self.cfg) - - ds = cml.load_source("file", output_file) - for field in ds: - self.write(None, template=field) - - def parse_model_args(self, args): - return None - - def patch_retrieve_request(self, request): - patched = self.api.patch_retrieve_request(self.cfg, request) - request.update(patched) - - def load_parameters(self): - params = self.api.metadata( - self.model, - [ - "expver", - "version", - "grid", - "area", - "param_level_ml", - "param_level_pl", - "param_sfc", - "lagged", - "grib_extra_metadata", - "retrieve", - ], - ) - self._param.update(params) - - def get_parameter(self, name): - if (param := self._param.get(name, None)) is not None: - return param - - self._param.update(self.api.metadata(self.model, name)) - - return self._param[name] - - @cached_property - def param_level_ml(self): - return self.get_parameter("param_level_ml") or ([], []) - - @cached_property - def param_level_pl(self): - return self.get_parameter("param_level_pl") or ([], []) - - @cached_property - def param_sfc(self): - return self.get_parameter("param_sfc") or [] - - @cached_property - def lagged(self): - return self.get_parameter("lagged") or False - - @cached_property - def version(self): - return self.get_parameter("version") or 1 - - @cached_property - def grib_extra_metadata(self): - return self.get_parameter("grib_extra_metadata") or {} - - @cached_property - def retrieve(self): - return self.get_parameter("retrieve") or {} - - class BearerAuth(requests.auth.AuthBase): def __init__(self, token): self.token = token diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py new file mode 100644 index 0000000..7fdd8ee --- /dev/null +++ b/ai_models/remote/model.py @@ -0,0 +1,102 @@ +import os +import tempfile +from functools import cached_property + +import climetlab as cml + +from ..model import Model +from .api import RemoteAPI + + +class RemoteModel(Model): + def __init__(self, **kwargs): + self.cfg = kwargs + self.cfg["download_assets"] = False + self.cfg["assets_extra_dir"] = None + + self.model = self.cfg["model"] + self._param = {} + self.api = RemoteAPI() + + self.load_parameters() + + super().__init__(**self.cfg) + + def __getattr__(self, name): + return self.get_parameter(name) + + def run(self): + with tempfile.TemporaryDirectory() as tmpdirname: + input_file = os.path.join(tmpdirname, "input.grib") + output_file = os.path.join(tmpdirname, "output.grib") + self.all_fields.save(input_file) + + self.api.input_file = input_file + self.api.output_file = output_file + + self.api.run(self.cfg) + + ds = cml.load_source("file", output_file) + for field in ds: + self.write(None, template=field) + + def parse_model_args(self, args): + return None + + def patch_retrieve_request(self, request): + patched = self.api.patch_retrieve_request(self.cfg, request) + request.update(patched) + + def load_parameters(self): + params = self.api.metadata( + self.model, + [ + "expver", + "version", + "grid", + "area", + "param_level_ml", + "param_level_pl", + "param_sfc", + "lagged", + "grib_extra_metadata", + "retrieve", + ], + ) + self._param.update(params) + + def get_parameter(self, name): + if (param := self._param.get(name, None)) is not None: + return param + + self._param.update(self.api.metadata(self.model, name)) + + return self._param[name] + + @cached_property + def param_level_ml(self): + return self.get_parameter("param_level_ml") or ([], []) + + @cached_property + def param_level_pl(self): + return self.get_parameter("param_level_pl") or ([], []) + + @cached_property + def param_sfc(self): + return self.get_parameter("param_sfc") or [] + + @cached_property + def lagged(self): + return self.get_parameter("lagged") or False + + @cached_property + def version(self): + return self.get_parameter("version") or 1 + + @cached_property + def grib_extra_metadata(self): + return self.get_parameter("grib_extra_metadata") or {} + + @cached_property + def retrieve(self): + return self.get_parameter("retrieve") or {} From 77a2489e3a796c8efae734d8d84c79cb827f6006 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 2 Apr 2024 13:46:22 +0000 Subject: [PATCH 14/21] Refactor remote api config to config.py --- ai_models/remote/api.py | 36 +++++++++++--------------------- ai_models/remote/config.py | 42 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 24 deletions(-) create mode 100644 ai_models/remote/config.py diff --git a/ai_models/remote/api.py b/ai_models/remote/api.py index d14412a..cf5d64e 100644 --- a/ai_models/remote/api.py +++ b/ai_models/remote/api.py @@ -8,6 +8,8 @@ from multiurl import download, robust from tqdm import tqdm +from .config import API_URL, CONFIG_PATH, load_config + LOG = logging.getLogger(__name__) @@ -28,37 +30,23 @@ def __init__( url: str = None, token: str = None, ): - root = os.path.join(os.path.expanduser("~"), ".config", "ai-models") - os.makedirs(root, exist_ok=True) - - configfile = os.path.join(root, "api.yaml") - - if os.path.exists(configfile): - from yaml import safe_load - - with open(configfile, "r") as f: - config = safe_load(f) or {} + config = load_config() - url = ( - url - or os.getenv("AI_MODELS_REMOTE_URL") - or config.get("url") - or "https://ai-models.ecmwf.int" + self.url = ( + url or os.getenv("AI_MODELS_REMOTE_URL") or config.get("url") or API_URL ) - LOG.info("Using remote server %s", url) + self.token = token or os.getenv("AI_MODELS_REMOTE_TOKEN") or config.get("token") - token = token or os.getenv("AI_MODELS_REMOTE_TOKEN") or config.get("token") - - if not token: + if not self.token: LOG.error( - "Missing remote token. Set it in %s or in env AI_MODELS_REMOTE_TOKEN", - configfile, + "Missing remote token. Set it in %s or env AI_MODELS_REMOTE_TOKEN", + CONFIG_PATH, ) sys.exit(1) - self.url = url - self.token = token - self.auth = BearerAuth(token) + LOG.info("Using remote server %s", self.url) + + self.auth = BearerAuth(self.token) self.output_file = output_file self.input_file = input_file self._timeout = 300 diff --git a/ai_models/remote/config.py b/ai_models/remote/config.py new file mode 100644 index 0000000..47667cb --- /dev/null +++ b/ai_models/remote/config.py @@ -0,0 +1,42 @@ +import logging +import os + +API_URL = "https://ai-models.ecmwf.int" + +ROOT_PATH = os.path.join(os.path.expanduser("~"), ".config", "ai-models") +CONFIG_PATH = os.path.join(ROOT_PATH, "api.yaml") + +LOG = logging.getLogger(__name__) + + +def config_exists(): + return os.path.exists(CONFIG_PATH) + + +def create_config(): + if config_exists(): + return + + try: + os.makedirs(ROOT_PATH, exist_ok=True) + with open(CONFIG_PATH, "w") as f: + f.write("token: \n") + f.write(f"url: {API_URL}\n") + except Exception as e: + LOG.error(f"Failed to create config {CONFIG_PATH}") + LOG.error(e, exc_info=True) + + +def load_config() -> dict: + from yaml import safe_load + + if not config_exists(): + create_config() + + try: + with open(CONFIG_PATH, "r") as f: + return safe_load(f) or {} + except Exception as e: + LOG.error(f"Failed to read config {CONFIG_PATH}") + LOG.error(e, exc_info=True) + return {} From 4f4a72070bfc0b69c11b1cc785b5bc62bafe1413 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 2 Apr 2024 14:09:42 +0000 Subject: [PATCH 15/21] Create remote config file on install --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 8ef8ae1..3a43480 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,8 @@ import setuptools +from ai_models.remote.config import config_exists, create_config + def read(fname): file_path = os.path.join(os.path.dirname(__file__), fname) @@ -28,6 +30,9 @@ def read(fname): assert version +if not config_exists(): + create_config() + setuptools.setup( name="ai-models", From c9c487f25c3802508fa6f1f850cd09d480b42418 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Thu, 4 Apr 2024 13:53:54 +0000 Subject: [PATCH 16/21] Check if remote model exists on server --- ai_models/__main__.py | 2 +- ai_models/remote/model.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index 175bf3f..e97caff 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -216,7 +216,7 @@ def _main(argv): parser.add_argument( "model", metavar="MODEL", - choices=available_models(), + choices=available_models() if "--remote" not in argv else None, help="The model to run", ) diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py index 7fdd8ee..9b9624c 100644 --- a/ai_models/remote/model.py +++ b/ai_models/remote/model.py @@ -1,4 +1,6 @@ +import logging import os +import sys import tempfile from functools import cached_property @@ -7,6 +9,8 @@ from ..model import Model from .api import RemoteAPI +LOG = logging.getLogger(__name__) + class RemoteModel(Model): def __init__(self, **kwargs): @@ -18,6 +22,13 @@ def __init__(self, **kwargs): self._param = {} self.api = RemoteAPI() + if self.model not in self.api.models(): + LOG.error(f"Model '{self.model}' not available on remote server.") + LOG.error( + "Rerun the command with --models --remote to list available remote models." + ) + sys.exit(1) + self.load_parameters() super().__init__(**self.cfg) From 83c6bc5e400d460cf48e87099003c528bcf893b6 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Thu, 4 Apr 2024 14:51:29 +0000 Subject: [PATCH 17/21] Safe parameter lookup --- ai_models/remote/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py index 9b9624c..7e3eb38 100644 --- a/ai_models/remote/model.py +++ b/ai_models/remote/model.py @@ -77,12 +77,12 @@ def load_parameters(self): self._param.update(params) def get_parameter(self, name): - if (param := self._param.get(name, None)) is not None: + if (param := self._param.get(name)) is not None: return param self._param.update(self.api.metadata(self.model, name)) - return self._param[name] + return self._param.get(name) @cached_property def param_level_ml(self): From 0824583802e43cfcb7e3827c298efe7e3cb087c4 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Fri, 5 Apr 2024 10:42:36 +0000 Subject: [PATCH 18/21] Update remote api url --- ai_models/__main__.py | 2 +- ai_models/remote/api.py | 5 ++++- ai_models/remote/config.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ai_models/__main__.py b/ai_models/__main__.py index e97caff..f75e548 100644 --- a/ai_models/__main__.py +++ b/ai_models/__main__.py @@ -239,7 +239,7 @@ def _main(argv): if len(models) == 0: print(f"No remote models available on {api.url}") sys.exit(0) - print(f"Models available on remote server {api.url}:") + print(f"Models available on remote server {api.url}") else: models = available_models() diff --git a/ai_models/remote/api.py b/ai_models/remote/api.py index cf5d64e..031477c 100644 --- a/ai_models/remote/api.py +++ b/ai_models/remote/api.py @@ -35,6 +35,9 @@ def __init__( self.url = ( url or os.getenv("AI_MODELS_REMOTE_URL") or config.get("url") or API_URL ) + if not self.url.endswith("/"): + self.url += "/" + self.token = token or os.getenv("AI_MODELS_REMOTE_TOKEN") or config.get("token") if not self.token: @@ -66,7 +69,7 @@ def run(self, cfg: dict): # submit task data = self._request(requests.post, data["href"], json=cfg) - LOG.info("Request submitted") + LOG.info("Inference request submitted") if data["status"] != "queued": LOG.error(data["status"]) diff --git a/ai_models/remote/config.py b/ai_models/remote/config.py index 47667cb..61af0d3 100644 --- a/ai_models/remote/config.py +++ b/ai_models/remote/config.py @@ -1,7 +1,7 @@ import logging import os -API_URL = "https://ai-models.ecmwf.int" +API_URL = "https://ai-models.ecmwf.int/api/v1/" ROOT_PATH = os.path.join(os.path.expanduser("~"), ".config", "ai-models") CONFIG_PATH = os.path.join(ROOT_PATH, "api.yaml") From bd0ececf993faae44ab2e94fe0ab36db91b9a4be Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 9 Apr 2024 13:56:37 +0000 Subject: [PATCH 19/21] Include model version in metadata request --- ai_models/remote/api.py | 10 +++++++--- ai_models/remote/model.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ai_models/remote/api.py b/ai_models/remote/api.py index 031477c..53aa03f 100644 --- a/ai_models/remote/api.py +++ b/ai_models/remote/api.py @@ -124,11 +124,15 @@ def run(self, cfg: dict): LOG.debug("Result written to %s", self.output_file) - def metadata(self, model, param) -> dict: + def metadata(self, model, model_version, param) -> dict: if isinstance(param, str): - return self._request(requests.get, f"metadata/{model}/{param}") + return self._request( + requests.get, f"metadata/{model}/{model_version}/{param}" + ) elif isinstance(param, (list, dict)): - return self._request(requests.post, f"metadata/{model}", json=param) + return self._request( + requests.post, f"metadata/{model}/{model_version}", json=param + ) else: raise ValueError("param must be a string, list, or dict with 'param' key.") diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py index 7e3eb38..9a2fc4c 100644 --- a/ai_models/remote/model.py +++ b/ai_models/remote/model.py @@ -19,6 +19,7 @@ def __init__(self, **kwargs): self.cfg["assets_extra_dir"] = None self.model = self.cfg["model"] + self.model_version = self.cfg.get("model_version", "latest") self._param = {} self.api = RemoteAPI() @@ -61,6 +62,7 @@ def patch_retrieve_request(self, request): def load_parameters(self): params = self.api.metadata( self.model, + self.model_version, [ "expver", "version", @@ -80,7 +82,8 @@ def get_parameter(self, name): if (param := self._param.get(name)) is not None: return param - self._param.update(self.api.metadata(self.model, name)) + _param = self.api.metadata(self.model, self.model_version, name) + self._param.update(_param) return self._param.get(name) From 0c7a3c01065e29694ed964b9b4c3dbf21cbae3bd Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 9 Apr 2024 15:00:24 +0000 Subject: [PATCH 20/21] Check if remote model needs retrieve requests patched --- ai_models/remote/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py index 9a2fc4c..642488b 100644 --- a/ai_models/remote/model.py +++ b/ai_models/remote/model.py @@ -56,6 +56,9 @@ def parse_model_args(self, args): return None def patch_retrieve_request(self, request): + if not self.remote_has_patch: + return + patched = self.api.patch_retrieve_request(self.cfg, request) request.update(patched) @@ -74,6 +77,7 @@ def load_parameters(self): "lagged", "grib_extra_metadata", "retrieve", + "remote_has_patch", # this is a custom parameter that checks if the remote model implemented patch_retrieve_request ], ) self._param.update(params) From bf8dc311c5f720498babab6c5343f81614d09a45 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 9 Apr 2024 15:02:25 +0000 Subject: [PATCH 21/21] Fix remote assets extra dir --- ai_models/remote/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ai_models/remote/model.py b/ai_models/remote/model.py index 642488b..1a4b7ee 100644 --- a/ai_models/remote/model.py +++ b/ai_models/remote/model.py @@ -16,7 +16,6 @@ class RemoteModel(Model): def __init__(self, **kwargs): self.cfg = kwargs self.cfg["download_assets"] = False - self.cfg["assets_extra_dir"] = None self.model = self.cfg["model"] self.model_version = self.cfg.get("model_version", "latest") @@ -77,7 +76,7 @@ def load_parameters(self): "lagged", "grib_extra_metadata", "retrieve", - "remote_has_patch", # this is a custom parameter that checks if the remote model implemented patch_retrieve_request + "remote_has_patch", # custom parameter, checks if remote model need patches ], ) self._param.update(params)