Skip to content

Commit

Permalink
Merge branch 'main' into javiermtorres/issue-492-runtime-resources
Browse files Browse the repository at this point in the history
  • Loading branch information
javiermtorres authored Jan 29, 2025
2 parents 658f734 + 89ae11c commit 25684bb
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 228 deletions.
11 changes: 11 additions & 0 deletions lumigator/python/mzai/backend/backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from backend.services.datasets import DatasetService
from backend.services.experiments import ExperimentService
from backend.services.jobs import JobService
from backend.services.workflows import WorkflowService
from backend.settings import settings


Expand Down Expand Up @@ -74,6 +75,16 @@ def get_experiment_service(
ExperimentServiceDep = Annotated[ExperimentService, Depends(get_experiment_service)]


def get_workflow_service(
session: DBSessionDep, job_service: JobServiceDep, dataset_service: DatasetServiceDep
) -> WorkflowService:
job_repo = JobRepository(session)
return WorkflowService(job_repo, job_service, dataset_service)


WorkflowServiceDep = Annotated[WorkflowService, Depends(get_workflow_service)]


def get_mistral_completion_service() -> MistralCompletionService:
return MistralCompletionService()

Expand Down
13 changes: 3 additions & 10 deletions lumigator/python/mzai/backend/backend/api/router.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from fastapi import APIRouter

from backend.api.routes import (
completions,
datasets,
experiments,
experiments_new,
health,
jobs,
models,
)
from backend.api.routes import completions, datasets, experiments, health, jobs, models, workflows
from backend.api.tags import Tags

API_V1_PREFIX = "/api/v1"
Expand All @@ -20,6 +12,7 @@
api_router.include_router(experiments.router, prefix="/experiments", tags=[Tags.EXPERIMENTS])
api_router.include_router(completions.router, prefix="/completions", tags=[Tags.COMPLETIONS])
api_router.include_router(models.router, prefix="/models", tags=[Tags.MODELS])
# TODO: Workflows route is not yet ready so it is excluded from the OpenAPI schema
api_router.include_router(
experiments_new.router, prefix="/experiments_new", tags=[Tags.EXPERIMENTS_NEW]
workflows.router, prefix="/workflows", tags=[Tags.WORKFLOWS], include_in_schema=False
)
57 changes: 56 additions & 1 deletion lumigator/python/mzai/backend/backend/api/routes/experiments.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from http import HTTPStatus
from uuid import UUID

from fastapi import APIRouter, BackgroundTasks, status
from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentIdCreate,
ExperimentResponse,
ExperimentResultDownloadResponse,
ExperimentResultResponse,
Expand All @@ -12,11 +14,19 @@
JobEvalCreate,
)

from backend.api.deps import JobServiceDep
from backend.api.deps import ExperimentServiceDep, JobServiceDep
from backend.services.exceptions.base_exceptions import ServiceError
from backend.services.exceptions.experiment_exceptions import ExperimentNotFoundError

router = APIRouter()


def experiment_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:
return {
ExperimentNotFoundError: status.HTTP_404_NOT_FOUND,
}


