Skip to content

Commit

Permalink
Use new task wrappers in task server
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Mar 7, 2024
1 parent 47dc23d commit 3902ba2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 83 deletions.
4 changes: 4 additions & 0 deletions colmena/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class ColmenaTask:
name: str
"""Name used to identify the function"""

@property
def __name__(self):
return self.name

def function(self, *args, **kwargs) -> Any:
"""Function provided by the Colmena user"""
raise NotImplementedError()
Expand Down
81 changes: 14 additions & 67 deletions colmena/task_server/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""Base classes for the Task Server and associated functions"""
import logging
import os
import platform
from inspect import isgeneratorfunction
from abc import ABCMeta, abstractmethod
from concurrent.futures import Future
from inspect import signature
from multiprocessing import Process
from time import perf_counter
from typing import Optional, Callable, Collection
from typing import Collection, Optional, Callable, Union

from colmena.exceptions import KillSignalException, TimeoutException
from colmena.models.tasks import ColmenaTask, PythonGeneratorTask, PythonTask
from colmena.models import Result, FailureInformation
from colmena.proxy import resolve_proxies_async, store_proxy_stats
from colmena.queue.base import ColmenaQueues

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,68 +161,18 @@ def process_queue(self, topic: str, task: Result):
future.add_done_callback(lambda x: self.perform_callback(x, task, topic))


def run_and_record_timing(func: Callable, result: Result) -> Result:
"""Run a function and also return the runtime
def convert_to_colmena_task(function: Union[Callable, ColmenaTask]) -> ColmenaTask:
"""Wrap user-supplified functions in the task model wrapper, if needed
Args:
func: Function to invoke
result: Result object describing task request
function: User-provided function
Returns:
Result object with the serialized result
Function as appropriate subclasses of Colmena Task wrapper
"""
# Mark that compute has started on the worker
result.mark_compute_started()

# Unpack the inputs
result.time.deserialize_inputs = result.deserialize()

# Start resolving any proxies in the input asynchronously
start_time = perf_counter()
input_proxies = []
for arg in result.args:
input_proxies.extend(resolve_proxies_async(arg))
for value in result.kwargs.values():
input_proxies.extend(resolve_proxies_async(value))
result.time.async_resolve_proxies = perf_counter() - start_time

# Execute the function
start_time = perf_counter()
success = True
try:
if '_resources' in result.kwargs:
logger.warning('`_resources` provided as a kwargs. Unexpected things are about to happen')
if '_resources' in signature(func).parameters:
output = func(*result.args, **result.kwargs, _resources=result.resources)
else:
output = func(*result.args, **result.kwargs)
except BaseException as e:
output = None
success = False
result.failure_info = FailureInformation.from_exception(e)
finally:
end_time = perf_counter()

# Store the results
result.set_result(output, end_time - start_time)
if not success:
result.success = False

# Add the worker information into the tasks, if available
worker_info = {}
# TODO (wardlt): Move this information into a separate, parsl-specific wrapper
for tag in ['PARSL_WORKER_RANK', 'PARSL_WORKER_POOL_ID']:
if tag in os.environ:
worker_info[tag] = os.environ[tag]
worker_info['hostname'] = platform.node()
result.worker_info = worker_info

result.mark_compute_ended()

# Re-pack the results. Will store the proxy statistics
result.time.serialize_results, _ = result.serialize()

# Get the statistics for the proxy resolution
for proxy in input_proxies:
store_proxy_stats(proxy, result.time.proxy)

return result

if isinstance(function, ColmenaTask):
return function
elif isgeneratorfunction(function):
return PythonGeneratorTask(function)
else:
return PythonTask(function)
14 changes: 6 additions & 8 deletions colmena/task_server/globus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
Tasks and results are communicated to/from the endpoint through a cloud service secured using Globus Auth."""

import logging
from functools import partial, update_wrapper
from typing import Dict, Callable, Optional, Tuple
from concurrent.futures import Future

from globus_compute_sdk import Client, Executor

from colmena.task_server.base import run_and_record_timing, FutureBasedTaskServer
from colmena.task_server.base import convert_to_colmena_task, FutureBasedTaskServer
from colmena.queue.python import PipeQueues

from colmena.models import Result
Expand Down Expand Up @@ -58,13 +57,12 @@ def __init__(self,
# Create a function with the latest version of the wrapper function
self.registered_funcs: Dict[str, Tuple[str, str]] = {} # Function name -> (funcX id, endpoints)
for func, endpoint in methods.items():
# Make a wrapped version of the function
func_name = func.__name__
new_func = partial(run_and_record_timing, func)
update_wrapper(new_func, func)
func_fxid = self.fx_client.register_function(new_func)
# Register a wrapped version of the function
task = convert_to_colmena_task(func)
func_fxid = self.fx_client.register_function(task)

# Store the information for the function
self.registered_funcs[func_name] = (func_fxid, endpoint)
self.registered_funcs[task.name] = (func_fxid, endpoint)

self._batch_options = dict(
batch_size=batch_size,
Expand Down
16 changes: 8 additions & 8 deletions colmena/task_server/parsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import platform
import shutil
from concurrent.futures import Future
from functools import partial, update_wrapper
from functools import partial
from pathlib import Path
from tempfile import mkdtemp
from time import perf_counter
Expand All @@ -22,7 +22,7 @@
from colmena.queue.base import ColmenaQueues
from colmena.models import Result, FailureInformation, ResourceRequirements
from colmena.proxy import resolve_proxies_async
from colmena.task_server.base import run_and_record_timing, FutureBasedTaskServer
from colmena.task_server.base import convert_to_colmena_task, FutureBasedTaskServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -352,14 +352,13 @@ def __init__(self, methods: List[Union[Callable, Tuple[Callable, Dict]]],
options = {'executors': default_executors}
logger.info(f'Using default executors for {function.__name__}: {default_executors}')

# Make the Parsl app
name = function.__name__
# Convert the function to a Colmena task
function = convert_to_colmena_task(method)
name = function.name

# If the function is an executable, just wrap it
# If the function is not an executable, submit it as a single task
if not isinstance(function, ExecutableTask):
wrapped_function = partial(run_and_record_timing, function)
wrapped_function = update_wrapper(wrapped_function, function)
app = PythonApp(wrapped_function, **options)
app = PythonApp(function, **options)
self.methods_[name] = (app, 'basic')
else:
logger.info(f'Building a chain of apps for an ExecutableTask, {function.__name__}')
Expand Down Expand Up @@ -406,6 +405,7 @@ def _submit(self, task: Result, topic: str) -> Optional[Future]:
elif func_type == 'exec':
# For executable functions, we have a different route for returning results
exec_app, post_app = self.exec_apps_[method]
# TODO (wardlt): Use a join_app rather than callback?
future.add_done_callback(lambda x: _preprocess_callback(x, serialized_inputs, task, self, topic, exec_app, post_app))
return None # `None` prevents the Task Server from adding its own callback
else:
Expand Down

0 comments on commit 3902ba2

Please sign in to comment.