From d5f2fa0b8224ce871510d2a03df9c492cc3ce196 Mon Sep 17 00:00:00 2001 From: Paul Koch Date: Mon, 19 Aug 2024 11:15:05 -0400 Subject: [PATCH] retry azure commands if there's an http failure --- docs/benchmarks/ebm-benchmark.ipynb | 26 ++-- .../powerlift/powerlift/run_azure/__main__.py | 137 +++++++++++------- 2 files changed, 97 insertions(+), 66 deletions(-) diff --git a/docs/benchmarks/ebm-benchmark.ipynb b/docs/benchmarks/ebm-benchmark.ipynb index f07b598cd..c3d6d9e72 100644 --- a/docs/benchmarks/ebm-benchmark.ipynb +++ b/docs/benchmarks/ebm-benchmark.ipynb @@ -161,7 +161,7 @@ " # catboost doesn't like missing categoricals, so make them a category\n", " col_data = X[col]\n", " if str(col_data.dtype) == \"category\" and col_data.isnull().any():\n", - " X[col] = col_data.cat.add_categories('NaN').fillna('NaN')\n", + " X[col] = col_data.cat.add_categories('nan').fillna('nan')\n", " \n", " cat_bools = meta[\"categorical_mask\"]\n", " cat_cols = [i for i, val in enumerate(cat_bools) if val]\n", @@ -178,21 +178,13 @@ " \n", " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=stratification, random_state=seed)\n", "\n", - " # Build preprocessor\n", - " cat_ohe_step = (\"ohe\", OneHotEncoder(handle_unknown=\"ignore\"))\n", - " cat_pipe = Pipeline([cat_ohe_step])\n", - " num_pipe = Pipeline([(\"identity\", FunctionTransformer())])\n", - " transformers = [(\"cat\", cat_pipe, cat_cols), (\"num\", num_pipe, num_cols)]\n", - " ct = Pipeline(\n", - " [\n", - " (\"ct\", ColumnTransformer(transformers=transformers, sparse_threshold=0)),\n", - " (\n", - " \"missing\",\n", - " SimpleImputer(add_indicator=True, strategy=\"most_frequent\"),\n", - " ),\n", - " ]\n", - " )\n", - "\n", + " # Build optional preprocessor for use by methods below\n", + " # missing categoricals already handled above by making new \"nan\" category\n", + " cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=True, dtype=np.int16)\n", + " num_imputer = SimpleImputer(strategy=\"mean\")\n", + " transformers = [(\"cat\", cat_encoder, cat_cols), (\"num\", num_imputer, num_cols)]\n", + " ct = ColumnTransformer(transformers=transformers, sparse_threshold=0) # densify\n", + " \n", " # Specify method\n", " if trial.task.problem in [\"binary\", \"multiclass\"]:\n", " fit_params = {\"X\":X_train, \"y\":y_train}\n", @@ -264,7 +256,7 @@ " global_counter = 0\n", " \n", " # Train\n", - " print(f\"FIT: {global_counter}, {trial.task.origin}, {trial.task.name}, {trial.method.name}, \", end=\"\")\n", + " print(f\"FIT: {global_counter}, {trial.task.origin}, {trial.task.name}, {trial.method.name}\")\n", " with warnings.catch_warnings():\n", " warnings.filterwarnings(\"ignore\")\n", " start_time = time()\n", diff --git a/python/powerlift/powerlift/run_azure/__main__.py b/python/powerlift/powerlift/run_azure/__main__.py index 3becbddff..2d22d0c56 100644 --- a/python/powerlift/powerlift/run_azure/__main__.py +++ b/python/powerlift/powerlift/run_azure/__main__.py @@ -69,6 +69,7 @@ def run_trials( python startup.py """ + import time from azure.mgmt.containerinstance.models import ( ContainerGroup, Container, @@ -90,64 +91,95 @@ def run_trials( ) resource_group_name = azure_json["resource_group"] - aci_client = ContainerInstanceManagementClient( - credential, azure_json["subscription_id"] - ) - res_client = ResourceManagementClient(credential, azure_json["subscription_id"]) - resource_group = res_client.resource_groups.get(resource_group_name) # Run until completion. - container_counter = 0 n_tasks = len(tasks) n_containers = min(n_tasks, n_running_containers) results = {x: None for x in range(n_containers)} container_group_names = set() + resource_group = None + worker_id = None while len(tasks) != 0: - params = tasks.pop(0) - worker_id = _wait_for_completed_worker(aci_client, resource_group_name, results) - - experiment_id, trial_ids, uri, timeout, raise_exception, image = params - env_vars = [ - EnvironmentVariable(name="EXPERIMENT_ID", value=str(experiment_id)), - EnvironmentVariable( - name="TRIAL_IDS", value=",".join([str(x) for x in trial_ids]) - ), - EnvironmentVariable(name="DB_URL", secure_value=uri), - EnvironmentVariable(name="TIMEOUT", value=timeout), - EnvironmentVariable(name="RAISE_EXCEPTION", value=raise_exception), - ] - container_resource_requests = ResourceRequests( - cpu=num_cores, - memory_in_gb=mem_size_gb, - ) - container_resource_requirements = ResourceRequirements( - requests=container_resource_requests - ) - container_name = f"powerlift-container-{container_counter}" - container_counter += 1 - container = Container( - name=container_name, - image=image, - resources=container_resource_requirements, - command=["/bin/sh", "-c", startup_script.replace("\r\n", "\n")], - environment_variables=env_vars, - ) - container_group = ContainerGroup( - location=resource_group.location, - containers=[container], - os_type=OperatingSystemTypes.linux, - restart_policy=ContainerGroupRestartPolicy.never, - ) - container_group_name = f"powerlift-container-group-{worker_id}-{batch_id}" + experiment_id, trial_ids, uri, timeout, raise_exception, image = tasks.pop(0) + while True: + try: + if resource_group is None: + aci_client = ContainerInstanceManagementClient( + credential, azure_json["subscription_id"] + ) + res_client = ResourceManagementClient( + credential, azure_json["subscription_id"] + ) + resource_group = res_client.resource_groups.get(resource_group_name) - # begin_create_or_update returns LROPoller, - # but this is only indicates when the containter is started - aci_client.container_groups.begin_create_or_update( - resource_group.name, container_group_name, container_group - ) + # worker_id might be non-None if there was an exception we are retrying + if worker_id is None: + worker_id = _wait_for_completed_worker( + aci_client, resource_group_name, results + ) + container_group_name = ( + f"powerlift-container-group-{batch_id}-{worker_id}" + ) + else: + # if we previously started a container group but had an + # error we don't know the state of the container group, + # so delete it, if it exists, and restart + + # TODO: if begin_create_or_update wasn't reached then + # there will be no container group to delete so this + # will fail, but I don't know yet what exception it + # will fail with, so wrap with a try except here + # to allow that failure + delete_poller = aci_client.container_groups.begin_delete( + resource_group_name, container_group_name + ) + while not delete_poller.done(): + time.sleep(1) + + env_vars = [ + EnvironmentVariable(name="EXPERIMENT_ID", value=str(experiment_id)), + EnvironmentVariable( + name="TRIAL_IDS", value=",".join([str(x) for x in trial_ids]) + ), + EnvironmentVariable(name="DB_URL", secure_value=uri), + EnvironmentVariable(name="TIMEOUT", value=timeout), + EnvironmentVariable(name="RAISE_EXCEPTION", value=raise_exception), + ] + container_resource_requests = ResourceRequests( + cpu=num_cores, + memory_in_gb=mem_size_gb, + ) + container_resource_requirements = ResourceRequirements( + requests=container_resource_requests + ) + container = Container( + name="powerlift-container", + image=image, + resources=container_resource_requirements, + command=["/bin/sh", "-c", startup_script.replace("\r\n", "\n")], + environment_variables=env_vars, + ) + container_group = ContainerGroup( + location=resource_group.location, + containers=[container], + os_type=OperatingSystemTypes.linux, + restart_policy=ContainerGroupRestartPolicy.never, + ) + + # begin_create_or_update returns LROPoller, + # but this is only indicates when the containter is started + aci_client.container_groups.begin_create_or_update( + resource_group.name, container_group_name, container_group + ) + + break + except: # HttpResponseError normally, but I've seen others + resource_group = None + time.sleep(1) - container_group_names.add(container_group_name) results[worker_id] = container_group_name + worker_id = None + container_group_names.add(container_group_name) # Wait for all container groups to complete while ( @@ -157,8 +189,15 @@ def run_trials( # Delete all container groups if delete_group_container_on_complete: + delete_pollers = [] for container_group_name in container_group_names: - aci_client.container_groups.begin_delete( + delete_poller = aci_client.container_groups.begin_delete( resource_group_name, container_group_name ) + delete_pollers.append(delete_poller) + + for delete_poller in delete_pollers: + while not delete_poller.done(): + time.sleep(1) + return None