Skip to content

Commit

Permalink
Update tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferjiangkells committed Nov 14, 2024
1 parent 3977f75 commit 9970e0b
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 205 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ Pre-built pipelines are use case specific end-to-end workflows that already have
from healthchain.pipeline import MedicalCodingPipeline
from healthchain.models import CdaRequest

pipeline = MedicalCodingPipeline.load("./path/to/model")
# Load from model ID
pipeline = MedicalCodingPipeline.from_model_id("en_core_sci_md", source="spacy")

# Or load from local model
pipeline = MedicalCodingPipeline.from_local_model("./path/to/model", source="spacy")

cda_data = CdaRequest(document="<CDA XML content>")
output = pipeline(cda_data)
Expand Down Expand Up @@ -129,7 +133,9 @@ from typing import List
@hc.sandbox
class MyCDS(ClinicalDecisionSupport):
def __init__(self) -> None:
self.pipeline = SummarizationPipeline.load("./path/to/model")
self.pipeline = SummarizationPipeline.from_model_id(
"DISLab/SummLlama3.2-3B", source="huggingface"
)
self.data_generator = CdsDataGenerator()

# Sets up an instance of a mock EHR client of the specified workflow
Expand Down Expand Up @@ -162,7 +168,9 @@ from healthchain.models import CcdData, CdaRequest, CdaResponse
@hc.sandbox
class NotereaderSandbox(ClinicalDocumentation):
def __init__(self):
self.pipeline = MedicalCodingPipeline.load("./path/to/model")
self.pipeline = MedicalCodingPipeline.from_model_id(
"en_core_sci_md", source="spacy"
)

# Load an existing CDA file
@hc.ehr(workflow="sign-note-inpatient")
Expand Down
4 changes: 3 additions & 1 deletion docs/cookbook/notereader_sandbox.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ from healthchain.models import (
class NotereaderSandbox(ClinicalDocumentation):
def __init__(self):
self.cda_path = "./resources/uclh_cda.xml"
self.pipeline = MedicalCodingPipeline.load("./resources/models/medcat_model.zip")
self.pipeline = MedicalCodingPipeline.from_local_model(
"./resources/models/medcat_model.zip", source="spacy"
)

@hc.ehr(workflow="sign-note-inpatient")
def load_data_in_client(self) -> CcdData:
Expand Down
14 changes: 12 additions & 2 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ For a full list of available prebuilt pipelines and details on how to configure
from healthchain.pipeline import MedicalCodingPipeline
from healthchain.models import CdaRequest

pipeline = MedicalCodingPipeline.load("./path/to/model")
# Load from pre-built chain
chain = ChatPromptTemplate.from_template("Summarize: {text}") | ChatOpenAI()
pipeline = MedicalCodingPipeline.load(chain, source="langchain")

# Or load from model ID
pipeline = MedicalCodingPipeline.from_model_id("DISLab/SummLlama3-8B", source="huggingface")

# Or load from local model
pipeline = MedicalCodingPipeline.from_local_model("./path/to/model", source="spacy")

cda_data = CdaRequest(document="<CDA XML content>")
output = pipeline(cda_data)
Expand Down Expand Up @@ -128,7 +136,9 @@ from healthchain.models import CdaRequest, CdaResponse, CcdData
class MyCoolSandbox(ClinicalDocumentation):
def __init__(self) -> None:
# Load your pipeline
self.pipeline = MedicalCodingPipeline.load("./path/to/model")
self.pipeline = MedicalCodingPipeline.from_local_model(
"./path/to/model", source="spacy"
)

@hc.ehr(workflow="sign-note-inpatient")
def load_data_in_client(self) -> CcdData:
Expand Down
140 changes: 0 additions & 140 deletions example_use.py

This file was deleted.

19 changes: 19 additions & 0 deletions tests/components/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from healthchain.io.containers.document import Document
from healthchain.pipeline.components import CdsCardCreator
from tests.pipeline.conftest import mock_spacy_nlp # noqa: F401


Expand All @@ -14,3 +15,21 @@ def sample_lookup():
@pytest.fixture
def sample_document():
return Document(data="This is a sample text for testing.")


@pytest.fixture
def basic_creator():
return CdsCardCreator()


@pytest.fixture
def custom_template_creator():
template = """
{
"summary": "Custom: {{ model_output }}",
"indicator": "warning",
"source": {{ default_source | tojson }},
"detail": "{{ model_output }}"
}
"""
return CdsCardCreator(template=template)
65 changes: 47 additions & 18 deletions tests/components/test_cardcreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,6 @@
from healthchain.models.responses.cdsresponse import Card, Source, IndicatorEnum


@pytest.fixture
def basic_creator():
return CdsCardCreator()


@pytest.fixture
def custom_template_creator():
template = """
{
"summary": "Custom: {{ model_output }}",
"indicator": "warning",
"source": {{ default_source | tojson }},
"detail": "{{ model_output }}"
}
"""
return CdsCardCreator(template=template)


def test_default_template_rendering(basic_creator):
content = "Test message"
card = basic_creator.create_card(content)
Expand Down Expand Up @@ -109,3 +91,50 @@ def test_invalid_input_configuration():
"Either model output (source and task) or content need to be provided"
in str(exc_info.value)
)


