From 9970e0b573e1721e9c31475a03d6cfa11433c230 Mon Sep 17 00:00:00 2001 From: jenniferjiangkells Date: Thu, 14 Nov 2024 15:31:48 +0000 Subject: [PATCH] Update tests and docs --- README.md | 14 +- docs/cookbook/notereader_sandbox.md | 4 +- docs/quickstart.md | 14 +- example_use.py | 140 ------------------ tests/components/conftest.py | 19 +++ tests/components/test_cardcreator.py | 65 +++++--- tests/pipeline/prebuilt/test_medicalcoding.py | 5 +- tests/pipeline/prebuilt/test_summarization.py | 29 ++-- tests/pipeline/test_pipeline.py | 2 +- tests/pipeline/test_pipeline_load.py | 59 ++++---- 10 files changed, 146 insertions(+), 205 deletions(-) delete mode 100644 example_use.py diff --git a/README.md b/README.md index bdc7def..bb9f307 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,11 @@ Pre-built pipelines are use case specific end-to-end workflows that already have from healthchain.pipeline import MedicalCodingPipeline from healthchain.models import CdaRequest -pipeline = MedicalCodingPipeline.load("./path/to/model") +# Load from model ID +pipeline = MedicalCodingPipeline.from_model_id("en_core_sci_md", source="spacy") + +# Or load from local model +pipeline = MedicalCodingPipeline.from_local_model("./path/to/model", source="spacy") cda_data = CdaRequest(document="") output = pipeline(cda_data) @@ -129,7 +133,9 @@ from typing import List @hc.sandbox class MyCDS(ClinicalDecisionSupport): def __init__(self) -> None: - self.pipeline = SummarizationPipeline.load("./path/to/model") + self.pipeline = SummarizationPipeline.from_model_id( + "DISLab/SummLlama3.2-3B", source="huggingface" + ) self.data_generator = CdsDataGenerator() # Sets up an instance of a mock EHR client of the specified workflow @@ -162,7 +168,9 @@ from healthchain.models import CcdData, CdaRequest, CdaResponse @hc.sandbox class NotereaderSandbox(ClinicalDocumentation): def __init__(self): - self.pipeline = MedicalCodingPipeline.load("./path/to/model") + self.pipeline = MedicalCodingPipeline.from_model_id( + "en_core_sci_md", source="spacy" + ) # Load an existing CDA file @hc.ehr(workflow="sign-note-inpatient") diff --git a/docs/cookbook/notereader_sandbox.md b/docs/cookbook/notereader_sandbox.md index 12a2fc7..0ccdf9b 100644 --- a/docs/cookbook/notereader_sandbox.md +++ b/docs/cookbook/notereader_sandbox.md @@ -18,7 +18,9 @@ from healthchain.models import ( class NotereaderSandbox(ClinicalDocumentation): def __init__(self): self.cda_path = "./resources/uclh_cda.xml" - self.pipeline = MedicalCodingPipeline.load("./resources/models/medcat_model.zip") + self.pipeline = MedicalCodingPipeline.from_local_model( + "./resources/models/medcat_model.zip", source="spacy" + ) @hc.ehr(workflow="sign-note-inpatient") def load_data_in_client(self) -> CcdData: diff --git a/docs/quickstart.md b/docs/quickstart.md index 5df03bf..a204122 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -100,7 +100,15 @@ For a full list of available prebuilt pipelines and details on how to configure from healthchain.pipeline import MedicalCodingPipeline from healthchain.models import CdaRequest -pipeline = MedicalCodingPipeline.load("./path/to/model") +# Load from pre-built chain +chain = ChatPromptTemplate.from_template("Summarize: {text}") | ChatOpenAI() +pipeline = MedicalCodingPipeline.load(chain, source="langchain") + +# Or load from model ID +pipeline = MedicalCodingPipeline.from_model_id("DISLab/SummLlama3-8B", source="huggingface") + +# Or load from local model +pipeline = MedicalCodingPipeline.from_local_model("./path/to/model", source="spacy") cda_data = CdaRequest(document="") output = pipeline(cda_data) @@ -128,7 +136,9 @@ from healthchain.models import CdaRequest, CdaResponse, CcdData class MyCoolSandbox(ClinicalDocumentation): def __init__(self) -> None: # Load your pipeline - self.pipeline = MedicalCodingPipeline.load("./path/to/model") + self.pipeline = MedicalCodingPipeline.from_local_model( + "./path/to/model", source="spacy" + ) @hc.ehr(workflow="sign-note-inpatient") def load_data_in_client(self) -> CcdData: diff --git a/example_use.py b/example_use.py deleted file mode 100644 index aaf44f4..0000000 --- a/example_use.py +++ /dev/null @@ -1,140 +0,0 @@ -import healthchain as hc -import logging - -from healthchain.models.data.ccddata import CcdData -from healthchain.models.data.concept import ( - AllergyConcept, - Concept, - MedicationConcept, - ProblemConcept, - Quantity, -) -from healthchain.use_cases import ClinicalDecisionSupport -from healthchain.data_generators import CdsDataGenerator -from healthchain.models import Card, CdsFhirData, CDSRequest - -from healthchain.use_cases.clindoc import ClinicalDocumentation -from langchain_openai import ChatOpenAI -from langchain_core.prompts import PromptTemplate -from langchain_core.output_parsers import StrOutputParser - -from typing import List -from dotenv import load_dotenv - -load_dotenv() - - -log = logging.getLogger("healthchain") -log.setLevel(logging.DEBUG) - - -@hc.sandbox -class MyCoolSandbox(ClinicalDecisionSupport): - def __init__(self, testing=True): - self.testing = testing - self.chain = self._init_llm_chain() - self.data_generator = CdsDataGenerator() - - def _init_llm_chain(self): - prompt = PromptTemplate.from_template( - "Extract conditions from the FHIR resource below and summarize in one sentence using simple language \n'''{text}'''" - ) - model = ChatOpenAI(model="gpt-4o") - parser = StrOutputParser() - - chain = prompt | model | parser - return chain - - @hc.ehr(workflow="patient-view") - def load_data_in_client(self) -> CdsFhirData: - data = self.data_generator.generate() - return data - - @hc.api - def my_service(self, request: CDSRequest) -> List[Card]: - if self.testing: - result = "test" - else: - result = self.chain.invoke(str(request.prefetch)) - - return Card( - summary="Patient summary", - indicator="info", - source={"label": "openai"}, - detail=result, - ) - - -@hc.sandbox -class NotereaderSandbox(ClinicalDocumentation): - def __init__(self): - self.overwrite = True - - @hc.ehr(workflow="sign-note-inpatient") - def load_data_in_client(self) -> CcdData: - # data = self.data_generator.generate() - # return data - - with open("./resources/uclh_cda.xml", "r") as file: - xml_string = file.read() - - return CcdData(cda_xml=xml_string) - - @hc.api - def my_service(self, ccd_data: CcdData) -> CcdData: - print(ccd_data.problems) - print(ccd_data.medications) - print(ccd_data.allergies) - - new_problem = ProblemConcept( - code="38341003", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Hypertension", - ) - new_other_problem = ProblemConcept( - code="12341", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Diabetes", - ) - new_allergy = AllergyConcept( - code="70618", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Allergy to peanuts", - ) - another_allergy = AllergyConcept( - code="12344", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="CATS", - ) - new_medication = MedicationConcept( - code="197361", - code_system="2.16.840.1.113883.6.88", - code_system_name="RxNorm", - display_name="Lisinopril 10 MG Oral Tablet", - dosage=Quantity(value=10, unit="mg"), - route=Concept( - code="26643006", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Oral", - ), - ) - ccd_data.problems = [new_problem, new_other_problem] - ccd_data.allergies = [new_allergy, another_allergy] - ccd_data.medications = [new_medication] - - print(ccd_data.note) - - return ccd_data - - -if __name__ == "__main__": - # cds = MyCoolSandbox() - # cds.start_sandbox() - - cds = NotereaderSandbox() - cds.start_sandbox() diff --git a/tests/components/conftest.py b/tests/components/conftest.py index ea070c2..28c7d0f 100644 --- a/tests/components/conftest.py +++ b/tests/components/conftest.py @@ -1,5 +1,6 @@ import pytest from healthchain.io.containers.document import Document +from healthchain.pipeline.components import CdsCardCreator from tests.pipeline.conftest import mock_spacy_nlp # noqa: F401 @@ -14,3 +15,21 @@ def sample_lookup(): @pytest.fixture def sample_document(): return Document(data="This is a sample text for testing.") + + +@pytest.fixture +def basic_creator(): + return CdsCardCreator() + + +@pytest.fixture +def custom_template_creator(): + template = """ + { + "summary": "Custom: {{ model_output }}", + "indicator": "warning", + "source": {{ default_source | tojson }}, + "detail": "{{ model_output }}" + } + """ + return CdsCardCreator(template=template) diff --git a/tests/components/test_cardcreator.py b/tests/components/test_cardcreator.py index eae35af..7932b7c 100644 --- a/tests/components/test_cardcreator.py +++ b/tests/components/test_cardcreator.py @@ -4,24 +4,6 @@ from healthchain.models.responses.cdsresponse import Card, Source, IndicatorEnum -@pytest.fixture -def basic_creator(): - return CdsCardCreator() - - -@pytest.fixture -def custom_template_creator(): - template = """ - { - "summary": "Custom: {{ model_output }}", - "indicator": "warning", - "source": {{ default_source | tojson }}, - "detail": "{{ model_output }}" - } - """ - return CdsCardCreator(template=template) - - def test_default_template_rendering(basic_creator): content = "Test message" card = basic_creator.create_card(content) @@ -109,3 +91,50 @@ def test_invalid_input_configuration(): "Either model output (source and task) or content need to be provided" in str(exc_info.value) ) + + +def test_template_file_loading(tmp_path): + # Create a temporary template file + template_file = tmp_path / "test_template.json" + template_content = """ + { + "summary": "File Template: {{ model_output }}", + "indicator": "warning", + "source": {{ default_source | tojson }}, + "detail": "{{ model_output }}" + } + """ + template_file.write_text(template_content) + + creator = CdsCardCreator(template_path=template_file) + card = creator.create_card("Test message") + + assert card.summary == "File Template: Test message" + assert card.indicator == IndicatorEnum.warning + assert card.source == Source(label="Card Generated by HealthChain") + assert card.detail == "Test message" + + +def test_nonexistent_template_file(caplog): + creator = CdsCardCreator(template_path="nonexistent_template.json") + card = creator.create_card("Test message") + assert card.summary == "Test message" + assert card.indicator == IndicatorEnum.info + assert card.source == Source(label="Card Generated by HealthChain") + assert card.detail == "Test message" + assert "Error loading template" in caplog.text + + +def test_invalid_template_file(tmp_path): + # Create a temporary template file with invalid JSON + template_file = tmp_path / "invalid_template.json" + template_content = """ + { + "summary": {{ invalid_json }}, + "indicator": "warning" + """ # Invalid JSON + template_file.write_text(template_content) + + with pytest.raises(ValueError): + creator = CdsCardCreator(template_path=template_file) + creator.create_card("Test message") diff --git a/tests/pipeline/prebuilt/test_medicalcoding.py b/tests/pipeline/prebuilt/test_medicalcoding.py index 9df58b5..aedb6be 100644 --- a/tests/pipeline/prebuilt/test_medicalcoding.py +++ b/tests/pipeline/prebuilt/test_medicalcoding.py @@ -37,8 +37,9 @@ def test_coding_pipeline(mock_cda_connector, mock_spacy_nlp): ModelConfig( source=ModelSource.SPACY, model="en_core_sci_sm", + task="ner", path=None, - kwargs={"task": "ner"}, + kwargs={}, ) ) mock_spacy_nlp.return_value.assert_called_once() @@ -67,7 +68,7 @@ def test_full_coding_pipeline_integration(mock_spacy_nlp, test_cda_request): "healthchain.pipeline.mixins.ModelRoutingMixin.get_model_component", mock_spacy_nlp, ): - pipeline = MedicalCodingPipeline.load( + pipeline = MedicalCodingPipeline.from_local_model( "./spacy/path/to/production/model", source="spacy" ) diff --git a/tests/pipeline/prebuilt/test_summarization.py b/tests/pipeline/prebuilt/test_summarization.py index e34df6c..a79fd2a 100644 --- a/tests/pipeline/prebuilt/test_summarization.py +++ b/tests/pipeline/prebuilt/test_summarization.py @@ -60,6 +60,7 @@ def test_summarization_pipeline( source=ModelSource.HUGGINGFACE.value, task="summarization", template=pipeline._output_template, + template_path=pipeline._output_template_path, delimiter="\n", ) @@ -90,22 +91,28 @@ def test_summarization_pipeline( assert "card-creation" in pipeline._stages -def test_full_summarization_pipeline_integration(mock_hf_transformer, test_cds_request): +def test_full_summarization_pipeline_integration( + mock_hf_transformer, test_cds_request, tmp_path +): # Use mock LLM object for now with patch( "healthchain.pipeline.mixins.ModelRoutingMixin.get_model_component", mock_hf_transformer, ): - template = """ - { - "summary": "This is a test summary", - "indicator": "warning", - "source": {{ default_source | tojson }}, - "detail": "{{ model_output }}" - } - """ - pipeline = SummarizationPipeline.load( - "llama3", source="huggingface", template=template + # Create a temporary template file + template_file = tmp_path / "card_template.json" + template_content = """ + { + "summary": "This is a test summary", + "indicator": "warning", + "source": {{ default_source | tojson }}, + "detail": "{{ model_output }}" + } + """ + template_file.write_text(template_content) + + pipeline = SummarizationPipeline.from_model_id( + "llama3", source="huggingface", template_path=template_file ) cds_response = pipeline(test_cds_request) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 6929c93..03351d0 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -257,7 +257,7 @@ def test_pipeline_class_and_representation(mock_basic_pipeline): assert "comp1" in repr_string assert "comp2" in repr_string - loaded_pipeline = Pipeline.load("dummy_path") + loaded_pipeline = Pipeline.from_local_model("dummy_path", source="spacy") assert isinstance(loaded_pipeline, Pipeline) diff --git a/tests/pipeline/test_pipeline_load.py b/tests/pipeline/test_pipeline_load.py index a384cc1..273336c 100644 --- a/tests/pipeline/test_pipeline_load.py +++ b/tests/pipeline/test_pipeline_load.py @@ -1,11 +1,10 @@ import pytest from pathlib import Path from healthchain.pipeline.base import ModelSource -from healthchain.io.containers import DataContainer -def test_load_huggingface_model(mock_basic_pipeline): - pipeline = mock_basic_pipeline.load( +def test_from_model_id_huggingface(mock_basic_pipeline): + pipeline = mock_basic_pipeline.from_model_id( "meta-llama/Llama-2-7b", task="text-generation", device="cuda", batch_size=32 ) @@ -19,8 +18,8 @@ def test_load_huggingface_model(mock_basic_pipeline): } -def test_load_spacy_model(mock_basic_pipeline): - pipeline = mock_basic_pipeline.load( +def test_from_model_id_spacy(mock_basic_pipeline): + pipeline = mock_basic_pipeline.from_model_id( "en_core_sci_md", source="spacy", disable=["parser", "ner"] ) @@ -30,8 +29,10 @@ def test_load_spacy_model(mock_basic_pipeline): assert pipeline._model_config.kwargs == {"disable": ["parser", "ner"]} -def test_load_local_model(mock_basic_pipeline): - pipeline = mock_basic_pipeline.load("./models/custom_spacy_model", source="spacy") +def test_from_local_model_spacy(mock_basic_pipeline): + pipeline = mock_basic_pipeline.from_local_model( + "./models/custom_spacy_model", source="spacy" + ) assert pipeline._model_config.source == ModelSource.SPACY assert pipeline._model_config.model == "custom_spacy_model" @@ -47,7 +48,7 @@ def test_load_with_template(mock_basic_pipeline): } """ - pipeline = mock_basic_pipeline.load( + pipeline = mock_basic_pipeline.from_model_id( "gpt-3.5-turbo", source="huggingface", template=template ) @@ -56,24 +57,22 @@ def test_load_with_template(mock_basic_pipeline): assert pipeline._model_config.model == "gpt-3.5-turbo" -def test_load_callable_chain(mock_basic_pipeline, mock_chain): +def test_from_model_id_invalid_source(mock_basic_pipeline): + with pytest.raises(ValueError, match="not a valid ModelSource"): + mock_basic_pipeline.from_model_id("model", source="invalid_source") + + +def test_load_callable(mock_basic_pipeline, mock_chain): pipeline = mock_basic_pipeline.load(mock_chain, temperature=0.7, max_tokens=100) assert pipeline._model_config.source == ModelSource.LANGCHAIN assert pipeline._model_config.model == mock_chain assert pipeline._model_config.kwargs == {"temperature": 0.7, "max_tokens": 100} - with pytest.raises( - ValueError, match="LangChain models must be passed directly as chain objects" - ): + with pytest.raises(ValueError, match="Pipeline must be a callable object"): mock_basic_pipeline.load("langchain", source="langchain") -def test_load_invalid_source(mock_basic_pipeline): - with pytest.raises(ValueError, match="Unsupported model source"): - mock_basic_pipeline.load("model", source="invalid_source") - - def test_load_with_simple_callable(mock_basic_pipeline): # Create a simple callable def simple_chain(input_text: str) -> str: @@ -86,15 +85,21 @@ def simple_chain(input_text: str) -> str: assert pipeline._model_config.kwargs == {"temperature": 0.7} -def test_load_preserves_pipeline_functionality(mock_basic_pipeline): - pipeline = mock_basic_pipeline.load("test-model") +def test_load_with_template_path(mock_basic_pipeline, tmp_path): + # Create a temporary template file + template_file = tmp_path / "test_template.json" + template_content = """ + { + "summary": "Test summary", + "detail": "{{ model_output }}" + } + """ + template_file.write_text(template_content) - # Add a simple component - @pipeline.add_node(name="test_component") - def test_component(data: DataContainer): - data.data = "processed" - return data + pipeline = mock_basic_pipeline.from_model_id( + "gpt-3.5-turbo", source="huggingface", template_path=template_file + ) - # Test that the pipeline still works - result = pipeline("test") - assert result.data == "processed" + assert pipeline._output_template_path == template_file + assert pipeline._model_config.source == ModelSource.HUGGINGFACE + assert pipeline._model_config.model == "gpt-3.5-turbo"