@router.post("/", status_code=status.HTTP_201_CREATED)
def create_experiment(
service: JobServiceDep, request: ExperimentCreate, background_tasks: BackgroundTasks
Expand Down Expand Up @@ -60,3 +70,48 @@ def get_experiment_result_download(
return ExperimentResultDownloadResponse.model_validate(
service.get_job_result_download(experiment_id).model_dump()
)


####################################################################################################
# "new" routes
####################################################################################################
# These "experiments_new" routes are not yet ready to be exposed in the OpenAPI schema,
# because it is designed for the API when 'workflows' are supported
# TODO: Eventually this route will become the / route,
# but right now it is a placeholder while we build up the Workflows routes
# It's not included in the OpenAPI schema for now so it's not visible in the docs
@router.post("/new", status_code=status.HTTP_201_CREATED, include_in_schema=False)
def create_experiment_id(
service: ExperimentServiceDep, request: ExperimentIdCreate
) -> ExperimentResponse:
"""Create an experiment ID."""
return ExperimentResponse.model_validate(service.create_experiment(request).model_dump())


# TODO: FIXME this should not need the /all suffix.
# See further discussion https://github.com/mozilla-ai/lumigator/pull/728/files/2c960962c365d72e0714a16333884f0f209d214e#r1932176937
@router.get("/new/all", include_in_schema=False)
def list_experiments_new(
service: ExperimentServiceDep,
skip: int = 0,
limit: int = 100,
) -> ListingResponse[ExperimentResponse]:
"""List all experiments."""
return ListingResponse[ExperimentResponse].model_validate(
service.list_experiments(skip, limit).model_dump()
)


@router.get("/new/{experiment_id}", include_in_schema=False)
def get_experiment_new(service: ExperimentServiceDep, experiment_id: UUID) -> ExperimentResponse:
"""Get an experiment by ID."""
return ExperimentResponse.model_validate(service.get_experiment(experiment_id).model_dump())


@router.get("/new/{experiment_id}/workflows", include_in_schema=False)
def get_workflows(service: ExperimentServiceDep, experiment_id: UUID) -> ListingResponse[UUID]:
"""TODO: this endpoint should handle passing in an experiment id and the returning a list
of all the workflows associated with that experiment. Until workflows are stored and associated
with experiments, this is not yet implemented.
"""
raise NotImplementedError

This file was deleted.

79 changes: 79 additions & 0 deletions lumigator/python/mzai/backend/backend/api/routes/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from uuid import UUID

from fastapi import APIRouter, BackgroundTasks, status
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import JobResponse
from lumigator_schemas.workflows import (
WorkflowCreate,
WorkflowDetailsResponse,
WorkflowResponse,
WorkflowResultDownloadResponse,
)

from backend.api.deps import WorkflowServiceDep

router = APIRouter()


@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_workflow(
service: WorkflowServiceDep, request: WorkflowCreate, background_tasks: BackgroundTasks
) -> WorkflowResponse:
"""A workflow is a single execution for an experiment.
A workflow is a collection of 1 or more jobs.
It must be associated with an experiment id,
which means you must already have created an experiment and have that ID in the request.
"""
return WorkflowResponse.model_validate(service.create_workflow(request, background_tasks))


@router.get("/{workflow_id}")
def get_workflow(service: WorkflowServiceDep, workflow_id: UUID) -> WorkflowResponse:
"""TODO: The workflow objects are currently not saved in the database so it can't be retrieved.
In order to get all the info about a workflow,
you need to get all the jobs for an experiment and make some decisions about how to use them.
This means you can't yet easily compile a list of all workflows for an experiment.
"""
raise NotImplementedError


# TODO: currently experiment_id=workflow_id, but this will change
@router.get("/{experiment_id}/jobs", include_in_schema=False)
def get_workflow_jobs(
service: WorkflowServiceDep, experiment_id: UUID
) -> ListingResponse[JobResponse]:
"""Get all jobs for a workflow.
TODO: this will likely eventually be merged with the get_workflow endpoint, once implemented
"""
# TODO right now this command expects that the workflow_id is the same as the experiment_id
return ListingResponse[JobResponse].model_validate(
service.get_workflow_jobs(experiment_id).model_dump()
)


@router.get("/{workflow_id}/details")
def get_workflow_details(
service: WorkflowServiceDep,
workflow_id: UUID,
) -> WorkflowDetailsResponse:
"""TODO:Return the results metadata for a run if available in the DB.
This should retrieve the metadata for the job or jobs that were run in the workflow and compile
them into a single response that can be used to populate the UI.
Currently this looks like taking the average results for the
inference job (tok/s, gen length, etc) and the
average results for the evaluation job (ROUGE, BLEU, etc) and
returning them in a single response.
For detailed results you would want to use the get_workflow_details endpoint.
"""
raise NotImplementedError


@router.get("/{workflow_id}/details")
def get_experiment_result_download(
service: WorkflowServiceDep,
workflow_id: UUID,
) -> WorkflowResultDownloadResponse:
"""Return experiment results file URL for downloading."""
return WorkflowResultDownloadResponse.model_validate(
service.get_workflow_result_download(workflow_id).model_dump()
)
6 changes: 3 additions & 3 deletions lumigator/python/mzai/backend/backend/api/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Tags(str, Enum):
JOBS = "jobs"
COMPLETIONS = "completions"
EXPERIMENTS = "experiments"
EXPERIMENTS_NEW = "experiments_new"
WORKFLOWS = "workflows"
MODELS = "models"


Expand All @@ -25,8 +25,8 @@ class Tags(str, Enum):
"description": "Create and manage experiments.",
},
{
"name": Tags.EXPERIMENTS_NEW,
"description": "Create and manage experiments (new).",
"name": Tags.WORKFLOWS,
"description": "Create and manage workflows.",
},
{
"name": Tags.JOBS,
Expand Down
2 changes: 1 addition & 1 deletion lumigator/python/mzai/backend/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from backend.api.router import api_router
from backend.api.routes.completions import completion_exception_mappings
from backend.api.routes.datasets import dataset_exception_mappings
from backend.api.routes.experiments_new import experiment_exception_mappings
from backend.api.routes.experiments import experiment_exception_mappings
from backend.api.routes.jobs import job_exception_mappings
from backend.api.tags import TAGS_METADATA
from backend.services.exceptions.base_exceptions import ServiceError
Expand Down
Loading

0 comments on commit 25684bb

Please sign in to comment.