def test_template_file_loading(tmp_path):
# Create a temporary template file
template_file = tmp_path / "test_template.json"
template_content = """
{
"summary": "File Template: {{ model_output }}",
"indicator": "warning",
"source": {{ default_source | tojson }},
"detail": "{{ model_output }}"
}
"""
template_file.write_text(template_content)

creator = CdsCardCreator(template_path=template_file)
card = creator.create_card("Test message")

assert card.summary == "File Template: Test message"
assert card.indicator == IndicatorEnum.warning
assert card.source == Source(label="Card Generated by HealthChain")
assert card.detail == "Test message"


def test_nonexistent_template_file(caplog):
creator = CdsCardCreator(template_path="nonexistent_template.json")
card = creator.create_card("Test message")
assert card.summary == "Test message"
assert card.indicator == IndicatorEnum.info
assert card.source == Source(label="Card Generated by HealthChain")
assert card.detail == "Test message"
assert "Error loading template" in caplog.text


def test_invalid_template_file(tmp_path):
# Create a temporary template file with invalid JSON
template_file = tmp_path / "invalid_template.json"
template_content = """
{
"summary": {{ invalid_json }},
"indicator": "warning"
""" # Invalid JSON
template_file.write_text(template_content)

with pytest.raises(ValueError):
creator = CdsCardCreator(template_path=template_file)
creator.create_card("Test message")
5 changes: 3 additions & 2 deletions tests/pipeline/prebuilt/test_medicalcoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def test_coding_pipeline(mock_cda_connector, mock_spacy_nlp):
ModelConfig(
source=ModelSource.SPACY,
model="en_core_sci_sm",
task="ner",
path=None,
kwargs={"task": "ner"},
kwargs={},
)
)
mock_spacy_nlp.return_value.assert_called_once()
Expand Down Expand Up @@ -67,7 +68,7 @@ def test_full_coding_pipeline_integration(mock_spacy_nlp, test_cda_request):
"healthchain.pipeline.mixins.ModelRoutingMixin.get_model_component",
mock_spacy_nlp,
):
pipeline = MedicalCodingPipeline.load(
pipeline = MedicalCodingPipeline.from_local_model(
"./spacy/path/to/production/model", source="spacy"
)

Expand Down
29 changes: 18 additions & 11 deletions tests/pipeline/prebuilt/test_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_summarization_pipeline(
source=ModelSource.HUGGINGFACE.value,
task="summarization",
template=pipeline._output_template,
template_path=pipeline._output_template_path,
delimiter="\n",
)

Expand Down Expand Up @@ -90,22 +91,28 @@ def test_summarization_pipeline(
assert "card-creation" in pipeline._stages


def test_full_summarization_pipeline_integration(mock_hf_transformer, test_cds_request):
def test_full_summarization_pipeline_integration(
mock_hf_transformer, test_cds_request, tmp_path
):
# Use mock LLM object for now
with patch(
"healthchain.pipeline.mixins.ModelRoutingMixin.get_model_component",
mock_hf_transformer,
):
template = """
{
"summary": "This is a test summary",
"indicator": "warning",
"source": {{ default_source | tojson }},
"detail": "{{ model_output }}"
}
"""
pipeline = SummarizationPipeline.load(
"llama3", source="huggingface", template=template
# Create a temporary template file
template_file = tmp_path / "card_template.json"
template_content = """
{
"summary": "This is a test summary",
"indicator": "warning",
"source": {{ default_source | tojson }},
"detail": "{{ model_output }}"
}
"""
template_file.write_text(template_content)

pipeline = SummarizationPipeline.from_model_id(
"llama3", source="huggingface", template_path=template_file
)

cds_response = pipeline(test_cds_request)
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_pipeline_class_and_representation(mock_basic_pipeline):
assert "comp1" in repr_string
assert "comp2" in repr_string

loaded_pipeline = Pipeline.load("dummy_path")
loaded_pipeline = Pipeline.from_local_model("dummy_path", source="spacy")
assert isinstance(loaded_pipeline, Pipeline)


Expand Down
Loading

0 comments on commit 9970e0b

Please sign in to comment.