From 3977f7520954827c9f9271aa2476231430ee756c Mon Sep 17 00:00:00 2001 From: jenniferjiangkells Date: Thu, 14 Nov 2024 15:30:52 +0000 Subject: [PATCH] Split .load method to from_model_id and from_local_model and added template path as init option --- healthchain/pipeline/base.py | 240 ++++++++++++------ .../pipeline/components/cdscardcreator.py | 22 +- healthchain/pipeline/medicalcodingpipeline.py | 8 +- healthchain/pipeline/summarizationpipeline.py | 5 +- 4 files changed, 191 insertions(+), 84 deletions(-) diff --git a/healthchain/pipeline/base.py b/healthchain/pipeline/base.py index 444e3ad..2aac5f2 100644 --- a/healthchain/pipeline/base.py +++ b/healthchain/pipeline/base.py @@ -112,6 +112,7 @@ def __init__(self): self._input_connector: Optional[BaseConnector[T]] = None self._output_connector: Optional[BaseConnector[T]] = None self._output_template: Optional[str] = None + self._output_template_path: Optional[Path] = None def __repr__(self) -> str: components_repr = ", ".join( @@ -119,105 +120,165 @@ def __repr__(self) -> str: ) return f"[{components_repr}]" + def _configure_output_templates( + self, + template: Optional[str] = None, + template_path: Optional[Union[str, Path]] = None, + ) -> None: + """ + Configure template settings for the pipeline. + + Args: + template (Optional[str]): Template string for formatting outputs. + Defaults to None. + template_path (Optional[Union[str, Path]]): Path to template file. + Defaults to None. + """ + self._output_template = template + self._output_template_path = Path(template_path) if template_path else None + @classmethod def load( cls, - model: Union[str, Callable], - source: Union[str, ModelSource] = "huggingface", + pipeline: Callable, task: Optional[str] = "text-generation", + source: str = "langchain", template: Optional[str] = None, - **model_kwargs: Any, + template_path: Optional[Union[str, Path]] = None, + **kwargs: Any, ) -> "BasePipeline": """ - Load and configure a pipeline from a given model or LangChain chain. + Load a pipeline from a pre-built pipeline object (e.g. LangChain chain). Args: - model (Union[str, Callable]): Identifier for the model or LangChain chain. Can be: - - a model name from a supported source - (e.g. "en_core_sci_md", "meta-llama/Llama-3.2-1B") - - a local path to a model - - a LangChain Chain instance - source (Union[str, ModelSource], optional): Model source - can be "huggingface", "spacy", - "openai" or other supported sources. Defaults to "huggingface". - task (Optional[str], optional): Task name for models. used as keys to retrieve model outputs - (e.g. "ner", "summarization"). Defaults to "text-generation". - template (Optional[str], optional): Template string for formatting pipeline outputs. + pipeline (Callable): A callable pipeline object (e.g. LangChain chain) + task (Optional[str]): Task identifier used to retrieve model outputs. + Defaults to "text-generation". + template (Optional[str]): Template string for formatting outputs. Defaults to None. - **model_kwargs (Any): Additional configuration options passed to the model. Common options: - - device: Device to load model on ("cpu", "cuda") - - batch_size: Batch size for inference - - max_length: Maximum sequence length + **kwargs: Additional configuration options passed to the pipeline. Returns: - BasePipeline: Configured pipeline instance with the specified model. - - Raises: - ValueError: If an unsupported model source is provided. + BasePipeline: Configured pipeline instance. Examples: - >>> # Load NER pipeline with SpaCy model - >>> pipeline = MedicalCodingPipeline.load("en_core_sci_md", source="spacy") - - >>> # Load summarization pipeline with Hugging Face model - >>> pipeline = SummarizationPipeline.load( - ... "meta-llama/Llama-3.2-1B", # HuggingFace is default source - ... task="text-generation" - ... ) - - >>> # Load pipeline with LangChain >>> from langchain_core.prompts import ChatPromptTemplate >>> from langchain_openai import ChatOpenAI >>> chain = ChatPromptTemplate.from_template("What is {input}?") | ChatOpenAI() - >>> pipeline = Pipeline.load(chain, temperature=0.7, max_tokens=100) + >>> pipeline = Pipeline.load(chain, temperature=0.7) + """ + if not hasattr(pipeline, "__call__") and not hasattr(pipeline, "invoke"): + raise ValueError("Pipeline must be a callable object") + + instance = cls() + instance._configure_output_templates(template, template_path) + + config = ModelConfig( + source=ModelSource(source.lower()), model=pipeline, task=task, kwargs=kwargs + ) - >>> # Load custom pipeline from local SpaCy model - >>> pipeline = CustomPipeline.load("./models/spacy/my_model") + instance._model_config = config + instance.configure_pipeline(config) + + return instance + + @classmethod + def from_model_id( + cls, + model_id: str, + source: Union[str, ModelSource] = "huggingface", + task: Optional[str] = "text-generation", + template: Optional[str] = None, + template_path: Optional[Union[str, Path]] = None, + **kwargs: Any, + ) -> "BasePipeline": + """ + Load pipeline from a model identifier. + + Args: + model_id (str): Model identifier (e.g. HuggingFace model ID, SpaCy model name) + source (Union[str, ModelSource]): Model source. Defaults to "huggingface". + Can be "huggingface", "spacy". + task (Optional[str]): Task identifier for the model. Defaults to "text-generation". + **kwargs: Additional configuration options passed to the model. e.g. temperature, max_length, etc. + + Returns: + BasePipeline: Configured pipeline instance. + + Examples: + >>> # Load HuggingFace model + >>> pipeline = Pipeline.from_model_id( + ... "facebook/bart-large-cnn", + ... task="summarization", + ... temperature=0.7 + ... ) + >>> + >>> # Load SpaCy model + >>> pipeline = Pipeline.from_model_id( + ... "en_core_sci_md", + ... source="spacy", + ... disable=["parser"] + ... ) """ - # Initialize pipeline instance pipeline = cls() - pipeline._output_template = template - - # Create model config based on input type - if hasattr(model, "invoke") or hasattr(model, "__call__"): - # Handle LangChain Chain - config = ModelConfig( - source=ModelSource.LANGCHAIN, - model=model, - task=task, - ) - else: - # Convert string source to enum if needed - if isinstance(source, str): - source_lower = source.lower() - - if source_lower == "langchain": - raise ValueError( - "LangChain models must be passed directly as chain objects, " - "not as string source identifiers." - ) + pipeline._configure_output_templates(template, template_path) - try: - source = ModelSource(source_lower) - except ValueError: - supported = ", ".join(s.value for s in ModelSource) - raise ValueError( - f"Unsupported model source: {source}. " - f"Supported sources: {supported}" - ) + config = ModelConfig( + source=ModelSource(source.lower()), model=model_id, task=task, kwargs=kwargs + ) + pipeline._model_config = config + pipeline.configure_pipeline(config) - # Handle local paths vs model IDs - if isinstance(model, str) and ( - model.startswith("./") or model.startswith("/") - ): - path = Path(model) - config = ModelConfig( - source=source, model=path.name, path=path, task=task - ) - else: - config = ModelConfig(source=source, model=model, task=task) + return pipeline - # Configure and return pipeline - config.kwargs = model_kwargs + @classmethod + def from_local_model( + cls, + path: Union[str, Path], + source: Union[str, ModelSource], + task: Optional[str] = None, + template: Optional[str] = None, + template_path: Optional[Union[str, Path]] = None, + **kwargs: Any, + ) -> "BasePipeline": + """Load pipeline from a local model path. + + Args: + path (Union[str, Path]): Path to local model files/directory + source (Union[str, ModelSource]): Model source (e.g. "huggingface", "spacy") + task (Optional[str]): Task identifier for the model. Defaults to None. + **kwargs: Additional configuration options passed to the model. e.g. temperature, max_length, etc. + + Returns: + BasePipeline: Configured pipeline instance. + + Examples: + >>> # Load local HuggingFace model + >>> pipeline = Pipeline.from_local_model( + ... "models/my_summarizer", + ... source="huggingface", + ... task="summarization", + ... temperature=0.7 + ... ) + >>> + >>> # Load local SpaCy model + >>> pipeline = Pipeline.from_local_model( + ... "models/en_core_sci_md", + ... source="spacy", + ... disable=["parser"] + ... ) + """ + pipeline = cls() + pipeline._configure_output_templates(template, template_path) + + path = Path(path) + config = ModelConfig( + source=ModelSource(source.lower()), + model=path.name, + path=path, + task=task, + kwargs=kwargs, + ) pipeline._model_config = config pipeline.configure_pipeline(config) @@ -230,16 +291,39 @@ def configure_pipeline(self, model_config: ModelConfig) -> None: This method should be implemented by subclasses to add specific components and configure the pipeline according to the given model configuration. + The configuration typically involves: + 1. Setting up input/output connectors + 2. Adding model components based on the model source + 3. Adding any additional processing nodes + 4. Configuring the pipeline stages and execution order Args: - model_config (ModelConfig): Configuration object containing model source, - ID and optional path information. + model_config (ModelConfig): Configuration object containing: + - source: Model source (e.g. huggingface, spacy, langchain) + - model: Model identifier or path + - task: Optional task name (e.g. summarization, ner) + - path: Optional local path to model files + - kwargs: Additional model configuration parameters Returns: None Raises: NotImplementedError: If the method is not implemented by a subclass. + + Example: + >>> def configure_pipeline(self, config: ModelConfig): + ... # Add FHIR connector for input/output + ... connector = FhirConnector() + ... self.add_input(connector) + ... + ... # Add model component + ... model = self.get_model_component(config) + ... self.add_node(model, stage="processing") + ... + ... # Add output formatting + ... self.add_node(OutputFormatter(), stage="formatting") + ... self.add_output(connector) """ raise NotImplementedError("This method must be implemented by subclasses.") diff --git a/healthchain/pipeline/components/cdscardcreator.py b/healthchain/pipeline/components/cdscardcreator.py index 453317a..44c0e48 100644 --- a/healthchain/pipeline/components/cdscardcreator.py +++ b/healthchain/pipeline/components/cdscardcreator.py @@ -1,7 +1,8 @@ import logging import json -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Union from jinja2 import Template +from pathlib import Path from healthchain.pipeline.components.base import BaseComponent from healthchain.io.containers import Document @@ -23,6 +24,8 @@ class CdsCardCreator(BaseComponent[str]): Args: template (str, optional): Template string for card creation. If not provided, uses a default template that creates an info card with the model output. + template_path (str, optional): Path to a template file. If not provided, + uses a default template that creates an info card with the model output. static_content (str, optional): Static content to use instead of model output. source (str, optional): Source framework to get model output from (e.g. "huggingface", "langchain"). task (str, optional): Task name to get model output from (e.g. "summarization", "chat"). @@ -76,12 +79,25 @@ class CdsCardCreator(BaseComponent[str]): def __init__( self, template: Optional[str] = None, + template_path: Optional[Union[str, Path]] = None, static_content: Optional[str] = None, source: Optional[str] = None, task: Optional[str] = None, delimiter: Optional[str] = None, default_source: Optional[Dict[str, Any]] = None, ): + # Load template from file or use string template + if template_path: + try: + template_path = Path(template_path) + if not template_path.exists(): + raise FileNotFoundError(f"Template file not found: {template_path}") + with open(template_path) as f: + template = f.read() + except Exception as e: + logger.error(f"Error loading template from {template_path}: {str(e)}") + template = self.DEFAULT_TEMPLATE + self.template = Template( template if template is not None else self.DEFAULT_TEMPLATE ) @@ -122,7 +138,9 @@ def create_card(self, content: str) -> Card: links=card_fields.get("links"), ) except Exception as e: - raise ValueError(f"Error creating card: {str(e)}") + raise ValueError( + f"Error creating CDS card: Failed to render template or parse card fields: {str(e)}" + ) def __call__(self, doc: Document) -> Document: """ diff --git a/healthchain/pipeline/medicalcodingpipeline.py b/healthchain/pipeline/medicalcodingpipeline.py index 4d09dd1..eb70fc1 100644 --- a/healthchain/pipeline/medicalcodingpipeline.py +++ b/healthchain/pipeline/medicalcodingpipeline.py @@ -18,14 +18,18 @@ class MedicalCodingPipeline(BasePipeline, ModelRoutingMixin): Examples: >>> # Using with SpaCy/MedCAT - >>> pipeline = MedicalCodingPipeline.load("medcatlite", source="spacy") + >>> pipeline = MedicalCodingPipeline.from_model_id("medcatlite", source="spacy") >>> cda_response = pipeline(documents) >>> >>> # Using with Hugging Face - >>> pipeline = MedicalCodingPipeline.load( + >>> pipeline = MedicalCodingPipeline.from_model_id( ... "bert-base-uncased", ... task="ner" ... ) + >>> # Using with LangChain + >>> chain = ChatPromptTemplate.from_template("Extract medical codes: {text}") | ChatOpenAI() + >>> pipeline = MedicalCodingPipeline.load(chain) + >>> >>> cda_response = pipeline(documents) """ diff --git a/healthchain/pipeline/summarizationpipeline.py b/healthchain/pipeline/summarizationpipeline.py index aecf5d3..9375cef 100644 --- a/healthchain/pipeline/summarizationpipeline.py +++ b/healthchain/pipeline/summarizationpipeline.py @@ -21,11 +21,11 @@ class SummarizationPipeline(BasePipeline, ModelRoutingMixin): Examples: >>> # Using with GPT model - >>> pipeline = SummarizationPipeline.load("gpt-4o", source="openai") + >>> pipeline = SummarizationPipeline.from_model_id("gpt-4o", source="openai") >>> cds_response = pipeline(documents) >>> >>> # Using with Hugging Face - >>> pipeline = SummarizationPipeline.load( + >>> pipeline = SummarizationPipeline.from_model_id( ... "facebook/bart-large-cnn", ... task="summarization" ... ) @@ -53,6 +53,7 @@ def configure_pipeline(self, config: ModelConfig) -> None: source=config.source.value, task="summarization", template=self._output_template, + template_path=self._output_template_path, delimiter="\n", ), stage="card-creation",