Skip to content

Commit

Permalink
Split .load method to from_model_id and from_local_model and added te…
Browse files Browse the repository at this point in the history
…mplate path as init option
  • Loading branch information
jenniferjiangkells committed Nov 14, 2024
1 parent cba3b17 commit 3977f75
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 84 deletions.
240 changes: 162 additions & 78 deletions healthchain/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,112 +112,173 @@ def __init__(self):
self._input_connector: Optional[BaseConnector[T]] = None
self._output_connector: Optional[BaseConnector[T]] = None
self._output_template: Optional[str] = None
self._output_template_path: Optional[Path] = None

def __repr__(self) -> str:
components_repr = ", ".join(
[f'"{component.name}"' for component in self._components]
)
return f"[{components_repr}]"

def _configure_output_templates(
self,
template: Optional[str] = None,
template_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Configure template settings for the pipeline.
Args:
template (Optional[str]): Template string for formatting outputs.
Defaults to None.
template_path (Optional[Union[str, Path]]): Path to template file.
Defaults to None.
"""
self._output_template = template
self._output_template_path = Path(template_path) if template_path else None

@classmethod
def load(
cls,
model: Union[str, Callable],
source: Union[str, ModelSource] = "huggingface",
pipeline: Callable,
task: Optional[str] = "text-generation",
source: str = "langchain",
template: Optional[str] = None,
**model_kwargs: Any,
template_path: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> "BasePipeline":
"""
Load and configure a pipeline from a given model or LangChain chain.
Load a pipeline from a pre-built pipeline object (e.g. LangChain chain).
Args:
model (Union[str, Callable]): Identifier for the model or LangChain chain. Can be:
- a model name from a supported source
(e.g. "en_core_sci_md", "meta-llama/Llama-3.2-1B")
- a local path to a model
- a LangChain Chain instance
source (Union[str, ModelSource], optional): Model source - can be "huggingface", "spacy",
"openai" or other supported sources. Defaults to "huggingface".
task (Optional[str], optional): Task name for models. used as keys to retrieve model outputs
(e.g. "ner", "summarization"). Defaults to "text-generation".
template (Optional[str], optional): Template string for formatting pipeline outputs.
pipeline (Callable): A callable pipeline object (e.g. LangChain chain)
task (Optional[str]): Task identifier used to retrieve model outputs.
Defaults to "text-generation".
template (Optional[str]): Template string for formatting outputs.
Defaults to None.
**model_kwargs (Any): Additional configuration options passed to the model. Common options:
- device: Device to load model on ("cpu", "cuda")
- batch_size: Batch size for inference
- max_length: Maximum sequence length
**kwargs: Additional configuration options passed to the pipeline.
Returns:
BasePipeline: Configured pipeline instance with the specified model.
Raises:
ValueError: If an unsupported model source is provided.
BasePipeline: Configured pipeline instance.
Examples:
>>> # Load NER pipeline with SpaCy model
>>> pipeline = MedicalCodingPipeline.load("en_core_sci_md", source="spacy")
>>> # Load summarization pipeline with Hugging Face model
>>> pipeline = SummarizationPipeline.load(
... "meta-llama/Llama-3.2-1B", # HuggingFace is default source
... task="text-generation"
... )
>>> # Load pipeline with LangChain
>>> from langchain_core.prompts import ChatPromptTemplate
>>> from langchain_openai import ChatOpenAI
>>> chain = ChatPromptTemplate.from_template("What is {input}?") | ChatOpenAI()
>>> pipeline = Pipeline.load(chain, temperature=0.7, max_tokens=100)
>>> pipeline = Pipeline.load(chain, temperature=0.7)
"""
if not hasattr(pipeline, "__call__") and not hasattr(pipeline, "invoke"):
raise ValueError("Pipeline must be a callable object")

instance = cls()
instance._configure_output_templates(template, template_path)

config = ModelConfig(
source=ModelSource(source.lower()), model=pipeline, task=task, kwargs=kwargs
)

>>> # Load custom pipeline from local SpaCy model
>>> pipeline = CustomPipeline.load("./models/spacy/my_model")
instance._model_config = config
instance.configure_pipeline(config)

return instance

@classmethod
def from_model_id(
cls,
model_id: str,
source: Union[str, ModelSource] = "huggingface",
task: Optional[str] = "text-generation",
template: Optional[str] = None,
template_path: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> "BasePipeline":
"""
Load pipeline from a model identifier.
Args:
model_id (str): Model identifier (e.g. HuggingFace model ID, SpaCy model name)
source (Union[str, ModelSource]): Model source. Defaults to "huggingface".
Can be "huggingface", "spacy".
task (Optional[str]): Task identifier for the model. Defaults to "text-generation".
**kwargs: Additional configuration options passed to the model. e.g. temperature, max_length, etc.
Returns:
BasePipeline: Configured pipeline instance.
Examples:
>>> # Load HuggingFace model
>>> pipeline = Pipeline.from_model_id(
... "facebook/bart-large-cnn",
... task="summarization",
... temperature=0.7
... )
>>>
>>> # Load SpaCy model
>>> pipeline = Pipeline.from_model_id(
... "en_core_sci_md",
... source="spacy",
... disable=["parser"]
... )
"""
# Initialize pipeline instance
pipeline = cls()
pipeline._output_template = template

# Create model config based on input type
if hasattr(model, "invoke") or hasattr(model, "__call__"):
# Handle LangChain Chain
config = ModelConfig(
source=ModelSource.LANGCHAIN,
model=model,
task=task,
)
else:
# Convert string source to enum if needed
if isinstance(source, str):
source_lower = source.lower()

if source_lower == "langchain":
raise ValueError(
"LangChain models must be passed directly as chain objects, "
"not as string source identifiers."
)
pipeline._configure_output_templates(template, template_path)

try:
source = ModelSource(source_lower)
except ValueError:
supported = ", ".join(s.value for s in ModelSource)
raise ValueError(
f"Unsupported model source: {source}. "
f"Supported sources: {supported}"
)
config = ModelConfig(
source=ModelSource(source.lower()), model=model_id, task=task, kwargs=kwargs
)
pipeline._model_config = config
pipeline.configure_pipeline(config)

# Handle local paths vs model IDs
if isinstance(model, str) and (
model.startswith("./") or model.startswith("/")
):
path = Path(model)
config = ModelConfig(
source=source, model=path.name, path=path, task=task
)
else:
config = ModelConfig(source=source, model=model, task=task)
return pipeline

# Configure and return pipeline
config.kwargs = model_kwargs
@classmethod
def from_local_model(
cls,
path: Union[str, Path],
source: Union[str, ModelSource],
task: Optional[str] = None,
template: Optional[str] = None,
template_path: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> "BasePipeline":
"""Load pipeline from a local model path.
Args:
path (Union[str, Path]): Path to local model files/directory
source (Union[str, ModelSource]): Model source (e.g. "huggingface", "spacy")
task (Optional[str]): Task identifier for the model. Defaults to None.
**kwargs: Additional configuration options passed to the model. e.g. temperature, max_length, etc.
Returns:
BasePipeline: Configured pipeline instance.
Examples:
>>> # Load local HuggingFace model
>>> pipeline = Pipeline.from_local_model(
... "models/my_summarizer",
... source="huggingface",
... task="summarization",
... temperature=0.7
... )
>>>
>>> # Load local SpaCy model
>>> pipeline = Pipeline.from_local_model(
... "models/en_core_sci_md",
... source="spacy",
... disable=["parser"]
... )
"""
pipeline = cls()
pipeline._configure_output_templates(template, template_path)

path = Path(path)
config = ModelConfig(
source=ModelSource(source.lower()),
model=path.name,
path=path,
task=task,
kwargs=kwargs,
)
pipeline._model_config = config
pipeline.configure_pipeline(config)

Expand All @@ -230,16 +291,39 @@ def configure_pipeline(self, model_config: ModelConfig) -> None:
This method should be implemented by subclasses to add specific components
and configure the pipeline according to the given model configuration.
The configuration typically involves:
1. Setting up input/output connectors
2. Adding model components based on the model source
3. Adding any additional processing nodes
4. Configuring the pipeline stages and execution order
Args:
model_config (ModelConfig): Configuration object containing model source,
ID and optional path information.
model_config (ModelConfig): Configuration object containing:
- source: Model source (e.g. huggingface, spacy, langchain)
- model: Model identifier or path
- task: Optional task name (e.g. summarization, ner)
- path: Optional local path to model files
- kwargs: Additional model configuration parameters
Returns:
None
Raises:
NotImplementedError: If the method is not implemented by a subclass.
Example:
>>> def configure_pipeline(self, config: ModelConfig):
... # Add FHIR connector for input/output
... connector = FhirConnector()
... self.add_input(connector)
...
... # Add model component
... model = self.get_model_component(config)
... self.add_node(model, stage="processing")
...
... # Add output formatting
... self.add_node(OutputFormatter(), stage="formatting")
... self.add_output(connector)
"""
raise NotImplementedError("This method must be implemented by subclasses.")

Expand Down
22 changes: 20 additions & 2 deletions healthchain/pipeline/components/cdscardcreator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import json
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, Union
from jinja2 import Template
from pathlib import Path

from healthchain.pipeline.components.base import BaseComponent
from healthchain.io.containers import Document
Expand All @@ -23,6 +24,8 @@ class CdsCardCreator(BaseComponent[str]):
Args:
template (str, optional): Template string for card creation. If not provided,
uses a default template that creates an info card with the model output.
template_path (str, optional): Path to a template file. If not provided,
uses a default template that creates an info card with the model output.
static_content (str, optional): Static content to use instead of model output.
source (str, optional): Source framework to get model output from (e.g. "huggingface", "langchain").
task (str, optional): Task name to get model output from (e.g. "summarization", "chat").
Expand Down Expand Up @@ -76,12 +79,25 @@ class CdsCardCreator(BaseComponent[str]):
def __init__(
self,
template: Optional[str] = None,
template_path: Optional[Union[str, Path]] = None,
static_content: Optional[str] = None,
source: Optional[str] = None,
task: Optional[str] = None,
delimiter: Optional[str] = None,
default_source: Optional[Dict[str, Any]] = None,
):
# Load template from file or use string template
if template_path:
try:
template_path = Path(template_path)
if not template_path.exists():
raise FileNotFoundError(f"Template file not found: {template_path}")
with open(template_path) as f:
template = f.read()
except Exception as e:
logger.error(f"Error loading template from {template_path}: {str(e)}")
template = self.DEFAULT_TEMPLATE

self.template = Template(
template if template is not None else self.DEFAULT_TEMPLATE
)
Expand Down Expand Up @@ -122,7 +138,9 @@ def create_card(self, content: str) -> Card:
links=card_fields.get("links"),
)
except Exception as e:
raise ValueError(f"Error creating card: {str(e)}")
raise ValueError(
f"Error creating CDS card: Failed to render template or parse card fields: {str(e)}"
)

