diff --git a/README.md b/README.md index fa38872..5741808 100644 --- a/README.md +++ b/README.md @@ -47,11 +47,11 @@ nlp_pipeline = Pipeline[Document]() # Add TextPreProcessor component preprocessor = TextPreProcessor(tokenizer="spacy") -nlp_pipeline.add(preprocessor) +nlp_pipeline.add_node(preprocessor) # Add Model component (assuming we have a pre-trained model) model = Model(model_path="path/to/pretrained/model") -nlp_pipeline.add(model) +nlp_pipeline.add_node(model) # Add TextPostProcessor component postprocessor = TextPostProcessor( @@ -60,7 +60,7 @@ postprocessor = TextPostProcessor( "high blood pressure": "hypertension" } ) -nlp_pipeline.add(postprocessor) +nlp_pipeline.add_node(postprocessor) # Build the pipeline nlp = nlp_pipeline.build() diff --git a/docs/api/component.md b/docs/api/component.md index c3f4a9b..8bd62f7 100644 --- a/docs/api/component.md +++ b/docs/api/component.md @@ -1,6 +1,6 @@ # Component -::: healthchain.pipeline.components.basecomponent +::: healthchain.pipeline.components.base ::: healthchain.pipeline.components.preprocessors -::: healthchain.pipeline.components.models +::: healthchain.pipeline.components.model ::: healthchain.pipeline.components.postprocessors diff --git a/docs/api/connectors.md b/docs/api/connectors.md new file mode 100644 index 0000000..c633cc1 --- /dev/null +++ b/docs/api/connectors.md @@ -0,0 +1,5 @@ +# Connectors + +::: healthchain.io.base +::: healthchain.io.cdaconnector +::: healthchain.io.cdsfhirconnector diff --git a/docs/api/pipeline.md b/docs/api/pipeline.md index fc67c23..8a1ace0 100644 --- a/docs/api/pipeline.md +++ b/docs/api/pipeline.md @@ -1,3 +1,3 @@ # Pipeline -::: healthchain.pipeline.basepipeline +::: healthchain.pipeline.base diff --git a/docs/cookbook/cds_sandbox.md b/docs/cookbook/cds_sandbox.md index 54ddbe0..92f7e6b 100644 --- a/docs/cookbook/cds_sandbox.md +++ b/docs/cookbook/cds_sandbox.md @@ -1,13 +1,13 @@ # Build a CDS sandbox -A CDS sandbox which uses `gpt-4o` to summarise patient information from synthetically generated FHIR resources received from the `patient-view` CDS hook. +A CDS sandbox which uses `gpt-4o` to summarise patient information from synthetically generated FHIR resources received from the `patient-view` CDS hook. [NEEDS UPDATING] ```python import healthchain as hc from healthchain.use_cases import ClinicalDecisionSupport from healthchain.data_generators import CdsDataGenerator -from healthchain.models import Card, CdsFhirData, CDSRequest +from healthchain.models import CdsFhirData, CDSRequest, CDSResponse from langchain_openai import ChatOpenAI from langchain_core.prompts import PromptTemplate @@ -37,12 +37,16 @@ class CdsSandbox(ClinicalDecisionSupport): return data @hc.api - def my_service(self, request: CDSRequest) -> List[Card]: + def my_service(self, request: CDSRequest) -> CDSResponse: result = self.chain.invoke(str(request.prefetch)) - return Card( - summary="Patient summary", - indicator="info", - source={"label": "openai"}, - detail=result, + return CDSResponse( + cards=[ + Card( + summary="Patient summary", + indicator="info", + source={"label": "openai"}, + detail=result, + ) + ] ) ``` diff --git a/docs/cookbook/notereader_sandbox.md b/docs/cookbook/notereader_sandbox.md index 58875fc..12a2fc7 100644 --- a/docs/cookbook/notereader_sandbox.md +++ b/docs/cookbook/notereader_sandbox.md @@ -18,6 +18,7 @@ from healthchain.models import ( class NotereaderSandbox(ClinicalDocumentation): def __init__(self): self.cda_path = "./resources/uclh_cda.xml" + self.pipeline = MedicalCodingPipeline.load("./resources/models/medcat_model.zip") @hc.ehr(workflow="sign-note-inpatient") def load_data_in_client(self) -> CcdData: @@ -27,38 +28,7 @@ class NotereaderSandbox(ClinicalDocumentation): return CcdData(cda_xml=xml_string) @hc.api - def my_service(self, ccd_data: CcdData) -> CcdData: - - # Apply extraction method from ccd_data.note - - new_problem = ProblemConcept( - code="38341003", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Hypertension", - ) - new_allergy = AllergyConcept( - code="70618", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Allergy to peanuts", - ) - new_medication = MedicationConcept( - code="197361", - code_system="2.16.840.1.113883.6.88", - code_system_name="RxNorm", - display_name="Lisinopril 10 MG Oral Tablet", - dosage=Quantity(value=10, unit="mg"), - route=Concept( - code="26643006", - code_system="2.16.840.1.113883.6.96", - code_system_name="SNOMED CT", - display_name="Oral", - ), - ) - ccd_data.problems = [new_problem] - ccd_data.allergies = [new_allergy] - ccd_data.medications = [new_medication] - - return ccd_data + def my_service(self, request: CdaRequest) -> CdaResponse: + response = self.pipeline(request) + return response ``` diff --git a/docs/quickstart.md b/docs/quickstart.md index ba9eea7..5df03bf 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -6,36 +6,34 @@ After [installing HealthChain](installation.md), get up to speed quickly with th ### Pipeline πŸ› οΈ -The `Pipeline` module in HealthChain provides a flexible way to build and manage processing pipelines for NLP and ML tasks that can easily interface with -parsers and connectors to integrate with electronic health record (EHR) systems. +HealthChain Pipelines provide a flexible way to build and manage processing pipelines for NLP and ML tasks that can easily integrate with electronic health record (EHR) systems. You can build pipelines with three different approaches: #### 1. Build Your Own Pipeline with Inline Functions -This is the most flexible approach, ideal for quick experiments and prototyping. Initialize a pipeline type hinted with the container type you want to process, then add components to your pipeline with the `@add` decorator. +This is the most flexible approach, ideal for quick experiments and prototyping. Initialize a pipeline type hinted with the container type you want to process, then add components to your pipeline with the `@add_node` decorator. Compile the pipeline with `.build()` to use it. ```python from healthchain.pipeline import Pipeline -from healthchain.io.containers import Document +from healthchain.io import Document nlp_pipeline = Pipeline[Document]() -@nlp_pipeline.add +@nlp_pipeline.add_node def tokenize(doc: Document) -> Document: doc.tokens = doc.text.split() return doc -@nlp_pipeline.add +@nlp_pipeline.add_node def pos_tag(doc: Document) -> Document: - # Dummy POS tagging doc.pos_tags = ["NOUN" if token[0].isupper() else "VERB" for token in doc.tokens] return doc -# Build and use the pipeline nlp = nlp_pipeline.build() + doc = Document("Patient has a fracture of the left femur.") doc = nlp(doc) @@ -46,51 +44,74 @@ print(doc.pos_tags) # ['NOUN', 'VERB', 'VERB', 'VERB', 'VERB', 'VERB'] ``` -#### 2. Build Your Own Pipeline with Components and Models +#### 2. Build Your Own Pipeline with Components, Models, and Connectors -Components are stateful - they're classes instead of functions. They can be useful for grouping related processing steps together, or wrapping specific models. +Components are stateful - they're classes instead of functions. They can be useful for grouping related processing steps together, setting configurations, or wrapping specific model loading steps. HealthChain comes with a few pre-built components, but you can also easily add your own. You can find more details on the [Components](./reference/pipeline/component.md) and [Models](./reference/pipeline/models/models.md) documentation pages. -Add components to your pipeline with the `.add()` method and compile with `.build()`. +Add components to your pipeline with the `.add_node()` method and compile with `.build()`. ```python from healthchain.pipeline import Pipeline -from healthchain.io.containers import Document from healthchain.pipeline.components import TextPreProcessor, Model, TextPostProcessor +from healthchain.io import Document pipeline = Pipeline[Document]() -pipeline.add(TextPreProcessor()) -pipeline.add(Model(model_path="path/to/model")) -pipeline.add(TextPostProcessor()) +pipeline.add_node(TextPreProcessor()) +pipeline.add_node(Model(model_path="path/to/model")) +pipeline.add_node(TextPostProcessor()) pipe = pipeline.build() + doc = Document("Patient presents with hypertension.") -doc = pipe(doc) +output = pipe(doc) +``` + +Let's go one step further! You can use [Connectors](./reference/pipeline/connectors/connectors.md) to work directly with [CDA](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) and [FHIR](https://hl7.org/fhir/) data received from healthcare system APIs. Add Connectors to your pipeline with the `.add_input()` and `.add_output()` methods. + +```python +from healthchain.pipeline import Pipeline +from healthchain.pipeline.components import Model +from healthchain.io import CdaConnector +from healthchain.models import CdaRequest + +pipeline = Pipeline() +cda_connector = CdaConnector() + +pipeline.add_input(cda_connector) +pipeline.add_node(Model(model_path="path/to/model")) +pipeline.add_output(cda_connector) + +pipe = pipeline.build() + +cda_data = CdaRequest(document="") +output = pipe(cda_data) ``` #### 3. Use Prebuilt Pipelines -Prebuilt pipelines are pre-configured collections of `Components` and `Models`. They are configured for specific use cases, offering the highest level of abstraction. This is the easiest way to get started if you already know the use case you want to build for. +Prebuilt pipelines are pre-configured collections of Components, Models, and Connectors. They are built for specific use cases, offering the highest level of abstraction. This is the easiest way to get started if you already know the use case you want to build for. For a full list of available prebuilt pipelines and details on how to configure and customize them, see the [Pipelines](./reference/pipeline/pipeline.md) documentation page. ```python from healthchain.pipeline import MedicalCodingPipeline +from healthchain.models import CdaRequest pipeline = MedicalCodingPipeline.load("./path/to/model") -doc = Document("Patient diagnosed with myocardial infarction.") -doc = pipeline(doc) +cda_data = CdaRequest(document="") +output = pipeline(cda_data) ``` ### Sandbox πŸ§ͺ -Once you've built your pipeline, you might want to experiment with how you want your pipeline to interact with different health systems. A sandbox helps you stage and test the end-to-end workflow of your pipeline application where real-time EHR integrations are involved. +Once you've built your pipeline, you might want to experiment with how it interacts with different healthcare systems. A sandbox helps you stage and test the end-to-end workflow of your pipeline application where real-time EHR integrations are involved. -Running a sandbox will start a `FastAPI` server with standardized API endpoints and create a sandboxed environment for you to interact with your application. +Running a sandbox will start a [FastAPI](https://fastapi.tiangolo.com/) server with pre-defined standardized endpoints and create a sandboxed environment for you to interact with your application. -To create a sandbox, initialize a class that inherits from a type of `UseCase` and decorate it with the `@hc.sandbox` decorator. +To create a sandbox, initialize a class that inherits from a type of [UseCase](./reference/sandbox/use_cases/use_cases.md) and decorate it with the `@hc.sandbox` decorator. Every sandbox also requires a **client** function marked by `@hc.ehr` and a **service** function marked by `@hc.api`. A **workflow** must be specified when creating an EHR client. @@ -101,6 +122,7 @@ import healthchain as hc from healthchain.use_cases import ClinicalDocumentation from healthchain.pipeline import MedicalCodingPipeline +from healthchain.models import CdaRequest, CdaResponse, CcdData @hc.sandbox class MyCoolSandbox(ClinicalDocumentation): @@ -117,9 +139,9 @@ class MyCoolSandbox(ClinicalDocumentation): return CcdData(cda_xml=xml_string) @hc.api - def my_service(self, ccd_data: CcdData) -> CcdData: + def my_service(self, request: CdaRequest) -> CdaResponse: # Run your pipeline - results = self.pipeline(ccd_data) + results = self.pipeline(request) return results if __name__ == "__main__": @@ -137,13 +159,6 @@ healthchain run my_sandbox.py This will start a server by default at `http://127.0.0.1:8000`, and you can interact with the exposed endpoints at `/docs`. Data generated from your sandbox runs is saved at `./output/` by default. -Then run: - -```bash -cd streamlist_demo -streamlit run app.py -``` - ## Utilities βš™οΈ ### Data Generator @@ -151,13 +166,14 @@ You can use the data generator to generate synthetic data for your sandbox runs. The `.generate()` is dependent on use case and workflow. For example, `CdsDataGenerator` will generate synthetic [FHIR](https://hl7.org/fhir/) data suitable for the workflow specified by the use case. -We're currently working on generating synthetic [CDA](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) data. If you're interested in contributing, please [reach out](https://discord.gg/UQC6uAepUz)! +We're working on generating synthetic [CDA](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) data. If you're interested in contributing, please [reach out](https://discord.gg/UQC6uAepUz)! [(Full Documentation on Data Generators)](./reference/utilities/data_generator.md) === "Within client" ```python import healthchain as hc + from healthchain.use_cases import ClinicalDecisionSupport from healthchain.models import CdsFhirData from healthchain.data_generators import CdsDataGenerator diff --git a/docs/reference/pipeline/component.md b/docs/reference/pipeline/component.md index 6738e09..8a44531 100644 --- a/docs/reference/pipeline/component.md +++ b/docs/reference/pipeline/component.md @@ -16,7 +16,7 @@ Components are the building blocks of the healthchain pipeline. They are designe You can create your own custom components by extending the `BaseComponent` class and implementing the `__call__` method. ```python -from healthchain.pipeline.basecomponent import BaseComponent +from healthchain.pipeline.base import BaseComponent class MyCustomComponent(BaseComponent): def __init__(self, **kwargs): diff --git a/docs/reference/pipeline/connectors/cdaconnector.md b/docs/reference/pipeline/connectors/cdaconnector.md new file mode 100644 index 0000000..aaa51fa --- /dev/null +++ b/docs/reference/pipeline/connectors/cdaconnector.md @@ -0,0 +1,55 @@ +# CDA Connector + +The `CdaConnector` handles Clinical Document Architecture (CDA) documents, serving as both an input and output connector in the pipeline. It parses CDA documents, extracting free-text notes and relevant structured clinical data into a `Document` object, and can return an annotated CDA document as output. + +This connector is particularly useful for clinical documentation improvement (CDI) workflows where CDA documents need to be processed and updated with additional structured data. + +[(Full Documentation on Clinical Documentation)](../../sandbox/use_cases/clindoc.md) + +## Usage + +```python +from healthchain.io import CdaConnector, Document +from healthchain.models import CdaRequest +from healthchain.pipeline import Pipeline + +# Create a pipeline with CdaConnector +pipeline = Pipeline() + +cda_connector = CdaConnector() +pipeline.add_input(cda_connector) +pipeline.add_output(cda_connector) + +# Example CDA request +cda_request = CdaRequest(document="") + +# Example 1: Simple pipeline execution +pipe = pipeline.build() +cda_response = pipe(cda_request) +print(cda_response) +# Output: CdaResponse(document='') + +# Example 2: Accessing CDA data inside a pipeline node +@pipeline.add_node +def example_pipeline_node(document: Document) -> Document: + print(document.ccd_data) + return document + +pipe = pipeline.build() +cda_response = pipe(cda_request) +# Output: CcdData object... +``` + +## Accessing data inside your pipeline + +Data parsed from the CDA document is stored in the `Document.ccd_data` attribute as a `CcdData` object, as shown in the example above. + +[(CcdData Reference)](../../../api/data_models.md#healthchain.models.data.ccddata.CcdData) + +## Configuration + +The `overwrite` parameter in the `CdaConnector` constructor determines whether existing data in the document should be overwritten. This can be useful for readability with very long CDA documents when the receiving system does not require the full document. + +```python +cda_connector = CdaConnector(overwrite=True) +``` diff --git a/docs/reference/pipeline/connectors/cdsfhirconnector.md b/docs/reference/pipeline/connectors/cdsfhirconnector.md new file mode 100644 index 0000000..2d86c07 --- /dev/null +++ b/docs/reference/pipeline/connectors/cdsfhirconnector.md @@ -0,0 +1,63 @@ +# CDS FHIR Connector + +The `CdsFhirConnector` handles FHIR data in the context of Clinical Decision Support (CDS) services, serving as both an input and output connector in the pipeline. + +Note that this is not meant to be used as a generic FHIR connector, but specifically designed for use with the [CDS Hooks specification](https://cds-hooks.org/). + +[(Full Documentation on Clinical Decision Support)](../../sandbox/use_cases/cds.md) + +## Usage + +```python +from healthchain.io import CdsFhirConnector, Document +from healthchain.models import CDSRequest +from healthchain.pipeline import Pipeline + +# Create a pipeline with CdsFhirConnector +pipeline = Pipeline() + +cds_fhir_connector = CdsFhirConnector() +pipeline.add_input(cds_fhir_connector) +pipeline.add_output(cds_fhir_connector) + +# Example CDS request +cds_request = CDSRequest( + hook="patient-view", + hookInstance="d1577c69-dfbe-44ad-ba6d-3e05e953b2ea", + context={ + "userId": "Practitioner/123", + "patientId": "Patient/456" + }, + prefetch={ + "patient": { + "resourceType": "Patient", + "id": "456", + "name": [{"family": "Doe", "given": ["John"]}], + "birthDate": "1970-01-01" + } + } +) + +# Example 1: Simple pipeline execution +pipe = pipeline.build() +cds_response = pipe(cds_request) +print(cds_response) +# Output: CDSResponse with cards... + +# Example 2: Accessing FHIR data inside a pipeline node +@pipeline.add_node +def example_pipeline_node(document: Document) -> Document: + print(document.fhir_resources) + return document + +pipe = pipeline.build() +cds_response = pipe(cds_request) +# Output: CdsFhirData object... + +``` + +## Accessing data inside your pipeline + +Data parsed from the FHIR resources is stored in the `Document.fhir_resources` attribute as a `CdsFhirData` object, as shown in the example above. + +[(CdsFhirData Reference)](../../../api/data_models.md#healthchain.models.data.cdsfhirdata) diff --git a/docs/reference/pipeline/connectors/connectors.md b/docs/reference/pipeline/connectors/connectors.md new file mode 100644 index 0000000..a1aa644 --- /dev/null +++ b/docs/reference/pipeline/connectors/connectors.md @@ -0,0 +1,49 @@ +# Connectors + +Connectors transform your data into a format that can be understood by healthcare systems such as EHRs. They allow your pipelines to work directly with data in HL7 interoperability standard formats, such as [CDA](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) or [FHIR](https://hl7.org/fhir/), without the headache of parsing and validating the data yourself. + +Connectors are what give you the power to build *end-to-end* pipelines that interact with real-time healthcare systems. + +## Available connectors + +Connectors make certain assumptions about the data they receive depending on the use case to convert it to an appropriate internal data format and container. + +Some connectors require the same instance to be used for both input and output, while others may be input or output only. + +| Connector | Input | Output | Internal Data Representation | Access it by... | Same instance I/O? | +|-----------|-------|--------|-------------------------|----------------|--------------------------| +| [**CdaConnector**](cdaconnector.md) | `CdaRequest` :material-arrow-right: `Document` | `Document` :material-arrow-right: `CdaRequest` | [**CcdData**](../../../api/data_models.md#healthchain.models.data.ccddata.CcdData) | `.ccd_data` | βœ… | +| [**CdsFhirConnector**](cdsfhirconnector.md) | `CDSRequest` :material-arrow-right: `Document` | `Document` :material-arrow-right: `CdsResponse` | [**CdsFhirData**](../../../api/data_models.md#healthchain.models.data.cdsfhirdata.CdsFhirData) | `.fhir_resources` | βœ… | + +!!! example "CdaConnector Example" + The `CdaConnector` expects a `CdaRequest` object as input and outputs a `CdaResponse` object. The connector converts the input data into a `Document` object because CDAs are usually represented as a document object. + + This `Document` object contains a `.ccd_data` attribute, which stores the structured data from the CDA document in a `CcdData` object. Any free-text notes are stored in the `Document.text` attribute. + + Because CDAs are annotated documents, the same `CdaConnector` instance must be used for both input and output operations in the pipeline. + +## Use Cases +Each connector can be mapped to a specific use case in the sandbox module. + +| Connector | Use Case | +|-----------|----------| +| `CdaConnector` | [**Clinical Documentation**](../../sandbox/use_cases/clindoc.md) | +| `CdsFhirConnector` | [**Clinical Decision Support**](../../sandbox/use_cases/cds.md) | + +## Adding connectors to your pipeline + +To add connectors to your pipeline, use the `.add_input()` and `.add_output()` methods. + +```python +from healthchain.pipeline import Pipeline +from healthchain.io import CdaConnector + +pipeline = Pipeline() +# In this example, we're using the same connector instance for input and output +cda_connector = CdaConnector() + +pipeline.add_input(cda_connector) +pipeline.add_output(cda_connector) +``` + +Connectors are currently intended for development and testing purposes only. They are not production-ready, although this is something we want to work towards on our long-term roadmap. If there is a specific connector you would like to see, please feel free to [open an issue](https://github.com/dotimplement/healthchain/issues) or [contact us](https://discord.gg/UQC6uAepUz)! diff --git a/docs/reference/pipeline/pipeline.md b/docs/reference/pipeline/pipeline.md index 635a2cb..b302c98 100644 --- a/docs/reference/pipeline/pipeline.md +++ b/docs/reference/pipeline/pipeline.md @@ -1,6 +1,6 @@ # Pipeline -HealthChain pipelines provide a simple interface to test, version, and connect your pipeline to common healthcare data standards, such as [CDAs (Clinical Document Architecture)](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) and [FHIR (Fast Healthcare Interoperability Resources)](https://build.fhir.org/). +HealthChain pipelines provide a simple interface to test, version, and connect your pipeline to common healthcare data standards, such as [CDA (Clinical Document Architecture)](https://www.hl7.org.uk/standards/hl7-standards/cda-clinical-document-architecture/) and [FHIR (Fast Healthcare Interoperability Resources)](https://build.fhir.org/). Depending on your need, you can either go top down, where you use prebuilt pipelines and customize them to your needs, or bottom up, where you build your own pipeline from scratch. @@ -11,33 +11,33 @@ HealthChain comes with a set of prebuilt pipelines that are out-of-the-box imple | Pipeline | Container | Compatible Connector | Description | Example Use Case | |----------|-----------|-----------|-------------|------------------| | [**MedicalCodingPipeline**](./prebuilt_pipelines/medicalcoding.md) | `Document` | `CdaConnector` | An NLP pipeline that processes free-text clinical notes into structured data | Automatically generating SNOMED CT codes from clinical notes | -| **SummarizationPipeline** [TODO] | `Document` | `FhirConnector` | An NLP pipeline for summarizing clinical notes | Generating discharge summaries from patient history and notes | +| **SummarizationPipeline** [TODO] | `Document` | `CdsFhirConnector` | An NLP pipeline for summarizing clinical notes | Generating discharge summaries from patient history and notes | | **QAPipeline** [TODO] | `Document` | N/A | A Question Answering pipeline suitable for conversational AI applications | Developing a chatbot to answer patient queries about their medical records | -| **ClassificationPipeline** [TODO] | `Tabular` | `FhirConnector` | A pipeline for machine learning classification tasks | Predicting patient readmission risk based on historical health data | +| **ClassificationPipeline** [TODO] | `Tabular` | `CdsFhirConnector` | A pipeline for machine learning classification tasks | Predicting patient readmission risk based on historical health data | -Pipeline inputs and outputs are defined by the container type. +Prebuilt pipelines are end-to-end workflows with Connectors built into them. They interact with raw data received from EHR interfaces, usually CDA or FHIR data from specific [use cases](../sandbox/use_cases/use_cases.md). ```python from healthchain.pipeline import Pipeline -from healthchain.io.containers import Document +from healthchain.models import CdaRequest pipeline = MedicalCodingPipeline.load('/path/to/model') -doc = Document("Patient is diagnosed with diabetes") -doc = pipeline(doc) +cda_request = CdaRequest(document="") +cda_response = pipeline(cda_request) ``` ### Customizing Prebuilt Pipelines To customize a prebuilt pipeline, you can use the [pipeline management methods](#pipeline-management) to add, remove, and replace components. For example, you may want to change the model being used. [TODO] -If you need even more control and don't mind writing more code, you can subclass `BasePipeline` and implement your own pipeline logic. +If you need more control and don't mind writing more code, you can subclass `BasePipeline` and implement your own pipeline logic. -[(BasePipeline API Reference)](../../api/pipeline.md#healthchain.pipeline.basepipeline.BasePipeline) +[(BasePipeline API Reference)](../../api/pipeline.md#healthchain.pipeline.base.BasePipeline) ## Freestyle πŸ•Ί -To build your own pipeline, you can start with an empty pipeline and add components to it. Initialize your pipeline with the appropriate container type, such as `Document` or `Tabular`. +To build your own pipeline, you can start with an empty pipeline and add components to it. Initialize your pipeline with the appropriate container type, such as `Document` or `Tabular`. This is not essential, but it allows the pipeline to enforce type safety (If you don't specify the container type, it will be inferred from the first component added.) You can see the full list of available containers at the [Container](./data_container.md) page. @@ -46,9 +46,12 @@ from healthchain.pipeline import Pipeline from healthchain.io.containers import Document pipeline = Pipeline[Document]() + +# Or if you live dangerously +# pipeline = Pipeline() ``` -To use a built pipeline, compile it by running `.build()` on it. This will return a compiled pipeline that you can run on your data. +To use a built pipeline, compile it by running `.build()`. This will return a compiled pipeline that you can run on your data. ```python pipe = pipeline.build() @@ -57,37 +60,43 @@ doc = pipe(Document("Patient is diagnosed with diabetes")) print(doc.entities) ``` -There are three types of nodes you can add to your pipeline: +### Adding Nodes + +There are three types of nodes you can add to your pipeline with the method `.add_node()`: - Inline Functions - Components - Custom Components -### Inline Functions +#### Inline Functions -Inline functions are simple functions that take in a container and return a container. They are defined directly within the `.add()` method. +Inline functions are simple functions that take in a container and return a container. ```python -@pipeline.add() +@pipeline.add_node def remove_stopwords(doc: Document) -> Document: stopwords = {"the", "a", "an", "in", "on", "at"} doc.tokens = [token for token in doc.tokens if token not in stopwords] return doc + +# Equivalent to: +pipeline.add_node(remove_stopwords) ``` -### Components +#### Components Components are pre-configured building blocks that perform specific tasks. They are defined as separate classes and can be reused across multiple pipelines. +You can see the full list of available components at the [Components](./component.md) page. + ```python from healthchain.pipeline import TextPreProcessor preprocessor = TextPreProcessor(tokenizer="spacy", lowercase=True) -pipeline.add(preprocessor) +pipeline.add_node(preprocessor) ``` -You can see the full list of available components at the [Components](./component.md) page. -### Custom Components +#### Custom Components Custom components are classes that implement the `BaseComponent` interface. You can use them to add custom processing logic to your pipeline. @@ -104,16 +113,29 @@ class RemoveStopwords(BaseComponent): return doc stopwords = ["the", "a", "an", "in", "on", "at"] -pipeline.add(RemoveStopwords(stopwords)) +pipeline.add_node(RemoveStopwords(stopwords)) ``` -[(BaseComponent API Reference)](../../api/component.md#healthchain.pipeline.components.basecomponent.BaseComponent) +[(BaseComponent API Reference)](../../api/component.md#healthchain.pipeline.components.base.BaseComponent) + +### Adding Connectors πŸ”— + +Connectors are added to the pipeline using the `.add_input()` and `.add_output()` methods. You can learn more about connectors at the [Connectors](./connectors/connectors.md) documentation page. + +```python +from healthchain.io import CdaConnector + +cda_connector = CdaConnector() + +pipeline.add_input(cda_connector) +pipeline.add_output(cda_connector) +``` ## Pipeline Management πŸ”¨ #### Adding -Use `.add()` to add a component to the pipeline. By default, the component will be added to the end of the pipeline and named as the function name provided. +Use `.add_node()` to add a component to the pipeline. By default, the component will be added to the end of the pipeline and named as the function name provided. You can specify the position of the component using the `position` parameter. Available positions are: @@ -128,7 +150,7 @@ When using `"after"` or `"before"`, you must also specify the `reference` parame You can also specify the `stage` parameter to add the component to a specific stage group of the pipeline. ```python -@pipeline.add(position="after", reference="tokenize", stage="preprocessing") +@pipeline.add_node(position="after", reference="tokenize", stage="preprocessing") def remove_stopwords(doc: Document) -> Document: stopwords = {"the", "a", "an", "in", "on", "at"} doc.tokens = [token for token in doc.tokens if token not in stopwords] @@ -138,7 +160,7 @@ def remove_stopwords(doc: Document) -> Document: You can specify dependencies between components using the `dependencies` parameter. This is useful if you want to ensure that a component is run after another component. ```python -@pipeline.add(dependencies=["tokenize"]) +@pipeline.add_node(dependencies=["tokenize"]) def remove_stopwords(doc: Document) -> Document: stopwords = {"the", "a", "an", "in", "on", "at"} doc.tokens = [token for token in doc.tokens if token not in stopwords] diff --git a/docs/reference/sandbox/sandbox.md b/docs/reference/sandbox/sandbox.md index c9a8de6..f23b11b 100644 --- a/docs/reference/sandbox/sandbox.md +++ b/docs/reference/sandbox/sandbox.md @@ -1,7 +1,7 @@ # Sandbox Designing your pipeline to integrate well in a healthcare context is an essential step to turning it into an application that -could potentially be adapted for real-world use. As a developer who has years of experience deploying healthcare NLP solutions into hospitals, I know how painful and slow this process can be in reality. +could potentially be adapted for real-world use. As a developer who has years of experience deploying healthcare NLP solutions into hospitals, I know how painful and slow this process can be. A sandbox makes this process easier. It provides a staging environment to debug, test, track, and interact with your application in realistic deployment scenarios without having to gain access to such environments, especially ones that are tightly integrated with local EHR configurations. Think of it as integration testing in healthcare systems. @@ -19,7 +19,7 @@ For a given sandbox run: To create a sandbox, initialize a class that inherits from a type of `UseCase` and decorate it with the `@hc.sandbox` decorator. `UseCase` loads in the blueprint of the API endpoints for the specified use case, and `@hc.sandbox` orchestrates these interactions. -Every sandbox also requires a **client** function marked by `@hc.ehr` and a **service** function marked by `@hc.api`. Every client function must specify a **workflow** that informs the sandbox how your data will be formatted. For more information on workflows, see the [Use Cases](./use_cases/use_cases.md) documentation. +Every sandbox also requires a [**Client**](./client.md) function marked by `@hc.ehr` and a [**Service**](./service.md) function marked by `@hc.api`. Every client function must specify a **workflow** that informs the sandbox how your data will be formatted. For more information on workflows, see the [Use Cases](./use_cases/use_cases.md) documentation. !!! success "For each sandbox you need to specify..." @@ -28,44 +28,31 @@ Every sandbox also requires a **client** function marked by `@hc.ehr` and a **se - client function - workflow of client -```bash -pip install torch transformers -``` ```python import healthchain as hc +from healthchain.pipeline import SummarizationPipeline from healthchain.use_cases import ClinicalDecisionSupport from healthchain.data_generators import CdsDataGenerator -from healthchain.models import Card, CDSRequest, CdsFhirData -from transformers import pipeline +from healthchain.models import CDSRequest, CdsFhirData, CDSResponse -from typing import List @hc.sandbox class MyCoolSandbox(ClinicalDecisionSupport): def __init__(self): self.data_generator = CdsDataGenerator() - self.pipeline = pipeline('summarization') + self.pipeline = SummarizationPipeline('gpt-4o') - @hc.ehr(workflow="patient-view") + @hc.ehr(workflow="encounter-discharge") def load_data_in_client(self) -> CdsFhirData: - with open('/path/to/data.json', "r") as file: - fhir_json = file.read() - - return CdsFhirData(**fhir_json) + cds_fhir_data = self.data_generator.generate() + return cds_fhir_data @hc.api - def my_service(self, request: CDSRequest) -> List[Card]: - results = self.pipeline(str(request.prefetch)) - return [ - Card( - summary="Patient summary", - indicator="info", - source={"label": "transformers"}, - detail=results[0]['summary_text'], - ) - ] + def my_service(self, request: CDSRequest) -> CDSResponse: + cds_response = self.pipeline(request) + return cds_response if __name__ == "__main__": cds = MyCoolSandbox() diff --git a/docs/reference/sandbox/service.md b/docs/reference/sandbox/service.md index ca74663..5bcf7c6 100644 --- a/docs/reference/sandbox/service.md +++ b/docs/reference/sandbox/service.md @@ -4,7 +4,7 @@ A service is typically an API of a third-party system that returns data to the c When you decorate a function with `@hc.api` in a sandbox, the function is mounted standardized API endpoint an EHR client can make requests to. This can be defined by healthcare interoperability standards, such as HL7, or the EHR provider. HealthChain will start a [FastAPI](https://fastapi.tiangolo.com/) server with these APIs pre-defined for you. -Your service function must accept and return models appropriate for your use case. Typically the service function should accept a `Request` model and return a use case specific model, such as a list of `Card` for CDS. [This will be updated in the future] +Your service function receives use case specific request data as input and returns the response data. We recommend you initialize your pipeline in the class `__init__` method. @@ -16,27 +16,24 @@ Here are minimal examples for each use case: from healthchain.use_cases import ClinicalDocumentation from healthchain.pipeline import MedicalCodingPipeline - from healthchain.models import CcdData + from healthchain.models import CcdData, CdaRequest, CdaResponse @hc.sandbox class MyCoolSandbox(ClinicalDocumentation): - def __init__(self) -> None: - # Load your pipeline + def __init__(self): self.pipeline = MedicalCodingPipeline.load("./path/to/model") @hc.ehr(workflow="sign-note-inpatient") def load_data_in_client(self) -> CcdData: - # Load your data with open('/path/to/data.xml', "r") as file: xml_string = file.read() return CcdData(cda_xml=xml_string) @hc.api - def my_service(self, ccd_data: CcdData) -> CcdData: - # Run your pipeline - results = self.pipeline(ccd_data) - return results + def my_service(self, request: CdaRequest) -> CdaResponse: + response = self.pipeline(request) + return response ``` === "CDS" @@ -44,14 +41,13 @@ Here are minimal examples for each use case: import healthchain as hc from healthchain.use_cases import ClinicalDecisionSupport - from healthchain.pipeline import Pipeline - from healthchain.models import Card, CDSRequest, CdsFhirData - from typing import List + from healthchain.pipeline import SummarizationPipeline + from healthchain.models import CDSRequest, CDSResponse, CdsFhirData @hc.sandbox class MyCoolSandbox(ClinicalDecisionSupport): def __init__(self): - self.pipeline = Pipeline.load("./path/to/pipeline") + self.pipeline = SummarizationPipeline.load("mode-name") @hc.ehr(workflow="patient-view") def load_data_in_client(self) -> CdsFhirData: @@ -61,14 +57,7 @@ Here are minimal examples for each use case: return CdsFhirData(**fhir_json) @hc.api - def my_service(self, request: CDSRequest) -> List[Card]: - result = self.pipeline(str(request.prefetch)) - return [ - Card( - summary="Patient summary", - indicator="info", - source={"label": "openai"}, - detail=result, - ) - ] + def my_service(self, request: CDSRequest) -> CDSResponse: + response = self.pipeline(request) + return response ``` diff --git a/docs/reference/sandbox/use_cases/cds.md b/docs/reference/sandbox/use_cases/cds.md index 7620d60..85750cc 100644 --- a/docs/reference/sandbox/use_cases/cds.md +++ b/docs/reference/sandbox/use_cases/cds.md @@ -2,18 +2,28 @@ ## Clinical Decision Support (CDS) -CDS workflows are based on [CDS Hooks](https://cds-hooks.org/). CDS Hooks is an [HL7](https://cds-hooks.hl7.org) published specification for clinical decision support. For more information you can consult the [official documentation](https://cds-hooks.org/). +CDS workflows are based on [CDS Hooks](https://cds-hooks.org/). CDS Hooks is an [HL7](https://cds-hooks.hl7.org) published specification for clinical decision support. CDS hooks communicate using [FHIR (Fast Healthcare Interoperability Resources)](https://hl7.org/fhir/). For more information you can consult the [official documentation](https://cds-hooks.org/). | When | Where | What you receive | What you send back | | :-------- | :-----| :-------------------------- |----------------------------| | Triggered at certain events during a clinician's workflow, e.g. when a patient record is opened. | EHR | The context of the event and FHIR resources that are requested by your service. e.g. patient ID, `Encounter` and `Patient`. | β€œCards” displaying text, actionable suggestions, or links to launch a [SMART](https://smarthealthit.org/) app from within the workflow. | +## Data Flow -CDS hooks communicate using [HL7 FHIR (Fast Healthcare Interoperability Resources)](https://hl7.org/fhir/). FHIR data are represented internally as `CdsFhirData` in HealthChain, so a CDS client must return a `CdsFhirData` object. +| Stage | Input | Internal Data Representation | Output | +|-------|-------|------------------------------|--------| +| Client | N/A | N/A | `CdsFhirData` | +| Service | `CdsRequest` | `CdsFhirData` | `CdsResponse` | -CDS service functions receive `CdsRequest` and return a list of `Card`. [Improved documentation coming soon] -[(Card API Reference | ](../../../api/use_cases.md#healthchain.models.responses.cdsresponse.Card)[CdsFhirData API Reference)](../../../api/data_models.md#healthchain.models.data.cdsfhirdata) +[CdsFhirConnector](../../pipeline/connectors/cdsfhirconnector.md) handles the conversion of `CDSRequests` :material-swap-horizontal: `CdsFhirData` :material-swap-horizontal: `CdsResponse` in a HealthChain pipeline. + +Attributes of `CdsFhirData` are: + +- `context` +- `prefetch` + +[(CdsFhirData API Reference)](../../../api/data_models.md#healthchain.models.data.cdsfhirdata) ## Supported Workflows diff --git a/docs/reference/sandbox/use_cases/clindoc.md b/docs/reference/sandbox/use_cases/clindoc.md index d9e9cc3..bb18f10 100644 --- a/docs/reference/sandbox/use_cases/clindoc.md +++ b/docs/reference/sandbox/use_cases/clindoc.md @@ -7,7 +7,18 @@ The `ClinicalDocumentation` use case implements a real-time Clinical Documentati | :-------- | :-----| :-------------------------- |----------------------------| | Triggered when a clinician opts in to a CDI functionality and signs or pends a note after writing it. | Specific modules in EHR where clinical documentation takes place, such as NoteReader in Epic. | A CDA document which contains continuity of care data and free-text data, e.g. a patient's problem list and the progress note that the clinician has entered in the EHR. | A CDA document which contains additional structured data extracted and returned by your CDI service. | -A `ClinicalDocumentation` function receives and returns `CcdData`. Attributes of `CcdData` are: + +## Data Flow + +| Stage | Input | Internal Data Representation | Output | +|-------|-------|------------------------------|--------| +| Client | N/A | N/A | `CcdData` | +| Service | `CdaRequest` | `CcdData` | `CdaResponse` | + + +[CdaConnector](../../pipeline/connectors/cdaconnector.md) handles the conversion of `CdaRequests` :material-swap-horizontal: `CcdData` :material-swap-horizontal: `CdaResponse` in a HealthChain pipeline. + +Attributes of `CcdData` are: - `problems` - `allergies` diff --git a/healthchain/fhir_resources/__init__.py b/healthchain/fhir_resources/__init__.py index e69de29..aa2fce5 100644 --- a/healthchain/fhir_resources/__init__.py +++ b/healthchain/fhir_resources/__init__.py @@ -0,0 +1,21 @@ +from .bundleresources import Bundle +from .condition import Condition +from .patient import Patient +from .practitioner import Practitioner +from .procedure import Procedure +from .documentreference import DocumentReference +from .encounter import Encounter +from .medicationadministration import MedicationAdministration +from .medicationrequest import MedicationRequest + +__all__ = [ + "Bundle", + "Condition", + "Patient", + "Practitioner", + "Procedure", + "DocumentReference", + "Encounter", + "MedicationAdministration", + "MedicationRequest", +] diff --git a/healthchain/fhir_resources/bundleresources.py b/healthchain/fhir_resources/bundleresources.py index 603651a..63b94af 100644 --- a/healthchain/fhir_resources/bundleresources.py +++ b/healthchain/fhir_resources/bundleresources.py @@ -1,4 +1,4 @@ -from pydantic import Field, BaseModel, field_validator +from pydantic import Field, BaseModel, model_validator from typing import List, Literal, Any from healthchain.fhir_resources.resourceregistry import ImplementedResourceRegistry @@ -13,20 +13,66 @@ class BundleEntry(BaseModel): description="The Resource for the entry. The purpose/meaning of the resource is determined by the Bundle.type. This is allowed to be a Parameters resource if and only if it is referenced by something else within the Bundle that provides context/meaning.", ) - @field_validator("resource_field") - def check_enum(cls, value): - if value.__class__.__name__ not in implemented_resources: + @model_validator(mode="before") + @classmethod + def validate_and_convert_resource(cls, values): + """ + Validates and converts the resource field in the BundleEntry. + + This method performs the following tasks: + 1. Checks if the resource is None, in which case it returns the values unchanged. + 2. If the resource is already a Pydantic BaseModel, it verifies that it's an implemented resource type. + 3. If the resource is a dictionary, it checks for the presence of a 'resourceType' key and validates that it's an implemented resource type. + 4. Dynamically imports the appropriate resource class based on the resourceType. + 5. Recursively converts nested dictionaries to the appropriate Pydantic models. + + Args: + cls: The class on which this method is called. + values (dict): A dictionary containing the field values of the BundleEntry. + + Returns: + dict: The validated and potentially modified values dictionary. + + Raises: + ValueError: If the resource is invalid or of an unsupported type. + """ + resource = values.get("resource") + + if resource is None: + return values # Return unchanged if resource is None + + if isinstance(resource, BaseModel): + # If it's already a Pydantic model (e.g., Patient), use it directly + if resource.__class__.__name__ not in implemented_resources: + raise ValueError( + f"Invalid resource type: {resource.__class__.__name__}. Must be one of {implemented_resources}." + ) + return values + + if not isinstance(resource, dict) or "resourceType" not in resource: + raise ValueError( + "Invalid resource: must be a dictionary with a 'resourceType' key or a valid FHIR resource model" + ) + + resource_type = resource["resourceType"] + if resource_type not in implemented_resources: raise ValueError( - f"Invalid value class: {value.__class__.__name__}. Must be one of {implemented_resources}." + f"Invalid resourceType: {resource_type}. Must be one of {implemented_resources}." ) - return value + # Import the appropriate resource class dynamically + module = __import__("healthchain.fhir_resources", fromlist=[resource_type]) + resource_class = getattr(module, resource_type) + + # Convert the dictionary to the appropriate Pydantic model + values["resource"] = resource_class(**resource) + return values class Bundle(BaseModel): resourceType: Literal["Bundle"] = "Bundle" entry_field: List[BundleEntry] = Field( - default=None, + default_factory=list, alias="entry", description="An entry in a bundle resource - will either contain a resource or information about a resource (transactions and history only).", ) diff --git a/healthchain/io/__init__.py b/healthchain/io/__init__.py index 4520c23..03b145b 100644 --- a/healthchain/io/__init__.py +++ b/healthchain/io/__init__.py @@ -1,3 +1,13 @@ from healthchain.io.containers import DataContainer, Document, Tabular +from healthchain.io.base import BaseConnector +from healthchain.io.cdaconnector import CdaConnector +from healthchain.io.cdsfhirconnector import CdsFhirConnector -__all__ = ["Document", "DataContainer", "Tabular"] +__all__ = [ + "Document", + "DataContainer", + "Tabular", + "BaseConnector", + "CdaConnector", + "CdsFhirConnector", +] diff --git a/healthchain/io/base.py b/healthchain/io/base.py new file mode 100644 index 0000000..b13b02d --- /dev/null +++ b/healthchain/io/base.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar +from healthchain.io.containers import DataContainer + +T = TypeVar("T") + + +class BaseConnector(Generic[T], ABC): + """ + Abstract base class for all connectors in the pipeline. + + This class should be subclassed to create specific connectors. + Subclasses must implement the input and output methods. + """ + + @abstractmethod + def input(self, data: DataContainer[T]) -> DataContainer[T]: + """ + Convert input data to the pipeline's internal format. + + Args: + data (DataContainer[T]): The input data to be converted. + + Returns: + DataContainer[T]: The converted data. + """ + pass + + @abstractmethod + def output(self, data: DataContainer[T]) -> DataContainer[T]: + """ + Convert pipeline's internal format to output data. + + Args: + data (DataContainer[T]): The data to be converted for output. + + Returns: + DataContainer[T]: The converted output data. + """ + pass diff --git a/healthchain/io/cdaconnector.py b/healthchain/io/cdaconnector.py new file mode 100644 index 0000000..26ce99a --- /dev/null +++ b/healthchain/io/cdaconnector.py @@ -0,0 +1,118 @@ +import logging +from healthchain.io.containers import Document +from healthchain.io.base import BaseConnector +from healthchain.cda_parser import CdaAnnotator +from healthchain.models.data.ccddata import CcdData +from healthchain.models.requests.cdarequest import CdaRequest +from healthchain.models.responses.cdaresponse import CdaResponse + +log = logging.getLogger(__name__) + + +class CdaConnector(BaseConnector): + """ + CDAConnector class for handling CDA (Clinical Document Architecture) documents. + + This connector is responsible for parsing CDA documents, extracting relevant + clinical data, and updating the document with new information. It serves as + both an input and output connector in the pipeline. + + Attributes: + overwrite (bool): Flag to determine if existing data should be overwritten + when updating the CDA document. + cda_doc (CdaAnnotator): The parsed CDA document. + + Methods: + input: Parses the input CDA document and extracts clinical data. + output: Updates the CDA document with new data and returns the response. + """ + + def __init__(self, overwrite: bool = False): + self.overwrite = overwrite + self.cda_doc = None + + def input(self, in_data: CdaRequest) -> Document: + """ + Parse the input CDA document and extract clinical data. + + This method takes a CdaRequest object containing the CDA document as input, + parses it using the CdaAnnotator, and extracts relevant clinical data. + The extracted data is then used to create a CcdData object and a healthchain + Document object, which is returned. + + Args: + in_data (CdaRequest): The input request containing the CDA document. + + Returns: + Document: A Document object containing the extracted clinical data + and the original note text. + + """ + self.cda_doc = CdaAnnotator.from_xml(in_data.document) + + # TODO: Temporary fix for the note section, this might be more of a concern for the Annotator class + if isinstance(self.cda_doc.note, dict): + note_text = " ".join(str(value) for value in self.cda_doc.note.values()) + elif isinstance(self.cda_doc.note, str): + note_text = self.cda_doc.note + else: + log.warning("Note section is not a string or dictionary") + note_text = "" + + ccd_data = CcdData( + problems=self.cda_doc.problem_list, + medications=self.cda_doc.medication_list, + allergies=self.cda_doc.allergy_list, + note=note_text, + ) + + return Document(data=ccd_data.note, ccd_data=ccd_data) + + def output(self, out_data: Document) -> CdaResponse: + """ + Update the CDA document with new data and return the response. + + This method takes a Document object containing updated clinical data, + updates the CDA document with this new information, and returns a + CdaResponse object with the updated CDA document. + + Args: + out_data (Document): A Document object containing the updated + clinical data (problems, allergies, medications). + + Returns: + CdaResponse: A response object containing the updated CDA document. + + Note: + The method updates the CDA document with new problems, allergies, + and medications if they are present in the input Document object. + The update behavior (overwrite or append) is determined by the + `overwrite` attribute of the CdaConnector instance. + """ + # Update the CDA document with the results + if out_data.ccd_data.problems: + log.debug( + f"Updating CDA document with {len(out_data.ccd_data.problems)} problem(s)." + ) + self.cda_doc.add_to_problem_list( + out_data.ccd_data.problems, overwrite=self.overwrite + ) + if out_data.ccd_data.allergies: + log.debug( + f"Updating CDA document with {len(out_data.ccd_data.allergies)} allergy(ies)." + ) + self.cda_doc.add_to_allergy_list( + out_data.ccd_data.allergies, overwrite=self.overwrite + ) + if out_data.ccd_data.medications: + log.debug( + f"Updating CDA document with {len(out_data.ccd_data.medications)} medication(s)." + ) + self.cda_doc.add_to_medication_list( + out_data.ccd_data.medications, overwrite=self.overwrite + ) + + # Export the updated CDA document + response_document = self.cda_doc.export() + + return CdaResponse(document=response_document) diff --git a/healthchain/io/cdsfhirconnector.py b/healthchain/io/cdsfhirconnector.py new file mode 100644 index 0000000..184502d --- /dev/null +++ b/healthchain/io/cdsfhirconnector.py @@ -0,0 +1,96 @@ +import logging + +from healthchain.io.containers import Document +from healthchain.io.base import BaseConnector +from healthchain.models.data.cdsfhirdata import CdsFhirData +from healthchain.models.requests.cdsrequest import CDSRequest +from healthchain.models.responses.cdsresponse import CDSResponse + +log = logging.getLogger(__name__) + + +class CdsFhirConnector(BaseConnector): + """ + CdsFhirConnector class for handling FHIR (Fast Healthcare Interoperability Resources) documents + for CDS Hooks. + + This connector facilitates the conversion between CDSRequest objects and Document objects, + as well as the creation of CDSResponse objects from processed Documents. + + Attributes: + hook_name (str): The name of the CDS Hook being used. + """ + + def __init__(self, hook_name: str): + self.hook_name = hook_name + + def input(self, in_data: CDSRequest) -> Document: + """ + Converts a CDSRequest object into a Document object containing FHIR resources. + + This method takes a CDSRequest object as input, extracts the context and prefetch data, + and creates a CdsFhirData object. It then returns a Document object with the stringified + prefetch data as the main data content and the CdsFhirData object in the fhir_resources field. + + Args: + in_data (CDSRequest): The input CDSRequest object containing context and prefetch data. + + Returns: + Document: A Document object with the following attributes: + - data: A string representation of the prefetch data. + - fhir_resources: A CdsFhirData object containing the context and prefetch data. + + Raises: + ValueError: If neither prefetch nor fhirServer is provided in the input data. + NotImplementedError: If fhirServer is provided, as this functionality is not yet implemented. + ValueError: If the provided prefetch data is invalid. + + Note: + - The method currently only supports prefetch data and does not handle FHIR server interactions. + - Future implementations may involve more detailed processing, such as parsing + notes depending on the hook configuration. + """ + if in_data.prefetch is None and in_data.fhirServer is None: + raise ValueError( + "Either prefetch or fhirServer must be provided to extract FHIR data!" + ) + + if in_data.fhirServer is not None: + raise NotImplementedError("FHIR server is not implemented yet!") + + try: + cds_fhir_data = CdsFhirData.create( + context=in_data.context.model_dump(), prefetch=in_data.prefetch + ) + except Exception as e: + raise ValueError("Invalid prefetch data provided: {e}!") from e + + return Document( + data=str(cds_fhir_data.model_dump_prefetch()), fhir_resources=cds_fhir_data + ) + + def output(self, out_data: Document) -> CDSResponse: + """ + Generates a CDSResponse object from a processed Document object. + + This method takes a Document object that has been processed and potentially + contains CDS cards and system actions. It creates and returns a CDSResponse + object based on the contents of the Document. + + Args: + out_data (Document): A Document object potentially containing CDS cards + and system actions. + + Returns: + CDSResponse: A response object containing CDS cards and optional system actions. + If no cards are found in the Document, an empty list of cards is returned. + + Note: + - If out_data.cds_cards is None, a warning is logged and an empty list of cards is returned. + - System actions (out_data.cds_actions) are included in the response if present. + """ + if out_data.cds_cards is None: + log.warning("No CDS cards found in Document, returning empty list of cards") + return CDSResponse(cards=[]) + + return CDSResponse(cards=out_data.cds_cards, systemActions=out_data.cds_actions) diff --git a/healthchain/io/containers.py b/healthchain/io/containers.py index 59cfb87..ec8fa47 100644 --- a/healthchain/io/containers.py +++ b/healthchain/io/containers.py @@ -1,10 +1,19 @@ import json import pandas as pd -from typing import Dict, TypeVar, Generic, List, Any, Iterator +from typing import Dict, Optional, TypeVar, Generic, List, Any, Iterator from dataclasses import dataclass, field from spacy.tokens import Doc as SpacyDoc +from healthchain.models.data.ccddata import CcdData +from healthchain.models.data.cdsfhirdata import CdsFhirData +from healthchain.models.data.concept import ( + AllergyConcept, + MedicationConcept, + ProblemConcept, +) +from healthchain.models.responses.cdsresponse import Action, Card + T = TypeVar("T") @@ -56,14 +65,19 @@ class Document(DataContainer[str]): A container for document data, optionally wrapping a spaCy Doc object. This class extends DataContainer to specifically handle textual document data. - It provides functionality to work with raw text, tokenized text, and spaCy Doc objects. + It provides functionality to work with raw text, tokenized text, spaCy Doc objects, + and structured clinical data. Attributes: data (str): The raw text content of the document. + preprocessed_text (str): The preprocessed version of the text. tokens (List[str]): A list of individual tokens extracted from the text. pos_tags (List[str]): A list of part-of-speech tags corresponding to the tokens. entities (List[str]): A list of named entities identified in the text. - preprocessed_text (str): The preprocessed version of the text. + ccd_data (Optional[CcdData]): An optional CcdData object containing structured clinical data. + fhir_resources (Optional[CdsFhirData]): Optional FHIR resources data. + cds_cards (Optional[List[Card]]): Optional list of CDS cards. + cds_actions (Optional[List[Action]]): Optional list of CDS actions. text (str): The current text content, which may be updated when setting a spaCy Doc. _doc (SpacyDoc): An internal reference to the spaCy Doc object, if set. @@ -75,6 +89,7 @@ class Document(DataContainer[str]): word_count() -> int: Returns the number of tokens in the document. char_count() -> int: Returns the number of characters in the text. get_entities() -> List[Dict[str, Any]]: Returns a list of entities with their details. + update_ccd(new_problems: List[ProblemConcept], new_medications: List[MedicationConcept], new_allergies: List[AllergyConcept], overwrite: bool): Updates the existing CcdData object. __iter__() -> Iterator[str]: Allows iteration over the document's tokens. __len__() -> int: Returns the word count of the document. @@ -86,11 +101,14 @@ class Document(DataContainer[str]): certain attributes and methods that depend on it. """ - # TODO: review this + preprocessed_text: str = field(default="") tokens: List[str] = field(default_factory=list) pos_tags: List[str] = field(default_factory=list) entities: List[str] = field(default_factory=list) - preprocessed_text: str = field(default="") + ccd_data: Optional[CcdData] = field(default=None) + fhir_resources: Optional[CdsFhirData] = field(default=None) + cds_cards: Optional[List[Card]] = field(default=None) + cds_actions: Optional[List[Action]] = field(default=None) def __post_init__(self): self.text = self.data @@ -133,6 +151,37 @@ def get_entities(self) -> List[Dict[str, Any]]: for ent in self._doc.ents ] + def update_ccd( + self, + new_problems: List[ProblemConcept], + new_medications: List[MedicationConcept], + new_allergies: List[AllergyConcept], + overwrite: bool = False, + ) -> None: + """ + Updates the existing CcdData object with new data. + + Args: + new_problems (List[ProblemConcept]): List of new problem concepts to add or update. + new_medications (List[MedicationConcept]): List of new medication concepts to add or update. + new_allergies (List[AllergyConcept]): List of new allergy concepts to add or update. + overwrite (bool, optional): If True, replaces existing data; if False, appends new data. Defaults to False. + + Raises: + ValueError: If there is no existing CcdData object to update. + """ + if self.ccd_data is None: + self.ccd_data = CcdData() + + if overwrite: + self.ccd_data.problems = new_problems + self.ccd_data.medications = new_medications + self.ccd_data.allergies = new_allergies + else: + self.ccd_data.problems.extend(new_problems) + self.ccd_data.medications.extend(new_medications) + self.ccd_data.allergies.extend(new_allergies) + def __iter__(self) -> Iterator[str]: return iter(self.tokens) diff --git a/healthchain/models/data/cdsfhirdata.py b/healthchain/models/data/cdsfhirdata.py index e3b2e97..f0e2433 100644 --- a/healthchain/models/data/cdsfhirdata.py +++ b/healthchain/models/data/cdsfhirdata.py @@ -1,3 +1,5 @@ +import copy + from pydantic import BaseModel, Field from typing import Dict @@ -6,15 +8,46 @@ class CdsFhirData(BaseModel): """ - Data model for CDS FHIR data, this matches the expected fields in CDSRequests + Data model for CDS FHIR data, matching the expected fields in CDSRequests. + + Attributes: + context (Dict): A dictionary containing contextual information for the CDS request. + prefetch (Bundle): A Bundle object containing prefetched FHIR resources. + + Methods: + create(cls, context: Dict, prefetch: Dict): Class method to create a CdsFhirData instance. + model_dump(*args, **kwargs): Returns a dictionary representation of the model. + model_dump_json(*args, **kwargs): Returns a JSON string representation of the model. + model_dump_prefetch(*args, **kwargs): Returns a dictionary representation of the prefetch Bundle. """ context: Dict = Field(default={}) prefetch: Bundle + @classmethod + def create(cls, context: Dict, prefetch: Dict): + # deep copy to avoid modifying the original prefetch data + prefetch_copy = copy.deepcopy(prefetch) + bundle = Bundle(**prefetch_copy) + return cls(context=context, prefetch=bundle) + def model_dump(self, *args, **kwargs): kwargs.setdefault("exclude_unset", True) kwargs.setdefault("exclude_none", True) kwargs.setdefault("by_alias", True) return super().model_dump(*args, **kwargs) + + def model_dump_json(self, *args, **kwargs): + kwargs.setdefault("exclude_unset", True) + kwargs.setdefault("exclude_none", True) + kwargs.setdefault("by_alias", True) + + return super().model_dump_json(*args, **kwargs) + + def model_dump_prefetch(self, *args, **kwargs): + kwargs.setdefault("exclude_unset", True) + kwargs.setdefault("exclude_none", True) + kwargs.setdefault("by_alias", True) + + return self.prefetch.model_dump(*args, **kwargs) diff --git a/healthchain/models/responses/cdaresponse.py b/healthchain/models/responses/cdaresponse.py index 8f44f25..3f6b677 100644 --- a/healthchain/models/responses/cdaresponse.py +++ b/healthchain/models/responses/cdaresponse.py @@ -35,7 +35,7 @@ def model_dump_xml(self, *args, **kwargs) -> str: xml_dict = xmltodict.parse(self.document) document = search_key(xml_dict, "tns:Document") if document is None: - log.warning("Coudln't find document under namespace 'tns:Document") + log.warning("Couldn't find document under namespace 'tns:Document") return "" cda = base64.b64decode(document).decode("UTF-8") diff --git a/healthchain/models/responses/cdsresponse.py b/healthchain/models/responses/cdsresponse.py index 1c4d151..6470a24 100644 --- a/healthchain/models/responses/cdsresponse.py +++ b/healthchain/models/responses/cdsresponse.py @@ -177,8 +177,22 @@ def validate_suggestions(self) -> Self: class CDSResponse(BaseModel): """ - Http response + Represents the response from a CDS service. + + This class models the structure of a CDS Hooks response, which includes + cards for displaying information or suggestions to the user, and optional + system actions that can be executed automatically. + + Attributes: + cards (List[Card]): A list of Card objects to be displayed to the end user. + Default is an empty list. + systemActions (Optional[List[Action]]): A list of Action objects representing + actions that the CDS Client should execute as part of performing + the decision support requested. This field is optional. + + For more information, see: + https://cds-hooks.org/specification/current/#cds-service-response """ cards: List[Card] = [] - systemActions: Optional[Action] = None + systemActions: Optional[List[Action]] = None diff --git a/healthchain/pipeline/__init__.py b/healthchain/pipeline/__init__.py index 57479a2..a4f4557 100644 --- a/healthchain/pipeline/__init__.py +++ b/healthchain/pipeline/__init__.py @@ -1,6 +1,6 @@ -from healthchain.pipeline.basepipeline import BasePipeline, Pipeline -from healthchain.pipeline.components.basecomponent import BaseComponent, Component -from healthchain.pipeline.components.models import Model +from healthchain.pipeline.base import BasePipeline, Pipeline +from healthchain.pipeline.components.base import BaseComponent, Component +from healthchain.pipeline.components.model import Model from healthchain.pipeline.components.preprocessors import TextPreProcessor from healthchain.pipeline.components.postprocessors import TextPostProcessor from healthchain.pipeline.medicalcodingpipeline import MedicalCodingPipeline diff --git a/healthchain/pipeline/basepipeline.py b/healthchain/pipeline/base.py similarity index 90% rename from healthchain/pipeline/basepipeline.py rename to healthchain/pipeline/base.py index e77c897..5df4d9e 100644 --- a/healthchain/pipeline/basepipeline.py +++ b/healthchain/pipeline/base.py @@ -16,8 +16,9 @@ from pydantic import BaseModel from dataclasses import dataclass, field +from healthchain.io.base import BaseConnector from healthchain.io.containers import DataContainer -from healthchain.pipeline.components.basecomponent import BaseComponent +from healthchain.pipeline.components.base import BaseComponent logger = logging.getLogger(__name__) @@ -64,6 +65,8 @@ def __init__(self): self._components: List[PipelineNode[T]] = [] self._stages: Dict[str, List[Callable]] = {} self._built_pipeline: Optional[Callable] = None + self._input_connector: Optional[BaseConnector[T]] = None + self._output_connector: Optional[BaseConnector[T]] = None def __repr__(self) -> str: components_repr = ", ".join( @@ -143,7 +146,45 @@ def stages(self, new_stages: Dict[str, List[Callable]]): """ self._stages = new_stages - def add( + def add_input(self, connector: BaseConnector[T]) -> None: + """ + Adds an input connector to the pipeline. + + This method sets the input connector for the pipeline, which is responsible + for processing the input data before it's passed to the pipeline components. + + Args: + connector (Connector[T]): The input connector to be added to the pipeline. + + Returns: + None + + Note: + Only one input connector can be set for the pipeline. If this method is + called multiple times, the last connector will overwrite the previous ones. + """ + self._input_connector = connector + + def add_output(self, connector: BaseConnector[T]) -> None: + """ + Adds an output connector to the pipeline. + + This method sets the output connector for the pipeline, which is responsible + for processing the output data after it has passed through all pipeline components. + + Args: + connector (Connector[T]): The output connector to be added to the pipeline. + + Returns: + None + + Note: + Only one output connector can be set for the pipeline. If this method is + called multiple times, the last connector will overwrite the previous ones. + """ + self._output_connector = connector + + def add_node( self, component: Union[ BaseComponent[T], Callable[[DataContainer[T]], DataContainer[T]] @@ -158,7 +199,7 @@ def add( dependencies: List[str] = [], ) -> None: """ - Adds a component to the pipeline. + Adds a component node to the pipeline. Args: component (Union[BaseComponent[T], Callable[[DataContainer[T]], DataContainer[T]]], optional): @@ -457,11 +498,20 @@ def resolve_dependencies(): ordered_components = resolve_dependencies() def pipeline(data: Union[T, DataContainer[T]]) -> DataContainer[T]: + if self._input_connector: + data = self._input_connector.input(data) + if not isinstance(data, DataContainer): data = DataContainer(data) - return reduce(lambda d, comp: comp(d), ordered_components, data) - self._built_pipeline = pipeline + data = reduce(lambda d, comp: comp(d), ordered_components, data) + if self._output_connector: + data = self._output_connector.output(data) + + return data + + if self._built_pipeline is not pipeline: + self._built_pipeline = pipeline return pipeline diff --git a/healthchain/pipeline/components/__init__.py b/healthchain/pipeline/components/__init__.py index 47950ab..e63d15b 100644 --- a/healthchain/pipeline/components/__init__.py +++ b/healthchain/pipeline/components/__init__.py @@ -1,7 +1,7 @@ from .preprocessors import TextPreProcessor from .postprocessors import TextPostProcessor -from .models import Model -from .basecomponent import BaseComponent, Component +from .model import Model +from .base import BaseComponent, Component __all__ = [ "TextPreProcessor", diff --git a/healthchain/pipeline/components/basecomponent.py b/healthchain/pipeline/components/base.py similarity index 100% rename from healthchain/pipeline/components/basecomponent.py rename to healthchain/pipeline/components/base.py diff --git a/healthchain/pipeline/components/cdaparser.py b/healthchain/pipeline/components/cdaparser.py deleted file mode 100644 index e69de29..0000000 diff --git a/healthchain/pipeline/components/fhirparser.py b/healthchain/pipeline/components/fhirparser.py deleted file mode 100644 index e69de29..0000000 diff --git a/healthchain/pipeline/components/llm.py b/healthchain/pipeline/components/llm.py new file mode 100644 index 0000000..7331c5f --- /dev/null +++ b/healthchain/pipeline/components/llm.py @@ -0,0 +1,20 @@ +from healthchain.pipeline.components.base import Component +from healthchain.io.containers import Document +from typing import TypeVar, Generic + +T = TypeVar("T") + + +# TODO: implement this class +class LLM(Component[T], Generic[T]): + def __init__(self, model_name: str): + self.model = model_name + + def load_model(self): + pass + + def load_chain(self): + pass + + def __call__(self, doc: Document) -> Document: + return doc diff --git a/healthchain/pipeline/components/models.py b/healthchain/pipeline/components/model.py similarity index 89% rename from healthchain/pipeline/components/models.py rename to healthchain/pipeline/components/model.py index bc51baf..f3234e6 100644 --- a/healthchain/pipeline/components/models.py +++ b/healthchain/pipeline/components/model.py @@ -1,4 +1,4 @@ -from healthchain.pipeline.components.basecomponent import Component +from healthchain.pipeline.components.base import Component from healthchain.io.containers import Document from typing import TypeVar, Generic diff --git a/healthchain/pipeline/components/postprocessors.py b/healthchain/pipeline/components/postprocessors.py index 25561d9..3d01206 100644 --- a/healthchain/pipeline/components/postprocessors.py +++ b/healthchain/pipeline/components/postprocessors.py @@ -1,4 +1,4 @@ -from healthchain.pipeline.components.basecomponent import BaseComponent +from healthchain.pipeline.components.base import BaseComponent from healthchain.io.containers import Document from typing import TypeVar, Dict @@ -6,7 +6,7 @@ T = TypeVar("T") -class TextPostProcessor(BaseComponent): +class TextPostProcessor(BaseComponent[Document]): """ A component for post-processing text documents, specifically for refining entities. diff --git a/healthchain/pipeline/components/preprocessors.py b/healthchain/pipeline/components/preprocessors.py index 53359bd..3d18fa1 100644 --- a/healthchain/pipeline/components/preprocessors.py +++ b/healthchain/pipeline/components/preprocessors.py @@ -1,12 +1,12 @@ import re -from healthchain.pipeline.components.basecomponent import BaseComponent +from healthchain.pipeline.components.base import BaseComponent from healthchain.io.containers import Document from typing import Callable, List, TypeVar, Tuple T = TypeVar("T") -class TextPreProcessor(BaseComponent): +class TextPreProcessor(BaseComponent[Document]): """ A component for preprocessing text documents. diff --git a/healthchain/pipeline/medicalcodingpipeline.py b/healthchain/pipeline/medicalcodingpipeline.py index 8549ca9..5e0a2f6 100644 --- a/healthchain/pipeline/medicalcodingpipeline.py +++ b/healthchain/pipeline/medicalcodingpipeline.py @@ -1,18 +1,24 @@ -from healthchain.pipeline.basepipeline import BasePipeline +from healthchain.io.cdaconnector import CdaConnector +from healthchain.pipeline.base import BasePipeline from healthchain.pipeline.components.preprocessors import TextPreProcessor from healthchain.pipeline.components.postprocessors import TextPostProcessor -from healthchain.pipeline.components.models import Model +from healthchain.pipeline.components.model import Model # TODO: Implement this pipeline in full class MedicalCodingPipeline(BasePipeline): def configure_pipeline(self, model_path: str) -> None: + cda_connector = CdaConnector() + self.add_input(cda_connector) # Add preprocessing component - self.add(TextPreProcessor(), stage="preprocessing") + self.add_node(TextPreProcessor(), stage="preprocessing") # Add NER component - model = Model(model_path) - self.add(model, stage="ner+l") + model = Model( + model_path + ) # TODO: should converting the CcdData be a model concern? + self.add_node(model, stage="ner+l") # Add postprocessing component - self.add(TextPostProcessor(), stage="postprocessing") + self.add_node(TextPostProcessor(), stage="postprocessing") + self.add_output(cda_connector) diff --git a/healthchain/pipeline/summarizationpipeline.py b/healthchain/pipeline/summarizationpipeline.py new file mode 100644 index 0000000..3277143 --- /dev/null +++ b/healthchain/pipeline/summarizationpipeline.py @@ -0,0 +1,19 @@ +from healthchain.io.cdsfhirconnector import CdsFhirConnector +from healthchain.pipeline.base import BasePipeline +from healthchain.pipeline.components.llm import LLM + + +# TODO: Implement this pipeline in full +class SummarizationPipeline(BasePipeline): + def configure_pipeline(self, model_name: str) -> None: + cds_fhir_connector = CdsFhirConnector(hook_name="encounter-discharge") + self.add_input(cds_fhir_connector) + + # Add summarization component + llm = LLM(model_name) + self.add_node(llm, stage="summarization") + + # Maybe you can have components that create cards + # self.add_node(CardCreator(), stage="card-creation") + + self.add_output(cds_fhir_connector) diff --git a/healthchain/use_cases/cds.py b/healthchain/use_cases/cds.py index d372855..88e6ff2 100644 --- a/healthchain/use_cases/cds.py +++ b/healthchain/use_cases/cds.py @@ -17,7 +17,6 @@ from healthchain.models import ( CDSRequest, CDSResponse, - Card, CDSService, CDSServiceInformation, ) @@ -162,14 +161,29 @@ def cds_discovery(self) -> CDSServiceInformation: def cds_service(self, id: str, request: CDSRequest) -> CDSResponse: """ - CDS service endpoint for FastAPI app, should be mounted to /cds-services/{id} + CDS service endpoint for FastAPI app, mounted to /cds-services/{id} + + This method handles the execution of a specific CDS service. It validates the + service configuration, checks the input parameters, executes the service + function, and ensures the correct response type is returned. Args: - id (str): The ID of the CDS service. + id (str): The unique identifier of the CDS service to be executed. request (CDSRequest): The request object containing the input data for the CDS service. Returns: CDSResponse: The response object containing the cards generated by the CDS service. + + Raises: + AssertionError: If the service function is not properly configured. + TypeError: If the input or output types do not match the expected types. + + Note: + This method performs several checks to ensure the integrity of the service: + 1. Verifies that the service API is configured. + 2. Validates the signature of the service function. + 3. Ensures the service function accepts a CDSRequest as its first argument. + 4. Verifies that the service function returns a CDSResponse. """ # TODO: can register multiple services and fetch with id @@ -178,42 +192,28 @@ def cds_service(self, id: str, request: CDSRequest) -> CDSResponse: log.warning("CDS 'service_api' not configured, check class init.") return CDSResponse(cards=[]) - # Check service function signature - signature = inspect.signature(self._service_api.func) - assert ( - len(signature.parameters) == 2 - ), f"Incorrect number of arguments: {len(signature.parameters)} {signature}; CDS Service functions currently only accept 'self' and a single input argument." - - # Handle different input types - service_input = request - params = iter(inspect.signature(self._service_api.func).parameters.items()) - for name, param in params: - if name != "self": - if param.annotation == str: - service_input = request.model_dump_json(exclude_none=True) - elif param.annotation == Dict: - service_input = request.model_dump(exclude_none=True) - - # Call the service function - result = self._service_api.func(self, service_input) - - # Check the result return type - if result is None: + # Check that the first argument of self._service_api.func is of type CDSRequest + func_signature = inspect.signature(self._service_api.func) + params = list(func_signature.parameters.values()) + if len(params) < 2: # Only 'self' parameter + raise AssertionError( + "Service function must have at least one parameter besides 'self'" + ) + first_param = params[1] # Skip 'self' + if first_param.annotation == inspect.Parameter.empty: log.warning( - "CDS 'service_api' returned None, please check function definition." + "Service function parameter has no type annotation. Expected CDSRequest." + ) + elif first_param.annotation != CDSRequest: + raise TypeError( + f"Expected first argument of service function to be CDSRequest, but got {first_param.annotation}" ) - return CDSResponse(cards=[]) - if not isinstance(result, list): - if isinstance(result, Card): - result = [result] - else: - raise TypeError(f"Expected a list, but got {type(result).__name__}") + # Call the service function + response = self._service_api.func(self, request) - for card in result: - if not isinstance(card, Card): - raise TypeError( - f"Expected a list of 'Card' objects, but found an item of type {type(card).__name__}" - ) + # Check that response is of type CDSResponse + if not isinstance(response, CDSResponse): + raise TypeError(f"Expected CDSResponse, but got {type(response).__name__}") - return CDSResponse(cards=result) + return response diff --git a/healthchain/use_cases/clindoc.py b/healthchain/use_cases/clindoc.py index 38f1ac0..dda2c60 100644 --- a/healthchain/use_cases/clindoc.py +++ b/healthchain/use_cases/clindoc.py @@ -17,7 +17,6 @@ validate_workflow, ) from healthchain.models import CdaRequest, CdaResponse, CcdData -from healthchain.cda_parser import CdaAnnotator from healthchain.apimethod import APIMethod @@ -141,13 +140,28 @@ def endpoints(self) -> Dict[str, Endpoint]: def process_notereader_document(self, request: CdaRequest) -> CdaResponse: """ - Process the NoteReader document. + Process the NoteReader document using the configured service API. + + This method handles the execution of the NoteReader service. It validates the + service configuration, checks the input parameters, executes the service + function, and ensures the correct response type is returned. Args: - request (CdaRequest): The CdaRequest object containing the document. + request (CdaRequest): The request object containing the CDA document to be processed. Returns: - CdaResponse: The CdaResponse object containing the processed document. + CdaResponse: The response object containing the processed CDA document. + + Raises: + AssertionError: If the service function is not properly configured. + TypeError: If the output type does not match the expected CdaResponse type. + + Note: + This method performs several checks to ensure the integrity of the service: + 1. Verifies that the service API is configured. + 2. Validates the signature of the service function. + 3. Ensures the service function accepts a CdaRequest as its argument. + 4. Verifies that the service function returns a CdaResponse. """ # Check service_api if self._service_api is None: @@ -156,45 +170,28 @@ def process_notereader_document(self, request: CdaRequest) -> CdaResponse: # Check service function signature signature = inspect.signature(self._service_api.func) - assert ( - len(signature.parameters) == 2 - ), f"Incorrect number of arguments: {len(signature.parameters)} {signature}; service functions currently only accept 'self' and a single input argument." - - # Parse the CDA document - cda_doc = CdaAnnotator.from_xml(request.document) - ccd_data = CcdData( - problems=cda_doc.problem_list, - medications=cda_doc.medication_list, - allergies=cda_doc.allergy_list, - note=cda_doc.note, - ) + params = list(signature.parameters.values()) + if len(params) < 2: # Only 'self' parameter + raise AssertionError( + "Service function must have at least one parameter besides 'self'" + ) + first_param = params[1] # Skip 'self' + if first_param.annotation == inspect.Parameter.empty: + log.warning( + "Service function parameter has no type annotation. Expected CdaRequest." + ) + elif first_param.annotation != CdaRequest: + raise TypeError( + f"Expected first argument of service function to be CdaRequest, but got {first_param.annotation}" + ) # Call the service function - result = self._service_api.func(self, ccd_data) + response = self._service_api.func(self, request) # Check return type - if not isinstance(result, CcdData): + if not isinstance(response, CdaResponse): raise TypeError( - f"Expected return type CcdData, got {type(result)} instead." - ) - - # Update the CDA document with the results - if result.problems: - log.debug(f"Updating CDA document with {len(result.problems)} problem(s).") - cda_doc.add_to_problem_list(result.problems, overwrite=self.overwrite) - if result.allergies: - log.debug( - f"Updating CDA document with {len(result.allergies)} allergy(ies)." - ) - cda_doc.add_to_allergy_list(result.allergies, overwrite=self.overwrite) - if result.medications: - log.debug( - f"Updating CDA document with {len(result.medications)} medication(s)." + f"Expected return type CdaResponse, got {type(response)} instead." ) - cda_doc.add_to_medication_list(result.medications, overwrite=self.overwrite) - - # Export the updated CDA document - response_document = cda_doc.export() - response = CdaResponse(document=response_document) return response diff --git a/mkdocs.yml b/mkdocs.yml index c627905..bb918ff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,10 +22,14 @@ nav: - Overview: reference/pipeline/pipeline.md - Data Container: reference/pipeline/data_container.md - Component: reference/pipeline/component.md - - Models: - - Overview: reference/pipeline/models/models.md - - Prebuilt: + - Connectors: + - Overview: reference/pipeline/connectors/connectors.md + - CDA Connector: reference/pipeline/connectors/cdaconnector.md + - CDS FHIR Connector: reference/pipeline/connectors/cdsfhirconnector.md + - Prebuilt Pipelines: - Medical Coding: reference/pipeline/prebuilt_pipelines/medicalcoding.md + - Models: + - Overview: reference/pipeline/models/models.md - Sandbox: - Overview: reference/sandbox/sandbox.md - Client: reference/sandbox/client.md @@ -42,6 +46,7 @@ nav: - api/pipeline.md - api/component.md - api/containers.md + - api/connectors.md - api/use_cases.md - api/cds_hooks.md - api/service.md diff --git a/tests/components/conftest.py b/tests/components/conftest.py deleted file mode 100644 index 61a3dae..0000000 --- a/tests/components/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytest - - -@pytest.fixture -def sample_lookup(): - return { - "high blood pressure": "hypertension", - "heart attack": "myocardial infarction", - } diff --git a/tests/conftest.py b/tests/conftest.py index b63c82a..b70402b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ ) from healthchain.models.requests.cdarequest import CdaRequest from healthchain.models.responses.cdaresponse import CdaResponse +from healthchain.models.responses.cdsresponse import CDSResponse, Card from healthchain.service.soap.epiccdsservice import CDSServices from healthchain.use_cases.cds import ( ClinicalDecisionSupport, @@ -27,6 +28,8 @@ from healthchain.use_cases.clindoc import ClinicalDocumentation from healthchain.workflows import UseCaseType +# TODO: Tidy up fixtures + @pytest.fixture(autouse=True) def setup_caplog(caplog): @@ -137,10 +140,59 @@ def test_cds_request(): "hook": "patient-view", "hookInstance": "29e93987-c345-4cb7-9a92-b5136289c2a4", "context": {"userId": "Practitioner/123", "patientId": "123"}, + "prefetch": { + "resourceType": "Bundle", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "123", + "name": [{"family": "Doe", "given": ["John"]}], + "gender": "male", + "birthDate": "1970-01-01", + } + }, + ], + }, } return CDSRequest(**cds_dict) +@pytest.fixture +def test_cds_response_single_card(): + return CDSResponse( + cards=[ + Card( + summary="Test Card", + indicator="info", + source={"label": "Test Source"}, + detail="This is a test card for CDS response", + ) + ] + ) + + +@pytest.fixture +def test_cds_response_empty(): + return CDSResponse(cards=[]) + + +@pytest.fixture +def test_cds_response_multiple_cards(): + return CDSResponse( + cards=[ + Card( + summary="Test Card 1", indicator="info", source={"label": "Test Source"} + ), + Card( + summary="Test Card 2", + indicator="warning", + source={"label": "Test Source"}, + ), + ] + ) + + @pytest.fixture def mock_client_decorator(): def mock_client_decorator(func): @@ -318,8 +370,18 @@ def test_cda_request(): @pytest.fixture -def mock_cda_response(): - return CdaResponse(document="testing") +def test_cda_response(): + return CdaResponse( + document="Mock CDA Response Document", + error=None, + ) + + +@pytest.fixture +def test_cda_response_with_error(): + return CdaResponse( + document="", error="An error occurred while processing the CDA document" + ) @pytest.fixture diff --git a/tests/pipeline/conftest.py b/tests/pipeline/conftest.py new file mode 100644 index 0000000..fa905f0 --- /dev/null +++ b/tests/pipeline/conftest.py @@ -0,0 +1,217 @@ +import pytest +from unittest.mock import patch +from healthchain.io.cdaconnector import CdaConnector +from healthchain.io.cdsfhirconnector import CdsFhirConnector +from healthchain.io.containers import Document +from healthchain.models.data.ccddata import CcdData +from healthchain.models.data.concept import ( + AllergyConcept, + MedicationConcept, + ProblemConcept, +) +from healthchain.models.responses.cdaresponse import CdaResponse +from healthchain.pipeline.base import BasePipeline +from healthchain.models.responses.cdsresponse import CDSResponse, Card +from healthchain.models.data.cdsfhirdata import CdsFhirData + + +@pytest.fixture +def cda_connector(): + return CdaConnector() + + +@pytest.fixture +def cds_fhir_connector(): + return CdsFhirConnector(hook_name="patient-view") + + +@pytest.fixture +def sample_lookup(): + return { + "high blood pressure": "hypertension", + "heart attack": "myocardial infarction", + } + + +@pytest.fixture +def mock_cda_connector(): + with patch("healthchain.io.cdaconnector.CdaConnector") as mock: + connector_instance = mock.return_value + + # Mock the input method + connector_instance.input.return_value = Document( + data="Original note", + ccd_data=CcdData( + problems=[ + ProblemConcept( + code="38341003", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Hypertension", + ) + ], + medications=[ + MedicationConcept( + code="123454", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Aspirin", + ) + ], + allergies=[ + AllergyConcept( + code="70618", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Allergy to peanuts", + ) + ], + note="Original note", + ), + ) + + # Mock the output method + connector_instance.output.return_value = CdaResponse( + document="Updated CDA" + ) + + yield mock + + +@pytest.fixture +def mock_cda_annotator(): + with patch("healthchain.io.cdaconnector.CdaAnnotator") as mock: + mock_instance = mock.return_value + mock_instance.from_xml.return_value = mock_instance + mock_instance.problem_list = [ + ProblemConcept( + code="38341003", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Hypertension", + ) + ] + mock_instance.medication_list = [ + MedicationConcept( + code="123454", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Aspirin", + ) + ] + mock_instance.allergy_list = [ + AllergyConcept( + code="70618", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Allergy to peanuts", + ) + ] + mock_instance.note = "Sample Note" + yield mock + + +@pytest.fixture +def mock_basic_pipeline(): + class TestPipeline(BasePipeline): + def configure_pipeline(self, model_path: str) -> None: + pass + + return TestPipeline() + + +@pytest.fixture +def mock_model(): + with patch("healthchain.pipeline.components.model.Model") as mock: + model_instance = mock.return_value + model_instance.return_value = Document( + data="Processed note", + ccd_data=CcdData( + problems=[ + ProblemConcept( + code="38341003", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Hypertension", + ) + ], + medications=[ + MedicationConcept( + code="123454", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Aspirin", + ) + ], + allergies=[ + AllergyConcept( + code="70618", + code_system="2.16.840.1.113883.6.96", + code_system_name="SNOMED CT", + display_name="Allergy to peanuts", + ) + ], + note="Processed note", + ), + ) + yield mock + + +@pytest.fixture +def mock_llm(): + with patch("healthchain.pipeline.components.llm.LLM") as mock: + llm_instance = mock.return_value + llm_instance.return_value = Document( + data="Summarized discharge information", + cds_cards=[ + Card( + summary="Summarized discharge information", + detail="Patient John Doe was discharged. Encounter details...", + indicator="info", + source={"label": "Summarization LLM"}, + ) + ], + ) + yield mock + + +@pytest.fixture +def mock_cds_fhir_connector(): + with patch("healthchain.io.cdsfhirconnector.CdsFhirConnector") as mock: + connector_instance = mock.return_value + + # Mock the input method + connector_instance.input.return_value = Document( + data="Original FHIR data", + fhir_resources=CdsFhirData( + context={"patientId": "123", "encounterId": "456"}, + prefetch={ + "resourceType": "Bundle", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "123", + "name": [{"family": "Doe", "given": ["John"]}], + "gender": "male", + "birthDate": "1970-01-01", + } + }, + ], + }, + ), + ) + + # Mock the output method + connector_instance.output.return_value = CDSResponse( + cards=[ + Card( + summary="Summarized discharge information", + detail="Patient John Doe was discharged. Encounter details...", + indicator="info", + source={"label": "Summarization LLM"}, + ) + ] + ) + + yield mock diff --git a/tests/pipeline/prebuilt/test_medicalcoding.py b/tests/pipeline/prebuilt/test_medicalcoding.py new file mode 100644 index 0000000..d8859a6 --- /dev/null +++ b/tests/pipeline/prebuilt/test_medicalcoding.py @@ -0,0 +1,51 @@ +from unittest.mock import patch +from healthchain.models.requests.cdarequest import CdaRequest +from healthchain.models.responses.cdaresponse import CdaResponse +from healthchain.pipeline.medicalcodingpipeline import MedicalCodingPipeline + + +def test_coding_pipeline(mock_cda_connector, mock_model): + with patch( + "healthchain.pipeline.medicalcodingpipeline.CdaConnector", mock_cda_connector + ), patch("healthchain.pipeline.medicalcodingpipeline.Model", mock_model): + pipeline = MedicalCodingPipeline.load("./path/to/model") + + # Create a sample CdaRequest + cda_request = CdaRequest(document="Sample CDA") + + # Process the request through the pipeline + cda_response = pipeline(cda_request) + + # Assertions + assert isinstance(cda_response, CdaResponse) + assert cda_response.document == "Updated CDA" + + # Verify that CdaConnector methods were called correctly + mock_cda_connector.return_value.input.assert_called_once_with(cda_request) + mock_cda_connector.return_value.output.assert_called_once() + + # Verify that the Model was called + mock_model.assert_called_once() + mock_model.return_value.assert_called_once() + + # Verify the pipeline used the mocked input and output + input_doc = mock_cda_connector.return_value.input.return_value + assert input_doc.data == "Original note" + assert input_doc.ccd_data.problems[0].display_name == "Hypertension" + assert input_doc.ccd_data.medications[0].display_name == "Aspirin" + assert input_doc.ccd_data.allergies[0].display_name == "Allergy to peanuts" + + +def test_full_coding_pipeline_integration(mock_model, test_cda_request): + # Use mock model object for now + with patch("healthchain.pipeline.medicalcodingpipeline.Model", mock_model): + # this load method doesn't do anything yet + pipeline = MedicalCodingPipeline.load("./path/to/production/model") + + cda_response = pipeline(test_cda_request) + + assert isinstance(cda_response, CdaResponse) + + assert "Aspirin" in cda_response.document + assert "Hypertension" in cda_response.document + assert "Allergy to peanuts" in cda_response.document diff --git a/tests/pipeline/prebuilt/test_summarization.py b/tests/pipeline/prebuilt/test_summarization.py new file mode 100644 index 0000000..16f471f --- /dev/null +++ b/tests/pipeline/prebuilt/test_summarization.py @@ -0,0 +1,66 @@ +from unittest.mock import patch +from healthchain.models.responses.cdsresponse import CDSResponse +from healthchain.pipeline.summarizationpipeline import SummarizationPipeline + + +def test_summarization_pipeline(mock_cds_fhir_connector, mock_llm, test_cds_request): + with patch( + "healthchain.pipeline.summarizationpipeline.CdsFhirConnector", + mock_cds_fhir_connector, + ), patch("healthchain.pipeline.summarizationpipeline.LLM", mock_llm): + # This also doesn't do anything yet + pipeline = SummarizationPipeline.load("gpt-3.5-turbo") + + # Process the request through the pipeline + cds_response = pipeline(test_cds_request) + + # Assertions + assert isinstance(cds_response, CDSResponse) + assert len(cds_response.cards) == 1 + assert cds_response.cards[0].summary == "Summarized discharge information" + + # Verify that CdsFhirConnector methods were called correctly + mock_cds_fhir_connector.return_value.input.assert_called_once_with( + test_cds_request + ) + mock_cds_fhir_connector.return_value.output.assert_called_once() + + # Verify that the LLM was called + mock_llm.assert_called_once_with("gpt-3.5-turbo") + mock_llm.return_value.assert_called_once() + + # Verify the pipeline used the mocked input and output + input_data = mock_cds_fhir_connector.return_value.input.return_value + assert input_data.fhir_resources.context == { + "patientId": "123", + "encounterId": "456", + } + assert input_data.fhir_resources.model_dump_prefetch() == { + "resourceType": "Bundle", + "entry": [ + { + "resource": { + "resourceType": "Patient", + "id": "123", + "name": [{"family": "Doe", "given": ["John"]}], + "gender": "male", + "birthDate": "1970-01-01", + } + }, + ], + } + + +def test_full_summarization_pipeline_integration(mock_llm, test_cds_request): + # Use mock LLM object for now + with patch("healthchain.pipeline.summarizationpipeline.LLM", mock_llm): + pipeline = SummarizationPipeline.load("gpt-3.5-turbo") + + cds_response = pipeline(test_cds_request) + print(cds_response) + + assert isinstance(cds_response, CDSResponse) + assert len(cds_response.cards) == 1 + assert cds_response.cards[0].summary == "Summarized discharge information" + assert "Patient John Doe" in cds_response.cards[0].detail + assert "Encounter details" in cds_response.cards[0].detail diff --git a/tests/pipeline/test_cdaconnector.py b/tests/pipeline/test_cdaconnector.py new file mode 100644 index 0000000..da658e3 --- /dev/null +++ b/tests/pipeline/test_cdaconnector.py @@ -0,0 +1,60 @@ +from unittest.mock import Mock +from healthchain.models.data.concept import ( + AllergyConcept, + MedicationConcept, + ProblemConcept, +) +from healthchain.models.requests.cdarequest import CdaRequest +from healthchain.models.responses.cdaresponse import CdaResponse +from healthchain.models.data.ccddata import CcdData +from healthchain.io.containers import Document + + +def test_input(cda_connector, mock_cda_annotator): + mock_cda_doc = Mock() + mock_cda_doc.problem_list = [ProblemConcept(code="test")] + mock_cda_doc.medication_list = [MedicationConcept(code="test")] + mock_cda_doc.allergy_list = [AllergyConcept(code="test")] + mock_cda_doc.note = "Test note" + mock_cda_annotator.from_xml.return_value = mock_cda_doc + + input_data = CdaRequest(document="Test CDA") + result = cda_connector.input(input_data) + + assert isinstance(result, Document) + assert result.data == "Test note" + + assert isinstance(result.ccd_data, CcdData) + assert result.ccd_data.problems == [ProblemConcept(code="test")] + assert result.ccd_data.medications == [MedicationConcept(code="test")] + assert result.ccd_data.allergies == [AllergyConcept(code="test")] + assert result.ccd_data.note == "Test note" + + +def test_output(cda_connector): + cda_connector.cda_doc = Mock() + cda_connector.cda_doc.export.return_value = "Updated CDA" + + out_data = Document( + data="Updated note", + ccd_data=CcdData( + problems=[ProblemConcept(code="New Problem")], + medications=[MedicationConcept(code="New Medication")], + allergies=[AllergyConcept(code="New Allergy")], + note="Updated note", + ), + ) + + result = cda_connector.output(out_data) + + assert isinstance(result, CdaResponse) + assert result.document == "Updated CDA" + cda_connector.cda_doc.add_to_problem_list.assert_called_once_with( + [ProblemConcept(code="New Problem")], overwrite=False + ) + cda_connector.cda_doc.add_to_allergy_list.assert_called_once_with( + [AllergyConcept(code="New Allergy")], overwrite=False + ) + cda_connector.cda_doc.add_to_medication_list.assert_called_once_with( + [MedicationConcept(code="New Medication")], overwrite=False + ) diff --git a/tests/pipeline/test_cdsfhirconnector.py b/tests/pipeline/test_cdsfhirconnector.py new file mode 100644 index 0000000..dd81106 --- /dev/null +++ b/tests/pipeline/test_cdsfhirconnector.py @@ -0,0 +1,102 @@ +import pytest + +from healthchain.io.containers import Document +from healthchain.models.responses.cdsresponse import Action, CDSResponse, Card +from healthchain.models.data.cdsfhirdata import CdsFhirData + + +def test_input_with_valid_prefetch(cds_fhir_connector, test_cds_request): + # Use the valid prefetch data from test_cds_request + input_data = test_cds_request + + # Call the input method + result = cds_fhir_connector.input(input_data) + + # Assert the result + assert isinstance(result, Document) + assert result.data == str(input_data.prefetch) + assert isinstance(result.fhir_resources, CdsFhirData) + assert result.fhir_resources.context == input_data.context.model_dump() + assert result.fhir_resources.model_dump_prefetch() == input_data.prefetch + + +def test_output_with_cards(cds_fhir_connector): + # Prepare test data + cards = [ + Card( + summary="Test Card 1", + detail="This is a test card", + indicator="info", + source={"label": "Test Source"}, + ), + Card( + summary="Test Card 2", + detail="This is another test card", + indicator="warning", + source={"label": "Test Source"}, + ), + ] + actions = [ + Action( + type="create", + description="Create a new resource", + resource={"resourceType": "Patient", "id": "123"}, + resourceId="123", + ) + ] + out_data = Document(data="", cds_cards=cards, cds_actions=actions) + + # Call the output method + result = cds_fhir_connector.output(out_data) + + # Assert the result + assert isinstance(result, CDSResponse) + assert result.cards == cards + assert result.systemActions == actions + + +def test_output_without_cards(cds_fhir_connector, caplog): + # Prepare test data + out_data = Document(data="", cds_cards=None) + + # Call the output method + result = cds_fhir_connector.output(out_data) + + # Assert the result + assert isinstance(result, CDSResponse) + assert result.cards == [] + assert result.systemActions is None + assert ( + "No CDS cards found in Document, returning empty list of cards" in caplog.text + ) + + +def test_input_with_empty_request(cds_fhir_connector, test_cds_request): + # Prepare test data + input_data = test_cds_request + input_data.prefetch = None + input_data.fhirServer = None + + # Call the input method and expect a ValueError + with pytest.raises(ValueError) as exc_info: + cds_fhir_connector.input(input_data) + + # Assert the error message + assert ( + str(exc_info.value) + == "Either prefetch or fhirServer must be provided to extract FHIR data!" + ) + + +def test_input_with_fhir_server(cds_fhir_connector, test_cds_request): + # Prepare test data + input_data = test_cds_request + input_data.prefetch = None + input_data.fhirServer = "http://example.com/fhir" + + # Call the input method and expect a NotImplementedError + with pytest.raises(NotImplementedError) as exc_info: + cds_fhir_connector.input(input_data) + + # Assert the error message + assert str(exc_info.value) == "FHIR server is not implemented yet!" diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py new file mode 100644 index 0000000..6929c93 --- /dev/null +++ b/tests/pipeline/test_pipeline.py @@ -0,0 +1,282 @@ +import pytest +from pydantic import BaseModel, Field, ValidationError +from healthchain.pipeline.base import BaseComponent +from healthchain.io.containers import DataContainer +from healthchain.pipeline.base import Pipeline + + +# Mock classes and functions for testing +class MockComponent: + def __call__(self, data): + return data + + +class MockInputModel(BaseModel): + data: int = Field(gt=0) + + +class MockOutputModel(BaseModel): + data: int = Field(lt=15) + + +def mock_component(data: DataContainer) -> DataContainer: + data.data += 1 + return data + + +# Test adding components +def test_add_component(mock_basic_pipeline): + # Test basic component addition + mock_basic_pipeline.add_node(mock_component, name="test_component") + assert len(mock_basic_pipeline._components) == 1 + assert mock_basic_pipeline._components[0].name == "test_component" + + # Test adding components with positions and stages + mock_basic_pipeline.add_node( + mock_component, name="first", position="first", stage="preprocessing" + ) + mock_basic_pipeline.add_node( + mock_component, name="last", position="last", stage="other_processing" + ) + mock_basic_pipeline.add_node( + mock_component, + name="second", + position="after", + reference="first", + stage="other_processing", + ) + mock_basic_pipeline.add_node( + mock_component, name="third", position="before", reference="last" + ) + + assert len(mock_basic_pipeline._components) == 5 + assert mock_basic_pipeline._components[0].name == "first" + assert mock_basic_pipeline._components[0].stage == "preprocessing" + assert mock_basic_pipeline._components[1].name == "second" + assert mock_basic_pipeline._components[1].stage == "other_processing" + assert mock_basic_pipeline._components[2].name == "test_component" + assert mock_basic_pipeline._components[3].name == "third" + assert mock_basic_pipeline._components[-1].name == "last" + assert mock_basic_pipeline._components[-1].stage == "other_processing" + + # Test adding component with invalid position + with pytest.raises(ValueError): + mock_basic_pipeline.add_node(mock_component, name="invalid", position="middle") + + # Test adding component with missing reference + with pytest.raises(ValueError): + mock_basic_pipeline.add_node( + mock_component, name="invalid", position="after", reference="nonexistent" + ) + + # Test adding component with dependencies + mock_basic_pipeline.add_node(mock_component, name="dep1") + mock_basic_pipeline.add_node(mock_component, name="dep2") + mock_basic_pipeline.add_node( + mock_component, name="main", dependencies=["dep1", "dep2"] + ) + + assert mock_basic_pipeline._components[-1].name == "main" + assert mock_basic_pipeline._components[-1].dependencies == ["dep1", "dep2"] + + # Test adding component with input and output models + mock_basic_pipeline.add_node( + mock_component, + name="validated_component", + input_model=MockInputModel, + output_model=MockOutputModel, + ) + assert len(mock_basic_pipeline._components) == 9 + assert mock_basic_pipeline._components[-1].name == "validated_component" + + # Test adding component as a decorator + @mock_basic_pipeline.add_node(name="decorator_component", stage="processing") + def decorator_component(data: DataContainer) -> DataContainer: + data.data += 1 + return data + + assert len(mock_basic_pipeline._components) == 10 + assert mock_basic_pipeline._components[-1].name == "decorator_component" + assert mock_basic_pipeline._components[-1].stage == "processing" + + +# Test removing and replacing components +def test_remove_and_replace_component(mock_basic_pipeline, caplog): + mock_basic_pipeline.add_node(mock_component, name="test_component") + mock_basic_pipeline.remove("test_component") + assert len(mock_basic_pipeline._components) == 0 + + with pytest.raises(ValueError): + mock_basic_pipeline.remove("nonexistent_component") + + mock_basic_pipeline.add_node(mock_component, name="original") + + # Test replacing with a valid callable + def new_component(data: DataContainer) -> DataContainer: + return data + + mock_basic_pipeline.replace("original", new_component) + assert mock_basic_pipeline._components[0].func == new_component + + # Test replacing with an invalid callable (wrong signature) + def invalid_component(data): + return data + + with pytest.raises(ValueError): + mock_basic_pipeline.replace("original", invalid_component) + + # Test replacing with a BaseComponent + class NewComponent(BaseComponent): + def __call__(self, data: DataContainer) -> DataContainer: + return data + + new_base_component = NewComponent() + mock_basic_pipeline.replace("original", new_base_component) + assert mock_basic_pipeline._components[0].func == new_base_component + + # Test replacing a non-existent component + with pytest.raises(ValueError): + mock_basic_pipeline.replace("non_existent", new_component) + + # Test replacing with an invalid type + with pytest.raises(ValueError): + mock_basic_pipeline.replace("original", "not a component") + + +# Test building and executing pipeline +def test_build_and_execute_pipeline(mock_basic_pipeline): + mock_basic_pipeline.add_node(mock_component, name="comp1") + mock_basic_pipeline.add_node(mock_component, name="comp2") + + # Test that the pipeline automatically builds on first use + input_data = DataContainer(1) + result = mock_basic_pipeline(input_data) # This should trigger the automatic build + + assert result.data == 3 + assert ( + mock_basic_pipeline._built_pipeline is not None + ) # Check that the pipeline was built + + # Test that subsequent calls use the already built pipeline + result2 = mock_basic_pipeline(DataContainer(2)) + assert result2.data == 4 + + # Test explicit build method + mock_basic_pipeline._built_pipeline = None # Reset the built pipeline + explicit_pipeline = mock_basic_pipeline.build() + assert callable(explicit_pipeline) + + result3 = explicit_pipeline(DataContainer(3)) + assert result3.data == 5 + assert mock_basic_pipeline._built_pipeline is explicit_pipeline + + # Test circular dependency detection + mock_basic_pipeline.add_node(mock_component, name="comp3", dependencies=["comp4"]) + mock_basic_pipeline.add_node(mock_component, name="comp4", dependencies=["comp3"]) + + # Reset the built pipeline to force a rebuild + mock_basic_pipeline._built_pipeline = None + + with pytest.raises(ValueError, match="Circular dependency detected"): + mock_basic_pipeline( + DataContainer(1) + ) # This should trigger the build and raise the error + + # Also test that explicit build raises the same error + with pytest.raises(ValueError, match="Circular dependency detected"): + mock_basic_pipeline.build() + + +def test_pipeline_with_connectors(mock_basic_pipeline): + # Test with input and output connectors + class MockConnector: + def input(self, data): + data.data += 10 + return data + + def output(self, data): + data.data *= 2 + return data + + mock_basic_pipeline.add_input(MockConnector()) + mock_basic_pipeline.add_node(mock_component) + mock_basic_pipeline.add_output(MockConnector()) + + result = mock_basic_pipeline(DataContainer(1)) + assert result.data == 24 # (1 + 10 + 1) * 2 + + +# Test input and output model validation +def test_input_output_validation(mock_basic_pipeline): + def validated_component(data: DataContainer) -> DataContainer: + data.data = data.data * 2 + return data + + mock_basic_pipeline.add_node( + validated_component, + name="validated", + input_model=MockInputModel, + output_model=MockOutputModel, + ) + + pipeline_func = mock_basic_pipeline.build() + + valid_input = DataContainer(5) + result = pipeline_func(valid_input) + assert result.data == 10 + + invalid_input = DataContainer(-1) + with pytest.raises(ValidationError): + pipeline_func(invalid_input) + + # Test output validation + @mock_basic_pipeline.add_node( + name="invalid_output", input_model=MockInputModel, output_model=MockOutputModel + ) + def invalid_output_component(data: DataContainer) -> DataContainer: + data.data = data.data * 10 # This will produce an invalid output + return data + + mock_basic_pipeline._built_pipeline = None # Reset the built pipeline + pipeline_func = mock_basic_pipeline.build() + + with pytest.raises(ValidationError): + pipeline_func(DataContainer(5)) # 5 * 10 = 50, which is > 15 (invalid output) + + +# Test Pipeline class and representation +def test_pipeline_class_and_representation(mock_basic_pipeline): + pipeline = Pipeline() + assert hasattr(pipeline, "configure_pipeline") + pipeline.configure_pipeline("dummy_path") # Should not raise any exception + + mock_basic_pipeline.add_node(mock_component, name="comp1") + mock_basic_pipeline.add_node(mock_component, name="comp2") + + repr_string = repr(mock_basic_pipeline) + assert "comp1" in repr_string + assert "comp2" in repr_string + + loaded_pipeline = Pipeline.load("dummy_path") + assert isinstance(loaded_pipeline, Pipeline) + + +# Add a new test for the stages property +def test_stages_property(mock_basic_pipeline): + mock_basic_pipeline.add_node(mock_component, name="comp1", stage="stage1") + mock_basic_pipeline.add_node(mock_component, name="comp2", stage="stage2") + mock_basic_pipeline.add_node(mock_component, name="comp3", stage="stage1") + + stages_repr = mock_basic_pipeline.stages + assert "Pipeline Stages:" in stages_repr + assert "stage1:" in stages_repr + assert "stage2:" in stages_repr + assert "- mock_component" in stages_repr + + # Test setting stages + new_stages = { + "new_stage1": [mock_component], + "new_stage2": [mock_component, mock_component], + } + mock_basic_pipeline.stages = new_stages + assert mock_basic_pipeline._stages == new_stages diff --git a/tests/components/test_postprocessor.py b/tests/pipeline/test_postprocessor.py similarity index 100% rename from tests/components/test_postprocessor.py rename to tests/pipeline/test_postprocessor.py diff --git a/tests/components/test_preprocessor.py b/tests/pipeline/test_preprocessor.py similarity index 100% rename from tests/components/test_preprocessor.py rename to tests/pipeline/test_preprocessor.py diff --git a/tests/test_cds.py b/tests/test_cds.py index 4c5acc7..73df593 100644 --- a/tests/test_cds.py +++ b/tests/test_cds.py @@ -1,8 +1,8 @@ import pytest from unittest.mock import Mock -from healthchain.use_cases.cds import ClinicalDecisionSupport -from healthchain.models import Card +from healthchain.models.requests.cdsrequest import CDSRequest +from healthchain.models.responses.cdsresponse import CDSResponse def test_initialization(cds): @@ -14,8 +14,8 @@ def test_initialization(cds): assert "service_mount" in cds.endpoints -def test_cds_discovery_client_not_set(): - cds = ClinicalDecisionSupport() +def test_cds_discovery_client_not_set(cds): + cds._client = None info = cds.cds_discovery() assert info.services == [] @@ -27,73 +27,75 @@ def test_cds_discovery(cds): assert cds_info.services[0].hook == "hook1" -def test_cds_service_no_api_set(test_cds_request): - cds = ClinicalDecisionSupport() +def test_cds_service_valid_response( + cds, + test_cds_request, + test_cds_response_single_card, + test_cds_response_multiple_cards, +): + # Test when everything is valid + def valid_service_func_single_card(self, request: CDSRequest): + return test_cds_response_single_card + + cds._service_api = Mock(func=valid_service_func_single_card) + response = cds.cds_service("1", test_cds_request) - assert response.cards == [] + assert response == test_cds_response_single_card + def valid_service_func_multiple_cards(self, request: CDSRequest): + return test_cds_response_multiple_cards -def test_cds_service(cds, test_cds_request): - # test returning list of cards - request = test_cds_request - cds._service_api.func.return_value = [ - Card( - summary="example", - indicator="info", - source={"label": "test"}, - ) - ] - response = cds.cds_service("1", request) - assert len(response.cards) == 1 - assert response.cards[0].summary == "example" - assert response.cards[0].indicator == "info" + cds._service_api = Mock(func=valid_service_func_multiple_cards) + + response = cds.cds_service("1", test_cds_request) + assert response == test_cds_response_multiple_cards - # test returning single card - cds._service_api.func.return_value = Card( - summary="example", - indicator="info", - source={"label": "test"}, - ) - response = cds.cds_service("1", request) - assert len(response.cards) == 1 - assert response.cards[0].summary == "example" - assert response.cards[0].indicator == "info" +def test_cds_service_no_service_api(cds, test_cds_request): + # Test when _service_api is None + cds._service_api = None + response = cds.cds_service("test_id", test_cds_request) + assert isinstance(response, CDSResponse) + assert response.cards == [] -def test_cds_service_incorrect_return_type(cds, test_cds_request): - request = test_cds_request - cds._service_api.func.return_value = "this is not a valid return type" - with pytest.raises(TypeError): - cds.cds_service("1", request) +def test_cds_service_invalid(cds, test_cds_request, test_cds_response_empty): + # Test when service_api function has invalid signature + def invalid_service_signature(self, invalid_param: str): + return test_cds_response_empty -def func_zero_params(): - pass + cds._service_api = Mock(func=invalid_service_signature) + with pytest.raises( + TypeError, match="Expected first argument of service function to be CDSRequest" + ): + cds.cds_service("test_id", test_cds_request) -def func_two_params(self, param1, param2): - pass + # Test when service_api function has invalid number of parameters + def invalid_service_num_params(self): + return test_cds_response_empty + cds._service_api = Mock(func=invalid_service_num_params) -def func_one_param(self, param): - pass + with pytest.raises( + AssertionError, + match="Service function must have at least one parameter besides 'self'", + ): + cds.cds_service("test_id", test_cds_request) + # Test when service_api function returns invalid type + def invalid_service_return_type(self, request: CDSRequest): + return "Not a CDSResponse" -def test_cds_service_correct_number_of_parameters(cds, test_cds_request): - # Function with one parameter apart from 'self' - cds._service_api = Mock(func=func_one_param) + cds._service_api = Mock(func=invalid_service_return_type) - # Should not raise an assertion error - cds.cds_service("1", test_cds_request) + with pytest.raises(TypeError, match="Expected CDSResponse, but got str"): + cds.cds_service("test_id", test_cds_request) + # test no annotation - should not raise error + def valid_service_func_no_annotation(self, request): + return test_cds_response_empty -def test_cds_service_incorrect_number_of_parameters(cds, test_cds_request): - # Test with zero parameters apart from 'self' - cds._service_api = Mock(func=func_zero_params) - with pytest.raises(AssertionError): - cds.cds_service("1", test_cds_request) + cds._service_api = Mock(func=valid_service_func_no_annotation) - # Test with more than one parameter apart from 'self' - cds._service_api = Mock(func=func_two_params) - with pytest.raises(AssertionError): - cds.cds_service("1", test_cds_request) + assert cds.cds_service("test_id", test_cds_request) == test_cds_response_empty diff --git a/tests/test_clindoc.py b/tests/test_clindoc.py index 96d84e3..9952c9e 100644 --- a/tests/test_clindoc.py +++ b/tests/test_clindoc.py @@ -2,6 +2,9 @@ from unittest.mock import Mock +from healthchain.models.requests.cdarequest import CdaRequest +from healthchain.models.responses.cdaresponse import CdaResponse + def test_initialization(clindoc): assert clindoc._service_api is not None @@ -11,34 +14,71 @@ def test_initialization(clindoc): assert "service_mount" in clindoc.endpoints -def test_clindoc_notereader(clindoc, test_cda_request, test_ccd_data): - clindoc._service_api.func.return_value = test_ccd_data +def test_clindoc_notereader_service(clindoc, test_cda_request, test_cda_response): + def valid_service_func(self, request: CdaRequest): + return test_cda_response + + clindoc._service_api = Mock(func=valid_service_func) response = clindoc.process_notereader_document(test_cda_request) - assert "test" in response.document + assert ( + "Mock CDA Response Document" + in response.document + ) -def test_cds_service_incorrect_return_type(clindoc, test_cda_request): +def test_clindoc_service_incorrect_return_type(clindoc, test_cda_request): clindoc._service_api.func.return_value = "this is not a valid return type" with pytest.raises(TypeError): clindoc.process_notereader_document(test_cda_request) -def func_zero_params(): - pass +def test_process_notereader_document_no_service_api(clindoc, test_cda_request): + clindoc._service_api = None + response = clindoc.process_notereader_document(test_cda_request) + assert isinstance(response, CdaResponse) + assert response.document == "" + + +def test_process_notereader_document_invalid( + clindoc, test_cda_request, test_cda_response +): + # Test invalid parameter type + def invalid_service_func_invalid_param(self, invalid_param: str): + return test_cda_response + + clindoc._service_api = Mock(func=invalid_service_func_invalid_param) + with pytest.raises( + TypeError, match="Expected first argument of service function to be CdaRequest" + ): + clindoc.process_notereader_document(test_cda_request) -def func_two_params(self, param1, param2): - pass + # Test invalid return type + def invalid_service_func_invalid_return_type(self, request: CdaRequest): + return "Not a CdaResponse" + clindoc._service_api = Mock(func=invalid_service_func_invalid_return_type) -def test_cds_service_incorrect_number_of_parameters(clindoc, test_cda_request): - # Test with zero parameters apart from 'self' - clindoc._service_api = Mock(func=func_zero_params) - with pytest.raises(AssertionError): + with pytest.raises(TypeError, match="Expected return type CdaResponse"): clindoc.process_notereader_document(test_cda_request) - # Test with more than one parameter apart from 'self' - clindoc._service_api = Mock(func=func_two_params) - with pytest.raises(AssertionError): + # Test invalid number of parameters + def invalid_service_func(self): + return test_cda_response + + clindoc._service_api = Mock(func=invalid_service_func) + + with pytest.raises( + AssertionError, + match="Service function must have at least one parameter besides 'self'", + ): clindoc.process_notereader_document(test_cda_request) + + # test no annotation - should not raise error + def valid_service_func_no_annotation(self, request): + return test_cda_response + + clindoc._service_api = Mock(func=valid_service_func_no_annotation) + + assert clindoc.process_notereader_document(test_cda_request) == test_cda_response diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dcbac3e..4822ef9 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,8 @@ import pytest from pydantic import BaseModel, Field, ValidationError -from healthchain.pipeline.basepipeline import BasePipeline, Pipeline, BaseComponent +from healthchain.pipeline.base import BasePipeline, BaseComponent from healthchain.io.containers import DataContainer +from healthchain.pipeline.base import Pipeline # Mock classes and functions for testing @@ -36,25 +37,25 @@ def configure_pipeline(self, model_path: str) -> None: # Test adding components def test_add_component(basic_pipeline): # Test basic component addition - basic_pipeline.add(mock_component, name="test_component") + basic_pipeline.add_node(mock_component, name="test_component") assert len(basic_pipeline._components) == 1 assert basic_pipeline._components[0].name == "test_component" # Test adding components with positions and stages - basic_pipeline.add( + basic_pipeline.add_node( mock_component, name="first", position="first", stage="preprocessing" ) - basic_pipeline.add( + basic_pipeline.add_node( mock_component, name="last", position="last", stage="other_processing" ) - basic_pipeline.add( + basic_pipeline.add_node( mock_component, name="second", position="after", reference="first", stage="other_processing", ) - basic_pipeline.add( + basic_pipeline.add_node( mock_component, name="third", position="before", reference="last" ) @@ -70,18 +71,18 @@ def test_add_component(basic_pipeline): # Test adding component with invalid position with pytest.raises(ValueError): - basic_pipeline.add(mock_component, name="invalid", position="middle") + basic_pipeline.add_node(mock_component, name="invalid", position="middle") # Test adding component with missing reference with pytest.raises(ValueError): - basic_pipeline.add( + basic_pipeline.add_node( mock_component, name="invalid", position="after", reference="nonexistent" ) # Test adding component with dependencies - basic_pipeline.add(mock_component, name="dep1") - basic_pipeline.add(mock_component, name="dep2") - basic_pipeline.add(mock_component, name="main", dependencies=["dep1", "dep2"]) + basic_pipeline.add_node(mock_component, name="dep1") + basic_pipeline.add_node(mock_component, name="dep2") + basic_pipeline.add_node(mock_component, name="main", dependencies=["dep1", "dep2"]) assert basic_pipeline._components[-1].name == "main" assert basic_pipeline._components[-1].dependencies == ["dep1", "dep2"] @@ -89,14 +90,14 @@ def test_add_component(basic_pipeline): # Test removing and replacing components def test_remove_and_replace_component(basic_pipeline, caplog): - basic_pipeline.add(mock_component, name="test_component") + basic_pipeline.add_node(mock_component, name="test_component") basic_pipeline.remove("test_component") assert len(basic_pipeline._components) == 0 with pytest.raises(ValueError): basic_pipeline.remove("nonexistent_component") - basic_pipeline.add(mock_component, name="original") + basic_pipeline.add_node(mock_component, name="original") # Test replacing with a valid callable def new_component(data: DataContainer) -> DataContainer: @@ -132,8 +133,8 @@ def __call__(self, data: DataContainer) -> DataContainer: # Test building and executing pipeline def test_build_and_execute_pipeline(basic_pipeline): - basic_pipeline.add(mock_component, name="comp1") - basic_pipeline.add(mock_component, name="comp2") + basic_pipeline.add_node(mock_component, name="comp1") + basic_pipeline.add_node(mock_component, name="comp2") # Test that the pipeline automatically builds on first use input_data = DataContainer(1) @@ -158,8 +159,8 @@ def test_build_and_execute_pipeline(basic_pipeline): assert basic_pipeline._built_pipeline is explicit_pipeline # Test circular dependency detection - basic_pipeline.add(mock_component, name="comp3", dependencies=["comp4"]) - basic_pipeline.add(mock_component, name="comp4", dependencies=["comp3"]) + basic_pipeline.add_node(mock_component, name="comp3", dependencies=["comp4"]) + basic_pipeline.add_node(mock_component, name="comp4", dependencies=["comp3"]) # Reset the built pipeline to force a rebuild basic_pipeline._built_pipeline = None @@ -180,7 +181,7 @@ def validated_component(data: DataContainer) -> DataContainer: data.data = data.data * 2 return data - basic_pipeline.add( + basic_pipeline.add_node( validated_component, name="validated", input_model=MockInputModel, @@ -204,8 +205,8 @@ def test_pipeline_class_and_representation(basic_pipeline): assert hasattr(pipeline, "configure_pipeline") pipeline.configure_pipeline("dummy_path") # Should not raise any exception - basic_pipeline.add(mock_component, name="comp1") - basic_pipeline.add(mock_component, name="comp2") + basic_pipeline.add_node(mock_component, name="comp1") + basic_pipeline.add_node(mock_component, name="comp2") repr_string = repr(basic_pipeline) assert "comp1" in repr_string diff --git a/tests/test_service.py b/tests/test_service.py index 2b5d817..4838e5c 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -32,8 +32,8 @@ def test_cds_service(test_cds_request): @patch( "healthchain.use_cases.clindoc.ClinicalDocumentation.process_notereader_document" ) -def test_clindoc_process_document(mock_process, mock_cda_response, test_soap_request): - mock_process.return_value = mock_cda_response +def test_clindoc_process_document(mock_process, test_cda_response, test_soap_request): + mock_process.return_value = test_cda_response headers = {"Content-Type": "text/xml; charset=utf-8"} response = clindoc_client.post( diff --git a/tests/test_service_with_func.py b/tests/test_service_with_func.py index 614bd10..01612dc 100644 --- a/tests/test_service_with_func.py +++ b/tests/test_service_with_func.py @@ -3,6 +3,8 @@ from healthchain.clients import ehr from healthchain.decorators import sandbox, api +from healthchain.models.requests.cdsrequest import CDSRequest +from healthchain.models.responses.cdsresponse import CDSResponse from healthchain.use_cases import ClinicalDecisionSupport from healthchain.models import Card @@ -20,14 +22,17 @@ def load_data(self): return self.data_generator.data @api - def llm(self, text: str): - return [ - Card( - summary="test", - indicator="info", - source={"label": "website"}, - ) - ] + def test_service(self, request: CDSRequest): + return CDSResponse( + cards=[ + Card( + summary="Test Card", + indicator="info", + source={"label": "Test Source"}, + detail="This is a test card for CDS response", + ) + ] + ) cds = myCDS() @@ -54,7 +59,12 @@ def test_cds_service(test_cds_request): assert response.status_code == 200 assert response.json() == { "cards": [ - {"summary": "test", "indicator": "info", "source": {"label": "website"}} + { + "summary": "Test Card", + "indicator": "info", + "source": {"label": "Test Source"}, + "detail": "This is a test card for CDS response", + } ] }