From 1f9dd08a76764dd19541eff35a2cb02f308b5aa8 Mon Sep 17 00:00:00 2001 From: jenniferjiangkells Date: Thu, 7 Nov 2024 18:36:52 +0000 Subject: [PATCH] Added CdsCardCreator implementation --- healthchain/io/containers/document.py | 133 +++++++++++++--- healthchain/pipeline/base.py | 62 ++++++-- .../pipeline/components/cdscardcreator.py | 143 +++++++++++++++++- healthchain/pipeline/summarizationpipeline.py | 15 +- tests/components/test_cardcreator.py | 111 ++++++++++++++ tests/pipeline/test_containers.py | 95 +++++++++++- 6 files changed, 510 insertions(+), 49 deletions(-) create mode 100644 tests/components/test_cardcreator.py diff --git a/healthchain/io/containers/document.py b/healthchain/io/containers/document.py index 6632f27..ac576e8 100644 --- a/healthchain/io/containers/document.py +++ b/healthchain/io/containers/document.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass, field from typing import Any, Dict, Iterator, List, Optional, Union @@ -12,6 +13,8 @@ ProblemConcept, ) +logger = logging.getLogger(__name__) + @dataclass class NlpAnnotations: @@ -86,7 +89,8 @@ class ModelOutputs: Container for storing and managing third-party integration model outputs. This class stores outputs from different NLP/ML frameworks like Hugging Face - and LangChain, organizing them by task type. + and LangChain, organizing them by task type. It also maintains a list of + generated text outputs across frameworks. Attributes: _huggingface_results (Dict[str, Any]): Dictionary storing Hugging Face model @@ -95,29 +99,96 @@ class ModelOutputs: keyed by task name. Methods: - add_output(framework: str, task: str, output: Any): Adds a model output for a - specific framework and task. - get_output(framework: str, task: str, default: Any = None) -> Any: Gets the model - output for a specific framework and task. Returns default if not found. + add_output(source: str, task: str, output: Any): Adds a model output for a + specific source and task. For text generation tasks, also extracts and + stores the generated text. + get_output(source: str, task: str, default: Any = None) -> Any: Gets the model + output for a specific source and task. Returns default if not found. + get_generated_text() -> List[str]: Returns the list of generated text outputs """ _huggingface_results: Dict[str, Any] = field(default_factory=dict) _langchain_results: Dict[str, Any] = field(default_factory=dict) - def add_output(self, framework: str, task: str, output: Any): - if framework == "huggingface": + def add_output(self, source: str, task: str, output: Any): + if source == "huggingface": self._huggingface_results[task] = output - elif framework == "langchain": + elif source == "langchain": self._langchain_results[task] = output else: - raise ValueError(f"Unknown framework: {framework}") + raise ValueError(f"Unknown source: {source}") + + def get_output(self, source: str, task: str) -> Any: + if source == "huggingface": + return self._huggingface_results.get(task, {}) + elif source == "langchain": + return self._langchain_results.get(task, {}) + raise ValueError(f"Unknown source: {source}") + + def get_generated_text(self, source: str, task: str) -> List[str]: + """ + Returns generated text outputs for a given source and task. + + Handles different output formats for Hugging Face and LangChain. For + Hugging Face, it extracts the last message content from chat-style + outputs and common keys like "generated_text", "summary_text", and + "translation". For LangChain, it converts JSON outputs to strings, and returns + the output as is if it is already a string. - def get_output(self, framework: str, task: str, default: Any = None) -> Any: - if framework == "huggingface": - return self._huggingface_results.get(task, default) - elif framework == "langchain": - return self._langchain_results.get(task, default) - raise ValueError(f"Unknown framework: {framework}") + Args: + source (str): Framework name (e.g., "huggingface", "langchain"). + task (str): Task name for retrieving generated text. + + Returns: + List[str]: List of generated text outputs, or an empty list if none. + """ + generated_text = [] + + if source == "huggingface": + # Handle chat-style output format + output = self._huggingface_results.get(task) + if isinstance(output, list): + for entry in output: + text = entry.get("generated_text") + if isinstance(text, list): + last_msg = text[-1] + if isinstance(last_msg, dict) and "content" in last_msg: + generated_text.append(last_msg["content"]) + # Otherwise get common huggingface output keys + elif any( + key in entry + for key in ["generated_text", "summary_text", "translation"] + ): + generated_text.append( + text + or entry.get("summary_text") + or entry.get("translation") + ) + else: + logger.warning("HuggingFace output is not a list of dictionaries. ") + elif source == "langchain": + output = self._langchain_results.get(task) + # Check if output is a string + if isinstance(output, str): + generated_text.append(output) + # Try to convert JSON to string + elif isinstance(output, dict): + try: + import json + + output_str = json.dumps(output) + generated_text.append(output_str) + except Exception: + logger.warning( + "LangChain output is not a string and could not be converted to JSON string. " + "Chains should output either a string or a JSON object." + ) + else: + logger.warning( + "LangChain output is not a string. Chains should output either a string or a JSON object." + ) + + return generated_text @dataclass @@ -336,19 +407,35 @@ def add_concepts( if allergies: self._concepts.allergies.extend(allergies) - def generate_cds_cards( - self, cards: Union[List[Dict], List[Dict[str, Any]]] + def add_cds_cards( + self, cards: Union[List[Card], List[Dict[str, Any]]] ) -> List[Card]: - if isinstance(cards, dict): - cards = [Card(**card) for card in cards] + if not cards: + raise ValueError("Cards must be provided as a list!") + + try: + if isinstance(cards[0], dict): + cards = [Card(**card) for card in cards] + elif not isinstance(cards[0], Card): + raise TypeError("Cards must be either Card objects or dictionaries") + except (IndexError, KeyError) as e: + raise ValueError("Invalid card format") from e return self._cds.set_cards(cards) - def generate_cds_actions( - self, actions: Union[List[Dict], List[Dict[str, Any]]] + def add_cds_actions( + self, actions: Union[List[Action], List[Dict[str, Any]]] ) -> List[Action]: - if isinstance(actions, dict): - actions = [Action(**action) for action in actions] + if not actions: + raise ValueError("Actions must be provided as a list!") + + try: + if isinstance(actions[0], dict): + actions = [Action(**action) for action in actions] + elif not isinstance(actions[0], Action): + raise TypeError("Actions must be either Action objects or dictionaries") + except (IndexError, KeyError) as e: + raise ValueError("Invalid action format") from e return self._cds.set_actions(actions) diff --git a/healthchain/pipeline/base.py b/healthchain/pipeline/base.py index c25fc6b..c07ff54 100644 --- a/healthchain/pipeline/base.py +++ b/healthchain/pipeline/base.py @@ -73,12 +73,34 @@ class PipelineNode(Generic[T]): class BasePipeline(Generic[T], ABC): """ - Abstract BasePipeline class for creating and managing a data processing pipeline. - The BasePipeline class allows users to create a data processing pipeline by adding components and defining their dependencies and execution order. It provides methods for adding, removing, and replacing components, as well as building and executing the pipeline. - This is an abstract base class and should be subclassed to create specific pipeline implementations. + Abstract base class for creating and managing data processing pipelines. + + The BasePipeline class provides a framework for building modular data processing pipelines + by allowing users to add, remove, and configure components with defined dependencies and + execution order. Components can be added at specific positions, grouped into stages, and + connected via input/output connectors. + + This is an abstract base class that should be subclassed to create specific pipeline + implementations. + Attributes: - components (List[PipelineNode]): A list of PipelineNode objects representing the components in the pipeline. - stages (Dict[str, List[Callable]]): A dictionary mapping stage names to lists of component functions. + _components (List[PipelineNode[T]]): Ordered list of pipeline components + _stages (Dict[str, List[Callable]]): Components grouped by processing stage + _built_pipeline (Optional[Callable]): Compiled pipeline function + _input_connector (Optional[BaseConnector[T]]): Connector for processing input data + _output_connector (Optional[BaseConnector[T]]): Connector for processing output data + _output_template (Optional[str]): Template string for formatting pipeline outputs + _model_config (Optional[ModelConfig]): Configuration for the pipeline model + + Example: + >>> class MyPipeline(BasePipeline[str]): + ... def configure_pipeline(self, config: ModelConfig) -> None: + ... self.add_node(preprocess, stage="preprocessing") + ... self.add_node(process, stage="processing") + ... self.add_node(postprocess, stage="postprocessing") + ... + >>> pipeline = MyPipeline() + >>> result = pipeline("input text") """ def __init__(self): @@ -87,6 +109,7 @@ def __init__(self): self._built_pipeline: Optional[Callable] = None self._input_connector: Optional[BaseConnector[T]] = None self._output_connector: Optional[BaseConnector[T]] = None + self._output_template: Optional[str] = None def __repr__(self) -> str: components_repr = ", ".join( @@ -99,19 +122,30 @@ def load( cls, model_id: str, source: Union[str, ModelSource] = "huggingface", + template: Optional[str] = None, **model_kwargs: Any, ) -> "BasePipeline": """ Load and configure a pipeline from a given model. Args: - model_id: Identifier for the model (e.g. "microsoft/omniparser", "en_core_sci_md") - source: Model source - either "spacy" or "huggingface". Defaults to "huggingface" - **model_kwargs: Additional configuration options passed to the model. Common options: - - task: Task name for Hugging Face models (e.g. "ner", "summarization") - - device: Device to load model on ("cpu", "cuda") - - batch_size: Batch size for inference - - max_length: Maximum sequence length + model_id (str): Identifier for the model. Can be a model name from a supported source + (e.g. "gpt-4", "microsoft/omniparser") or a local path to a model. + source (Union[str, ModelSource], optional): Model source - can be "huggingface", "spacy", + "openai" or other supported sources. Defaults to "huggingface". + template (Optional[str], optional): Template string for formatting pipeline outputs. + Defaults to None. + **model_kwargs (Any): Additional configuration options passed to the model. Common options: + - task: Task name for models (e.g. "ner", "summarization") + - device: Device to load model on ("cpu", "cuda") + - batch_size: Batch size for inference + - max_length: Maximum sequence length + + Returns: + BasePipeline: Configured pipeline instance with the specified model. + + Raises: + ValueError: If an unsupported model source is provided. Examples: >>> # Load a summarization pipeline with GPT model @@ -120,7 +154,7 @@ def load( >>> # Load NER pipeline with SpaCy model >>> pipeline = NERPipeline.load("en_core_sci_md", source="spacy") - >>> # Load classification pipeline with BERT, specifying task + >>> # Load classification pipeline with BERT >>> pipeline = ClassificationPipeline.load( ... "microsoft/omniparser", # HuggingFace is default source ... task="sequence-classification" @@ -150,7 +184,9 @@ def load( config.config = model_kwargs pipeline._model_config = config + pipeline._output_template = template pipeline.configure_pipeline(config) + return pipeline @abstractmethod diff --git a/healthchain/pipeline/components/cdscardcreator.py b/healthchain/pipeline/components/cdscardcreator.py index 9fd3f76..b037480 100644 --- a/healthchain/pipeline/components/cdscardcreator.py +++ b/healthchain/pipeline/components/cdscardcreator.py @@ -1,10 +1,149 @@ +import logging +from typing import Optional, Dict, Any +from jinja2 import Template + from healthchain.pipeline.components.base import BaseComponent from healthchain.io.containers import Document +from healthchain.models.responses.cdsresponse import Card, Source, IndicatorEnum + + +logger = logging.getLogger(__name__) class CdsCardCreator(BaseComponent[str]): - def __init__(self): - pass + """ + Component that creates CDS Cards using templates. + + The component uses a template string to format model outputs into CDS Hooks cards + that can be displayed in an EHR. + + Args: + template (str, optional): Template string for card creation. If not provided, + uses a default template that creates an info card with the model output. + content (str, optional): Static content to use instead of model output. + source (str, optional): Source framework to get model output from (e.g. "huggingface", "langchain"). + task (str, optional): Task name to get model output from (e.g. "summarization", "chat"). + default_source (Dict[str, Any], optional): Default source information for the cards. + Defaults to {"label": "Card Generated by HealthChain"}. + + Example: + >>> # Create cards from model output + >>> creator = CdsCardCreator(source="huggingface", task="summarization") + >>> doc = creator(doc) # Creates cards from summarization output + + >>> # Create cards with static content + >>> creator = CdsCardCreator(content="This is a static card message") + >>> doc = creator(doc) # Creates card with the static content + + >>> # Create cards with custom template + >>> template = ''' + ... { + ... "summary": "Custom card: {{ model_output }}", + ... "indicator": "info", + ... "source": {{ default_source | tojson }}, + ... "detail": "{{ model_output }}" + ... } + ... ''' + >>> creator = CdsCardCreator(template=template) + >>> doc = creator(doc) + + """ + + # TODO: make source and other fields configurable from model too + DEFAULT_TEMPLATE = """ + { + "summary": "{{ model_output[:140] }}", + "indicator": "info", + "source": {{ default_source | tojson }}, + "detail": "{{ model_output }}" + } + """ + + def __init__( + self, + template: Optional[str] = None, + content: Optional[str] = None, + source: Optional[str] = None, + task: Optional[str] = None, + default_source: Optional[Dict[str, Any]] = None, + ): + self.template = Template( + template if template is not None else self.DEFAULT_TEMPLATE + ) + self.content = content + self.source = source + self.task = task + self.default_source = default_source or { + "label": "Card Generated by HealthChain" + } + + def create_card(self, content: str) -> Card: + """Creates a CDS Card using the template and model output.""" + try: + # Render template with model output + import json + + card_json = self.template.render( + model_output=content, default_source=self.default_source + ) + + # Parse the rendered JSON into card fields + card_fields = json.loads(card_json) + + return Card( + summary=card_fields["summary"][:140], # Enforce max length + indicator=IndicatorEnum(card_fields["indicator"]), + source=Source(**card_fields["source"]), + detail=card_fields.get("detail"), + suggestions=card_fields.get("suggestions"), + selectionBehavior=card_fields.get("selectionBehavior"), + overrideReasons=card_fields.get("overrideReasons"), + links=card_fields.get("links"), + ) + except Exception as e: + raise ValueError(f"Error creating card: {str(e)}") def __call__(self, doc: Document) -> Document: + """ + Process the document and create CDS cards from model outputs or provided content. + + This method creates CDS Hooks cards either from model-generated text outputs + stored in the document's model outputs container, or from explicitly provided + content. The cards are created using the configured template and added to + the document's CDS container. + + Args: + doc (Document): Input document containing model outputs or to add cards to + + Returns: + Document: Document with generated CDS cards added + + Raises: + ValueError: If neither model source/task nor content is provided + """ + if self.source and self.task: + generated_text = doc.models.get_generated_text(self.source, self.task) + if not generated_text: + logger.warning( + f"No generated text for {self.source}/{self.task} found for CDS card creation!" + ) + return doc + elif self.content: + generated_text = [self.content] + else: + raise ValueError( + "Either model output (source and task) or content need to be provided for CDS card creation!" + ) + + # Create card from model output + cards = [] + for text in generated_text: + try: + card = self.create_card(text) + cards.append(card) + except Exception as e: + logger.warning(f"Error creating card: {str(e)}") + + doc.add_cds_cards(cards) + return doc diff --git a/healthchain/pipeline/summarizationpipeline.py b/healthchain/pipeline/summarizationpipeline.py index e54d23c..19248bd 100644 --- a/healthchain/pipeline/summarizationpipeline.py +++ b/healthchain/pipeline/summarizationpipeline.py @@ -1,5 +1,6 @@ -from healthchain.io.cdsfhirconnector import CdsFhirConnector from healthchain.pipeline.base import BasePipeline +from healthchain.io.cdsfhirconnector import CdsFhirConnector +from healthchain.pipeline.components.cdscardcreator import CdsCardCreator from healthchain.pipeline.modelrouter import ModelRouter, ModelConfig @@ -38,8 +39,12 @@ def configure_pipeline(self, config: ModelConfig) -> None: self.add_input(cds_fhir_connector) self.add_node(model, stage="summarization") - - # TODO: need a component that creates cards from the summary - # self.add_node(CardCreator(), stage="card-creation") - + self.add_node( + CdsCardCreator( + source=config.source.value, + task="summarization", + template=self._output_template, + ), + stage="card-creation", + ) self.add_output(cds_fhir_connector) diff --git a/tests/components/test_cardcreator.py b/tests/components/test_cardcreator.py new file mode 100644 index 0000000..89ea127 --- /dev/null +++ b/tests/components/test_cardcreator.py @@ -0,0 +1,111 @@ +import pytest +from healthchain.pipeline.components.cdscardcreator import CdsCardCreator +from healthchain.io.containers import Document +from healthchain.models.responses.cdsresponse import Card, Source, IndicatorEnum + + +@pytest.fixture +def basic_creator(): + return CdsCardCreator() + + +@pytest.fixture +def custom_template_creator(): + template = """ + { + "summary": "Custom: {{ model_output }}", + "indicator": "warning", + "source": {{ default_source | tojson }}, + "detail": "{{ model_output }}" + } + """ + return CdsCardCreator(template=template) + + +def test_default_template_rendering(basic_creator): + content = "Test message" + card = basic_creator.create_card(content) + + assert isinstance(card, Card) + assert card.summary == "Test message"[:140] + assert card.indicator == IndicatorEnum.info + assert card.source == Source(label="Card Generated by HealthChain") + assert card.detail == "Test message" + + +def test_custom_template_rendering(custom_template_creator): + content = "Test message" + card = custom_template_creator.create_card(content) + + assert card.summary == "Custom: Test message" + assert card.indicator == IndicatorEnum.warning + assert card.source == Source(label="Card Generated by HealthChain") + assert card.detail == "Test message" + + +def test_long_summary_truncation(basic_creator): + long_content = "x" * 200 + card = basic_creator.create_card(long_content) + + assert len(card.summary) == 140 + assert card.summary == "x" * 140 + + +def test_invalid_template_json(basic_creator): + invalid_template = """ + { + "summary": {{ invalid_json }}, + } + """ + with pytest.raises(ValueError): + creator = CdsCardCreator(template=invalid_template) + creator(Document(data="test")) + + +def test_document_processing_with_model_output(): + creator = CdsCardCreator(source="huggingface", task="summarization") + doc = Document(data="test") + doc.models.add_output( + "huggingface", "summarization", [{"summary_text": "Model summary"}] + ) + + processed_doc = creator(doc) + cards = processed_doc.cds.get_cards() + + assert len(cards) == 1 + assert cards[0].summary == "Model summary" + assert cards[0].indicator == IndicatorEnum.info + + +def test_document_processing_with_static_content(): + creator = CdsCardCreator(content="Static content") + doc = Document(data="test") + + processed_doc = creator(doc) + cards = processed_doc.cds.get_cards() + + assert len(cards) == 1 + assert cards[0].summary == "Static content" + + +def test_missing_model_output_warning(caplog): + creator = CdsCardCreator(source="huggingface", task="missing_task") + doc = Document(data="test") + + processed_doc = creator(doc) + + assert "No generated text for huggingface/missing_task found" in caplog.text + assert not processed_doc.cds.get_cards() + + +def test_invalid_input_configuration(): + creator = CdsCardCreator() # No content or source/task specified + doc = Document(data="test") + + with pytest.raises(ValueError) as exc_info: + creator(doc) + + assert ( + "Either model output (source and task) or content need to be provided" + in str(exc_info.value) + ) diff --git a/tests/pipeline/test_containers.py b/tests/pipeline/test_containers.py index bed1cea..f149e61 100644 --- a/tests/pipeline/test_containers.py +++ b/tests/pipeline/test_containers.py @@ -1,5 +1,12 @@ import pytest from healthchain.io.containers import Document +from healthchain.models.responses import Card, Action +from healthchain.models.data import CcdData +from healthchain.models.data.concept import ( + ProblemConcept, + MedicationConcept, + AllergyConcept, +) @pytest.fixture @@ -23,33 +30,105 @@ def test_document_initialization(sample_document): assert sample_document.nlp.get_embeddings() is None -def test_document_word_count(sample_document): - assert sample_document.word_count() == 7 +def test_document_add_concepts(sample_document): + problems = [ProblemConcept(display_name="Hypertension")] + medications = [MedicationConcept(display_name="Aspirin")] + allergies = [AllergyConcept(display_name="Penicillin")] + + sample_document.add_concepts( + problems=problems, medications=medications, allergies=allergies + ) + + assert sample_document.concepts.problems == problems + assert sample_document.concepts.medications == medications + assert sample_document.concepts.allergies == allergies + + +def test_document_generate_ccd(sample_document): + problems = [ProblemConcept(display_name="Hypertension")] + sample_document.add_concepts(problems=problems) + + ccd_data = sample_document.generate_ccd() + assert isinstance(ccd_data, CcdData) + assert ccd_data.concepts.problems == problems + + +def test_document_add_cds_cards(sample_document): + cards = [ + { + "summary": "Test Card", + "detail": "Test Detail", + "indicator": "info", + "source": {"label": "Test Source"}, + } + ] + sample_document.add_cds_cards(cards) + + assert isinstance(sample_document.cds.get_cards()[0], Card) + assert sample_document.cds.get_cards()[0].summary == "Test Card" + + +def test_document_add_cds_actions(sample_document): + actions = [ + {"type": "create", "description": "Test Action", "resource": {"test": "test"}} + ] + sample_document.add_cds_actions(actions) + + assert isinstance(sample_document.cds.get_actions()[0], Action) + assert sample_document.cds.get_actions()[0].description == "Test Action" def test_document_add_huggingface_output(sample_document): - mock_output = {"label": "POSITIVE", "score": 0.9} + mock_output = [ + {"label": "POSITIVE", "score": 0.9, "generated_text": "Generated response"} + ] sample_document.models.add_output("huggingface", "sentiment", mock_output) assert sample_document.models.get_output("huggingface", "sentiment") == mock_output + assert sample_document.models.get_generated_text("huggingface", "sentiment") == [ + "Generated response" + ] + + mock_output_chat = [ + { + "generated_text": [ + { + "role": "user", + "content": "What is the capital of France? Answer in one word.", + }, + {"role": "assistant", "content": "Paris"}, + ] + } + ] + sample_document.models.add_output("huggingface", "chat", mock_output_chat) + assert sample_document.models.get_output("huggingface", "chat") == mock_output_chat + assert sample_document.models.get_generated_text("huggingface", "chat") == [ + "Paris", + ] def test_document_add_langchain_output(sample_document): mock_output = "Summarized text" - sample_document.models.add_output("langchain", "summarization", mock_output) assert ( sample_document.models.get_output("langchain", "summarization") == mock_output ) + assert sample_document.models.get_generated_text("langchain", "summarization") == [ + mock_output + ] + + mock_output_json = {"test": "test"} + sample_document.models.add_output("langchain", "json", mock_output_json) + assert sample_document.models.get_generated_text("langchain", "json") == [ + '{"test": "test"}' + ] def test_document_embeddings(sample_document): embeddings = [0.1, 0.2, 0.3] - sample_document.nlp.set_embeddings(embeddings) - assert sample_document.nlp.get_embeddings() == embeddings @@ -67,3 +146,7 @@ def test_document_iteration(sample_document): def test_document_length(sample_document): assert len(sample_document) == 34 + + +def test_document_word_count(sample_document): + assert sample_document.word_count() == 7