Skip to content

Commit

Permalink
fix: ensure pydantic models serialize to json-compatible dicts. (#346)
Browse files Browse the repository at this point in the history
add regression test

closes #345
  • Loading branch information
dwinston committed Nov 3, 2023
1 parent cbead62 commit bd9df67
Show file tree
Hide file tree
Showing 25 changed files with 346 additions and 55 deletions.
4 changes: 3 additions & 1 deletion components/nmdc_runtime/workflow_execution_activity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def insert_into_keys(
workflow: Workflow, data_objects: list[DataObject]
) -> dict[str, Any]:
"""Insert data object url into correct workflow input field."""
workflow_dict = workflow.dict()
workflow_dict = workflow.model_dump(
mode="json",
)
for key in workflow_dict["inputs"]:
for do in data_objects:
if workflow_dict["inputs"][key] == str(do.data_object_type):
Expand Down
2 changes: 1 addition & 1 deletion nmdc_runtime/api/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def generate_secret(length=12):
def json_clean(data, model, exclude_unset=False) -> dict:
"""Run data through a JSON serializer for a pydantic model."""
if not isinstance(data, (dict, BaseModel)):
raise TypeError("`data` must be a pydantic model or its .dict()")
raise TypeError("`data` must be a pydantic model or its .model_dump()")
m = model(**data) if isinstance(data, dict) else data
return json.loads(m.json(exclude_unset=exclude_unset))
6 changes: 4 additions & 2 deletions nmdc_runtime/api/endpoints/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_object(
"""
id_supplied = supplied_object_id(
mdb, client_site, object_in.dict(exclude_unset=True)
mdb, client_site, object_in.model_dump(mode="json", exclude_unset=True)
)
drs_id = local_part(
id_supplied if id_supplied is not None else generate_one_id(mdb, S3_ID_NS)
Expand Down Expand Up @@ -255,7 +255,9 @@ def update_object(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"client authorized for different site_id than {object_mgr_site}",
)
doc_object_patched = merge(doc, object_patch.dict(exclude_unset=True))
doc_object_patched = merge(
doc, object_patch.model_dump(mode="json", exclude_unset=True)
)
mdb.operations.replace_one({"id": object_id}, doc_object_patched)
return doc_object_patched

Expand Down
8 changes: 6 additions & 2 deletions nmdc_runtime/api/endpoints/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,16 @@ def update_operation(
detail=f"client authorized for different site_id than {site_id_op}",
)
op_patch_metadata = merge(
op_patch.dict(exclude_unset=True).get("metadata", {}),
op_patch.model_dump(mode="json", exclude_unset=True).get("metadata", {}),
pick(["site_id", "job", "model"], doc_op.get("metadata", {})),
)
doc_op_patched = merge(
doc_op,
assoc(op_patch.dict(exclude_unset=True), "metadata", op_patch_metadata),
assoc(
op_patch.model_dump(mode="json", exclude_unset=True),
"metadata",
op_patch_metadata,
),
)
mdb.operations.replace_one({"id": op_id}, doc_op_patched)
return doc_op_patched
Expand Down
10 changes: 5 additions & 5 deletions nmdc_runtime/api/endpoints/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def run_query(
id=qid,
saved_at=saved_at,
)
mdb.queries.insert_one(query.dict(exclude_unset=True))
mdb.queries.insert_one(query.model_dump(mode="json", exclude_unset=True))
cmd_response = _run_query(query, mdb)
return unmongo(cmd_response.dict(exclude_unset=True))
return unmongo(cmd_response.model_dump(mode="json", exclude_unset=True))


@router.get("/queries/{query_id}", response_model=Query)
Expand Down Expand Up @@ -107,7 +107,7 @@ def rerun_query(
check_can_delete(user)

cmd_response = _run_query(query, mdb)
return unmongo(cmd_response.dict(exclude_unset=True))
return unmongo(cmd_response.model_dump(mode="json", exclude_unset=True))


def _run_query(query, mdb) -> CommandResponse:
Expand All @@ -131,12 +131,12 @@ def _run_query(query, mdb) -> CommandResponse:
detail="Failed to back up to-be-deleted documents. operation aborted.",
)

q_response = mdb.command(query.cmd.dict(exclude_unset=True))
q_response = mdb.command(query.cmd.model_dump(mode="json", exclude_unset=True))
cmd_response: CommandResponse = command_response_for(q_type)(**q_response)
query_run = (
QueryRun(qid=query.id, ran_at=ran_at, result=cmd_response)
if cmd_response.ok
else QueryRun(qid=query.id, ran_at=ran_at, error=cmd_response)
)
mdb.query_runs.insert_one(query_run.dict(exclude_unset=True))
mdb.query_runs.insert_one(query_run.model_dump(mode="json", exclude_unset=True))
return cmd_response
6 changes: 5 additions & 1 deletion nmdc_runtime/api/endpoints/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,9 @@ def post_run_event(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Supplied run_event.run.id does not match run_id given in request URL.",
)
mdb.run_events.insert_one(run_event.dict())
mdb.run_events.insert_one(
run_event.model_dump(
mode="json",
)
)
return _get_run_summary(run_event.run.id, mdb)
4 changes: 3 additions & 1 deletion nmdc_runtime/api/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def data_objects(
req: DataObjectListRequest = Depends(),
mdb: MongoDatabase = Depends(get_mongo_db),
):
filter_ = list_request_filter_to_mongo_filter(req.dict(exclude_unset=True))
filter_ = list_request_filter_to_mongo_filter(
req.model_dump(mode="json", exclude_unset=True)
)
max_page_size = filter_.pop("max_page_size", None)
page_token = filter_.pop("page_token", None)
req = ListRequest(
Expand Down
12 changes: 10 additions & 2 deletions nmdc_runtime/api/endpoints/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def create_site(
status_code=status.HTTP_409_CONFLICT,
detail=f"site with supplied id {site.id} already exists",
)
mdb.sites.insert_one(site.dict())
mdb.sites.insert_one(
site.model_dump(
mode="json",
)
)
refresh_minter_requesters_from_sites()
rv = mdb.users.update_one(
{"username": user.username},
Expand Down Expand Up @@ -165,7 +169,11 @@ def put_object_in_site(
},
}
)
mdb.operations.insert_one(op.dict())
mdb.operations.insert_one(
op.model_dump(
mode="json",
)
)
return op


Expand Down
22 changes: 17 additions & 5 deletions nmdc_runtime/api/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ async def login_for_access_token(
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.dict())
access_token_expires = timedelta(
**ACCESS_TOKEN_EXPIRES.model_dump(
mode="json",
)
)
access_token = create_access_token(
data={"sub": f"user:{user.username}"}, expires_delta=access_token_expires
)
Expand All @@ -51,7 +55,11 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"},
)
# TODO make below an absolute time
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.dict())
access_token_expires = timedelta(
**ACCESS_TOKEN_EXPIRES.model_dump(
mode="json",
)
)
access_token = create_access_token(
data={"sub": f"client:{form_data.client_id}"},
expires_delta=access_token_expires,
Expand All @@ -69,7 +77,9 @@ async def login_for_access_token(
return {
"access_token": access_token,
"token_type": "bearer",
"expires": ACCESS_TOKEN_EXPIRES.dict(),
"expires": ACCESS_TOKEN_EXPIRES.model_dump(
mode="json",
),
}


Expand All @@ -95,8 +105,10 @@ def create_user(
check_can_create_user(requester)
mdb.users.insert_one(
UserInDB(
**user_in.dict(),
**user_in.model_dump(
mode="json",
),
hashed_password=get_password_hash(user_in.password),
).dict(exclude_unset=True)
).model_dump(mode="json", exclude_unset=True)
)
return mdb.users.find_one({"username": user_in.username})
20 changes: 14 additions & 6 deletions nmdc_runtime/api/endpoints/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,11 @@ def _create_object(
mdb: MongoDatabase, object_in: DrsObjectIn, mgr_site, drs_id, self_uri
):
drs_obj = DrsObject(
**object_in.dict(exclude_unset=True), id=drs_id, self_uri=self_uri
**object_in.model_dump(exclude_unset=True, mode="json"),
id=drs_id,
self_uri=self_uri,
)
doc = drs_obj.dict(exclude_unset=True)
doc = drs_obj.model_dump(exclude_unset=True, mode="json")
doc["_mgr_site"] = mgr_site # manager site
try:
mdb.objects.insert_one(doc)
Expand Down Expand Up @@ -526,16 +528,22 @@ def _claim_job(job_id: str, mdb: MongoDatabase, site: Site):
"workflow": job.workflow,
"config": job.config,
}
).dict(exclude_unset=True),
).model_dump(mode="json", exclude_unset=True),
"site_id": site.id,
"model": dotted_path_for(JobOperationMetadata),
},
}
)
mdb.operations.insert_one(op.dict())
mdb.jobs.replace_one({"id": job.id}, job.dict(exclude_unset=True))
mdb.operations.insert_one(
op.model_dump(
mode="json",
)
)
mdb.jobs.replace_one(
{"id": job.id}, job.model_dump(mode="json", exclude_unset=True)
)

return op.dict(exclude_unset=True)
return op.model_dump(mode="json", exclude_unset=True)


@lru_cache
Expand Down
10 changes: 7 additions & 3 deletions nmdc_runtime/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ def ensure_initial_resources_on_boot():
collection_boot = import_module(f"nmdc_runtime.api.boot.{collection_name}")

for model in collection_boot.construct():
doc = model.dict()
doc = model.model_dump(
mode="json",
)
mdb[collection_name].replace_one({"id": doc["id"]}, doc, upsert=True)

username = os.getenv("API_ADMIN_USER")
Expand All @@ -247,7 +249,7 @@ def ensure_initial_resources_on_boot():
username=username,
hashed_password=get_password_hash(os.getenv("API_ADMIN_PASS")),
site_admin=[os.getenv("API_SITE_ID")],
).dict(exclude_unset=True),
).model_dump(mode="json", exclude_unset=True),
upsert=True,
)
mdb.users.create_index("username")
Expand All @@ -268,7 +270,9 @@ def ensure_initial_resources_on_boot():
),
)
],
).dict(),
).model_dump(
mode="json",
),
upsert=True,
)

Expand Down
18 changes: 14 additions & 4 deletions nmdc_runtime/api/models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def _add_run_requested_event(run_spec: RunUserSpec, mdb: MongoDatabase, user: Us
time=now(as_str=True),
inputs=run_spec.inputs,
)
mdb.run_events.insert_one(event.dict())
mdb.run_events.insert_one(
event.model_dump(
mode="json",
)
)
return run_id


Expand All @@ -113,7 +117,9 @@ def _add_run_started_event(run_id: str, mdb: MongoDatabase):
job=requested.job,
type=RunEventType.STARTED,
time=now(as_str=True),
).dict()
).model_dump(
mode="json",
)
)
return run_id

Expand All @@ -134,7 +140,9 @@ def _add_run_fail_event(run_id: str, mdb: MongoDatabase):
job=requested.job,
type=RunEventType.FAIL,
time=now(as_str=True),
).dict()
).model_dump(
mode="json",
)
)
return run_id

Expand All @@ -156,6 +164,8 @@ def _add_run_complete_event(run_id: str, mdb: MongoDatabase, outputs: List[str])
type=RunEventType.COMPLETE,
time=now(as_str=True),
outputs=outputs,
).dict()
).model_dump(
mode="json",
)
)
return run_id
2 changes: 1 addition & 1 deletion nmdc_runtime/core/exceptions/token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from core.exceptions import CustomException
from nmdc_runtime.core.exceptions import CustomException


class DecodeTokenException(CustomException):
Expand Down
13 changes: 11 additions & 2 deletions nmdc_runtime/minter/adapters/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def mint(self, req_mint: MintingRequest) -> list[Identifier]:
)
)
for id_ in ids:
self.db[id_.id] = id_.dict()
self.db[id_.id] = id_.model_dump(
mode="json",
)
return ids

def bind(self, req_bind: BindingRequest) -> Identifier:
Expand Down Expand Up @@ -184,7 +186,14 @@ def mint(self, req_mint: MintingRequest) -> list[Identifier]:
)
for id_name in not_taken
]
self.db["minter.id_records"].insert_many([i.dict() for i in ids])
self.db["minter.id_records"].insert_many(
[
i.model_dump(
mode="json",
)
for i in ids
]
)
collected.extend(ids)
if len(collected) == req_mint.how_many:
break
Expand Down
8 changes: 7 additions & 1 deletion nmdc_runtime/minter/entrypoints/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ def mint_ids(
requester = Entity(id=site.id)
try:
minted = s.mint(
MintingRequest(service=service, requester=requester, **req_mint.dict())
MintingRequest(
service=service,
requester=requester,
**req_mint.model_dump(
mode="json",
),
)
)
return [d.id for d in minted]
except MinterError as e:
Expand Down
6 changes: 5 additions & 1 deletion nmdc_runtime/site/drsobjects/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def claim_metadata_ingest_jobs(
)
jobs = []
while True:
rv = client.list_jobs(lr.dict()).json()
rv = client.list_jobs(
lr.model_dump(
mode="json",
)
).json()
jobs.extend(rv["resources"])
if "next_page_token" not in rv:
break
Expand Down
Loading

0 comments on commit bd9df67

Please sign in to comment.