def __call__(self, doc: Document) -> Document:
"""
Expand Down
8 changes: 6 additions & 2 deletions healthchain/pipeline/medicalcodingpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ class MedicalCodingPipeline(BasePipeline, ModelRoutingMixin):
Examples:
>>> # Using with SpaCy/MedCAT
>>> pipeline = MedicalCodingPipeline.load("medcatlite", source="spacy")
>>> pipeline = MedicalCodingPipeline.from_model_id("medcatlite", source="spacy")
>>> cda_response = pipeline(documents)
>>>
>>> # Using with Hugging Face
>>> pipeline = MedicalCodingPipeline.load(
>>> pipeline = MedicalCodingPipeline.from_model_id(
... "bert-base-uncased",
... task="ner"
... )
>>> # Using with LangChain
>>> chain = ChatPromptTemplate.from_template("Extract medical codes: {text}") | ChatOpenAI()
>>> pipeline = MedicalCodingPipeline.load(chain)
>>>
>>> cda_response = pipeline(documents)
"""

Expand Down
5 changes: 3 additions & 2 deletions healthchain/pipeline/summarizationpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class SummarizationPipeline(BasePipeline, ModelRoutingMixin):
Examples:
>>> # Using with GPT model
>>> pipeline = SummarizationPipeline.load("gpt-4o", source="openai")
>>> pipeline = SummarizationPipeline.from_model_id("gpt-4o", source="openai")
>>> cds_response = pipeline(documents)
>>>
>>> # Using with Hugging Face
>>> pipeline = SummarizationPipeline.load(
>>> pipeline = SummarizationPipeline.from_model_id(
... "facebook/bart-large-cnn",
... task="summarization"
... )
Expand Down Expand Up @@ -53,6 +53,7 @@ def configure_pipeline(self, config: ModelConfig) -> None:
source=config.source.value,
task="summarization",
template=self._output_template,
template_path=self._output_template_path,
delimiter="\n",
),
stage="card-creation",
Expand Down

0 comments on commit 3977f75

Please sign in to comment.