Skip to content

Commit

Permalink
retry azure commands if there's an http failure
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbkoch committed Aug 19, 2024
1 parent ecd319a commit d5f2fa0
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 66 deletions.
26 changes: 9 additions & 17 deletions docs/benchmarks/ebm-benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
" # catboost doesn't like missing categoricals, so make them a category\n",
" col_data = X[col]\n",
" if str(col_data.dtype) == \"category\" and col_data.isnull().any():\n",
" X[col] = col_data.cat.add_categories('NaN').fillna('NaN')\n",
" X[col] = col_data.cat.add_categories('nan').fillna('nan')\n",
" \n",
" cat_bools = meta[\"categorical_mask\"]\n",
" cat_cols = [i for i, val in enumerate(cat_bools) if val]\n",
Expand All @@ -178,21 +178,13 @@
" \n",
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=stratification, random_state=seed)\n",
"\n",
" # Build preprocessor\n",
" cat_ohe_step = (\"ohe\", OneHotEncoder(handle_unknown=\"ignore\"))\n",
" cat_pipe = Pipeline([cat_ohe_step])\n",
" num_pipe = Pipeline([(\"identity\", FunctionTransformer())])\n",
" transformers = [(\"cat\", cat_pipe, cat_cols), (\"num\", num_pipe, num_cols)]\n",
" ct = Pipeline(\n",
" [\n",
" (\"ct\", ColumnTransformer(transformers=transformers, sparse_threshold=0)),\n",
" (\n",
" \"missing\",\n",
" SimpleImputer(add_indicator=True, strategy=\"most_frequent\"),\n",
" ),\n",
" ]\n",
" )\n",
"\n",
" # Build optional preprocessor for use by methods below\n",
" # missing categoricals already handled above by making new \"nan\" category\n",
" cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=True, dtype=np.int16)\n",
" num_imputer = SimpleImputer(strategy=\"mean\")\n",
" transformers = [(\"cat\", cat_encoder, cat_cols), (\"num\", num_imputer, num_cols)]\n",
" ct = ColumnTransformer(transformers=transformers, sparse_threshold=0) # densify\n",
" \n",
" # Specify method\n",
" if trial.task.problem in [\"binary\", \"multiclass\"]:\n",
" fit_params = {\"X\":X_train, \"y\":y_train}\n",
Expand Down Expand Up @@ -264,7 +256,7 @@
" global_counter = 0\n",
" \n",
" # Train\n",
" print(f\"FIT: {global_counter}, {trial.task.origin}, {trial.task.name}, {trial.method.name}, \", end=\"\")\n",
" print(f\"FIT: {global_counter}, {trial.task.origin}, {trial.task.name}, {trial.method.name}\")\n",
" with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\")\n",
" start_time = time()\n",
Expand Down
137 changes: 88 additions & 49 deletions python/powerlift/powerlift/run_azure/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def run_trials(
python startup.py
"""

import time
from azure.mgmt.containerinstance.models import (
ContainerGroup,
Container,
Expand All @@ -90,64 +91,95 @@ def run_trials(
)

resource_group_name = azure_json["resource_group"]
aci_client = ContainerInstanceManagementClient(
credential, azure_json["subscription_id"]
)
res_client = ResourceManagementClient(credential, azure_json["subscription_id"])
resource_group = res_client.resource_groups.get(resource_group_name)

# Run until completion.
container_counter = 0
n_tasks = len(tasks)
n_containers = min(n_tasks, n_running_containers)
results = {x: None for x in range(n_containers)}
container_group_names = set()
resource_group = None
worker_id = None
while len(tasks) != 0:
params = tasks.pop(0)
worker_id = _wait_for_completed_worker(aci_client, resource_group_name, results)

experiment_id, trial_ids, uri, timeout, raise_exception, image = params
env_vars = [
EnvironmentVariable(name="EXPERIMENT_ID", value=str(experiment_id)),
EnvironmentVariable(
name="TRIAL_IDS", value=",".join([str(x) for x in trial_ids])
),
EnvironmentVariable(name="DB_URL", secure_value=uri),
EnvironmentVariable(name="TIMEOUT", value=timeout),
EnvironmentVariable(name="RAISE_EXCEPTION", value=raise_exception),
]
container_resource_requests = ResourceRequests(
cpu=num_cores,
memory_in_gb=mem_size_gb,
)
container_resource_requirements = ResourceRequirements(
requests=container_resource_requests
)
container_name = f"powerlift-container-{container_counter}"
container_counter += 1
container = Container(
name=container_name,
image=image,
resources=container_resource_requirements,
command=["/bin/sh", "-c", startup_script.replace("\r\n", "\n")],
environment_variables=env_vars,
)
container_group = ContainerGroup(
location=resource_group.location,
containers=[container],
os_type=OperatingSystemTypes.linux,
restart_policy=ContainerGroupRestartPolicy.never,
)
container_group_name = f"powerlift-container-group-{worker_id}-{batch_id}"
experiment_id, trial_ids, uri, timeout, raise_exception, image = tasks.pop(0)
while True:
try:
if resource_group is None:
aci_client = ContainerInstanceManagementClient(
credential, azure_json["subscription_id"]
)
res_client = ResourceManagementClient(
credential, azure_json["subscription_id"]
)
resource_group = res_client.resource_groups.get(resource_group_name)

# begin_create_or_update returns LROPoller,
# but this is only indicates when the containter is started
aci_client.container_groups.begin_create_or_update(
resource_group.name, container_group_name, container_group
)
# worker_id might be non-None if there was an exception we are retrying
if worker_id is None:
worker_id = _wait_for_completed_worker(
aci_client, resource_group_name, results
)
container_group_name = (
f"powerlift-container-group-{batch_id}-{worker_id}"
)
else:
# if we previously started a container group but had an
# error we don't know the state of the container group,
# so delete it, if it exists, and restart

# TODO: if begin_create_or_update wasn't reached then
# there will be no container group to delete so this
# will fail, but I don't know yet what exception it
# will fail with, so wrap with a try except here
# to allow that failure
delete_poller = aci_client.container_groups.begin_delete(
resource_group_name, container_group_name
)
while not delete_poller.done():
time.sleep(1)

env_vars = [
EnvironmentVariable(name="EXPERIMENT_ID", value=str(experiment_id)),
EnvironmentVariable(
name="TRIAL_IDS", value=",".join([str(x) for x in trial_ids])
),
EnvironmentVariable(name="DB_URL", secure_value=uri),
EnvironmentVariable(name="TIMEOUT", value=timeout),
EnvironmentVariable(name="RAISE_EXCEPTION", value=raise_exception),
]
container_resource_requests = ResourceRequests(
cpu=num_cores,
memory_in_gb=mem_size_gb,
)
container_resource_requirements = ResourceRequirements(
requests=container_resource_requests
)
container = Container(
name="powerlift-container",
image=image,
resources=container_resource_requirements,
command=["/bin/sh", "-c", startup_script.replace("\r\n", "\n")],
environment_variables=env_vars,
)
container_group = ContainerGroup(
location=resource_group.location,
containers=[container],
os_type=OperatingSystemTypes.linux,
restart_policy=ContainerGroupRestartPolicy.never,
)

# begin_create_or_update returns LROPoller,
# but this is only indicates when the containter is started
aci_client.container_groups.begin_create_or_update(
resource_group.name, container_group_name, container_group
)

break
except: # HttpResponseError normally, but I've seen others
resource_group = None
time.sleep(1)

container_group_names.add(container_group_name)
results[worker_id] = container_group_name
worker_id = None
container_group_names.add(container_group_name)

# Wait for all container groups to complete
while (
Expand All @@ -157,8 +189,15 @@ def run_trials(

# Delete all container groups
if delete_group_container_on_complete:
delete_pollers = []
for container_group_name in container_group_names:
aci_client.container_groups.begin_delete(
delete_poller = aci_client.container_groups.begin_delete(
resource_group_name, container_group_name
)
delete_pollers.append(delete_poller)

for delete_poller in delete_pollers:
while not delete_poller.done():
time.sleep(1)

return None

0 comments on commit d5f2fa0

Please sign in to comment.