From 78b7f24178dd0b6f5c6c9a5dba0900b16a139043 Mon Sep 17 00:00:00 2001 From: Jennifer Jiang-Kells Date: Thu, 13 Jun 2024 15:31:41 +0100 Subject: [PATCH] Fix tests --- tests/conftest.py | 15 ++++++++++----- tests/test_service_with_func.py | 2 +- tests/test_strategy.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b02b6ae..8d689a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,8 @@ from pydantic import BaseModel from healthchain.base import BaseStrategy, BaseUseCase, UseCaseType +from healthchain.fhir_resources.bundle_resources import Bundle_EntryModel, BundleModel +from healthchain.models.data.cdsfhirdata import CdsFhirData from healthchain.models.requests.cdsrequest import CDSRequest from healthchain.use_cases.cds import ClinicalDecisionSupportStrategy from healthchain.clients import EHRClient @@ -25,7 +27,10 @@ class synth_data: class MockDataGenerator: def __init__(self) -> None: - self.data = synth_data(context={}, prefetch=MockBundle()) + self.data = CdsFhirData( + context={}, prefetch=BundleModel(entry=[Bundle_EntryModel()]) + ) + # self.data = synth_data(context={}, prefetch=MockBundle()) self.workflow = None def set_workflow(self, workflow): @@ -39,17 +44,17 @@ def cds_strategy(): @pytest.fixture def valid_data(): - return synth_data( + return CdsFhirData( context={"userId": "Practitioner/123", "patientId": "123"}, - prefetch=MockBundle(), + prefetch=BundleModel(entry=[Bundle_EntryModel()]), ) @pytest.fixture def invalid_data(): - return synth_data( + return CdsFhirData( context={"invalidId": "Practitioner", "patientId": "123"}, - prefetch=MockBundle(), + prefetch=BundleModel(entry=[Bundle_EntryModel()]), ) diff --git a/tests/test_service_with_func.py b/tests/test_service_with_func.py index 3ccf7db..c7f9a64 100644 --- a/tests/test_service_with_func.py +++ b/tests/test_service_with_func.py @@ -21,7 +21,7 @@ def load_data(self): def llm(self, text: str): return [ Card( - summary=self.data_generator.data.prefetch.condition, + summary="test", indicator="info", source={"label": "website"}, ) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index b0a79bb..43cc5b9 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -13,7 +13,7 @@ def test_valid_data_request_construction(cds_strategy, valid_data): mock_init.assert_called_once_with( hook=Workflow.patient_view.value, context=PatientViewContext(userId="Practitioner/123", patientId="123"), - prefetch={}, + prefetch={"entry": [{}]}, )