Skip to content

Commit

Permalink
Change load method to use source parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferjiangkells committed Nov 6, 2024
1 parent ea6327d commit 3afdb6a
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 89 deletions.
110 changes: 29 additions & 81 deletions healthchain/pipeline/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
import logging
from abc import ABC, abstractmethod
from inspect import signature
Expand Down Expand Up @@ -96,114 +95,63 @@ def __repr__(self) -> str:
return f"[{components_repr}]"

@classmethod
def load(cls, model_path: str, **model_kwargs: Any) -> "BasePipeline":
def load(
cls,
model_id: str,
source: Union[str, ModelSource] = "huggingface",
**model_kwargs: Any,
) -> "BasePipeline":
"""
Load and configure a pipeline from a given model path.
Load and configure a pipeline from a given model.
Args:
model_path: Path or identifier for the model in the format "source/model_id" or a local path.
Sources can be "spacy" or "huggingface". For example:
- "spacy/en_core_sci_md" (remote/installed model)
- "huggingface/bert-base" (remote model)
- "./models/spacy/my_model" (local model)
model_id: Identifier for the model (e.g. "microsoft/omniparser", "en_core_sci_md")
source: Model source - either "spacy" or "huggingface". Defaults to "huggingface"
**model_kwargs: Additional configuration options passed to the model. Common options:
- task: Task name for Hugging Face models (e.g. "ner", "summarization")
- device: Device to load model on ("cpu", "cuda")
- batch_size: Batch size for inference
- max_length: Maximum sequence length
Returns:
BasePipeline: A configured pipeline instance with the specified model loaded
Raises:
ValueError: If model_path format is invalid or source is not supported
ImportError: If required model dependencies are not installed
Examples:
>>> # Load a summarization pipeline with GPT model
>>> pipeline = SummarizationPipeline.load("openai/gpt-4")
>>> pipeline = SummarizationPipeline.load("gpt-4", source="openai")
>>> # Load NER pipeline with SpaCy model
>>> pipeline = NERPipeline.load("spacy/en_core_sci_md")
>>> pipeline = NERPipeline.load("en_core_sci_md", source="spacy")
>>> # Load classification pipeline with BERT, specifying task
>>> pipeline = ClassificationPipeline.load(
... "huggingface/bert-base",
... "microsoft/omniparser", # HuggingFace is default source
... task="sequence-classification"
... )
>>> # Load custom pipeline from local SpaCy model
>>> pipeline = CustomPipeline.load("./models/spacy/my_model")
"""
pipeline = cls()
config = cls._parse_model_path(model_path)
config.config = model_kwargs
pipeline._model_config = config
pipeline.configure_pipeline(config)
return pipeline

@staticmethod
def _parse_model_path(model_path: str) -> ModelConfig:
"""
Parse model path to determine source and model ID.
This method parses a model path string to extract the model source and ID.
It handles both local paths and remote model identifiers.
Args:
model_path (str): Path or identifier for the model. Can be:
- Local path starting with "./" or "/" (e.g. "./models/spacy/my_model")
- Remote identifier in format "source/model_id" (e.g. "spacy/en_core_sci_md")
Returns:
ModelConfig: Configuration object containing:
- source: ModelSource enum value (e.g. ModelSource.SPACY)
- model_id: Name/ID of the model
- path: Optional Path object for local models
Raises:
ValueError: If model_path format is invalid or source is not supported
For local paths, if source cannot be determined from path
For remote paths, if format is not "source/model_id"
"""
# Handle local paths
if model_path.startswith("./") or model_path.startswith("/"):
path = Path(model_path)
# Match spacy or huggingface (case insensitive) anywhere in the path
source_match = re.search(r"(spacy|huggingface)", str(path), re.IGNORECASE)

if not source_match:
# Convert string source to enum if needed
if isinstance(source, str):
try:
source = ModelSource(source.lower())
except ValueError:
raise ValueError(
"Local models are only supported for Spacy and Huggingface and must be in folder containing the name of the model library. "
f"No valid source found in path: {path}"
f"Unsupported model source: {source}. "
f"Supported sources: {', '.join(s.value for s in ModelSource)}"
)

source_name = source_match.group(1).lower()
source = ModelSource(source_name)
return ModelConfig(source=source, model_id=path.name, path=path)

# Handle remote sources
pattern = r"^(?P<source>[a-zA-Z]+)/(?P<model_id>.+)$"
match = re.match(pattern, model_path)

if not match:
raise ValueError(
f"Invalid model path format: {model_path}. "
"Expected format: 'source/model_id' (e.g., 'spacy/en_core_sci_md')"
)

source_name = match.group("source").lower()
model_id = match.group("model_id")

try:
source = ModelSource(source_name)
except ValueError:
raise ValueError(
f"Unsupported model source: {source_name}. "
f"Supported sources: {', '.join(s.value for s in ModelSource)}"
)
# Handle local paths
if model_id.startswith("./") or model_id.startswith("/"):
path = Path(model_id)
config = ModelConfig(source=source, model_id=path.name, path=path)
else:
config = ModelConfig(source=source, model_id=model_id)

return ModelConfig(source=source, model_id=model_id)
config.config = model_kwargs
pipeline._model_config = config
pipeline.configure_pipeline(config)
return pipeline

@abstractmethod
def configure_pipeline(self, model_config: ModelConfig) -> None:
Expand Down
2 changes: 1 addition & 1 deletion healthchain/pipeline/medicalcodingpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class MedicalCodingPipeline(BasePipeline):
Examples:
>>> # Using with SpaCy/MedCAT
>>> pipeline = MedicalCodingPipeline.load("medcatlite")
>>> pipeline = MedicalCodingPipeline.load("medcatlite", source="spacy")
>>>
>>> # Using with Hugging Face
>>> pipeline = MedicalCodingPipeline.load(
Expand Down
2 changes: 1 addition & 1 deletion healthchain/pipeline/summarizationpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SummarizationPipeline(BasePipeline):
Examples:
>>> # Using with GPT model
>>> pipeline = SummarizationPipeline.load("openai/gpt-4")
>>> pipeline = SummarizationPipeline.load("gpt-4o", source="openai")
>>>
>>> # Using with Hugging Face
>>> pipeline = SummarizationPipeline.load(
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/prebuilt/test_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_summarization_pipeline(mock_cds_fhir_connector, mock_llm, test_cds_requ
"healthchain.pipeline.summarizationpipeline.ModelRouter.get_component", mock_llm
):
# This also doesn't do anything yet
pipeline = SummarizationPipeline.load("huggingface/llama3")
pipeline = SummarizationPipeline.load("llama3")

# Process the request through the pipeline
cds_response = pipeline(test_cds_request)
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_full_summarization_pipeline_integration(mock_llm, test_cds_request):
with patch(
"healthchain.pipeline.summarizationpipeline.ModelRouter.get_component", mock_llm
):
pipeline = SummarizationPipeline.load("huggingface/llama3")
pipeline = SummarizationPipeline.load("llama3")

cds_response = pipeline(test_cds_request)
print(cds_response)
Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def invalid_output_component(data: DataContainer) -> DataContainer:
def test_pipeline_class_and_representation(mock_basic_pipeline):
pipeline = Pipeline()
assert hasattr(pipeline, "configure_pipeline")
pipeline.configure_pipeline("spacy/dummy_path") # Should not raise any exception
pipeline.configure_pipeline("dummy_path") # Should not raise any exception

mock_basic_pipeline.add_node(mock_component, name="comp1")
mock_basic_pipeline.add_node(mock_component, name="comp2")
Expand All @@ -257,7 +257,7 @@ def test_pipeline_class_and_representation(mock_basic_pipeline):
assert "comp1" in repr_string
assert "comp2" in repr_string

loaded_pipeline = Pipeline.load("spacy/dummy_path")
loaded_pipeline = Pipeline.load("dummy_path")
assert isinstance(loaded_pipeline, Pipeline)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def validated_component(data: DataContainer) -> DataContainer:
def test_pipeline_class_and_representation(basic_pipeline):
pipeline = Pipeline()
assert hasattr(pipeline, "configure_pipeline")
pipeline.configure_pipeline("spacy/dummy_path") # Should not raise any exception
pipeline.configure_pipeline("dummy_path") # Should not raise any exception

basic_pipeline.add_node(mock_component, name="comp1")
basic_pipeline.add_node(mock_component, name="comp2")
Expand All @@ -212,5 +212,5 @@ def test_pipeline_class_and_representation(basic_pipeline):
assert "comp1" in repr_string
assert "comp2" in repr_string

loaded_pipeline = Pipeline.load("spacy/dummy_path")
loaded_pipeline = Pipeline.load("dummy_path")
assert isinstance(loaded_pipeline, Pipeline)

0 comments on commit 3afdb6a

Please sign in to comment.