Skip to content

Commit

Permalink
Pass kwargs to integration components
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferjiangkells committed Nov 7, 2024
1 parent 4086a8c commit d9f2a6d
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 27 deletions.
32 changes: 25 additions & 7 deletions healthchain/pipeline/components/integrations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
from healthchain.io.containers import Document
from healthchain.pipeline.components.base import BaseComponent
from healthchain.models.data import ProblemConcept
Expand All @@ -15,6 +16,7 @@ class SpacyNLP(BaseComponent[str]):
Args:
path_to_pipeline (str): The path or name of the spaCy model to load.
Can be a model name like 'en_core_web_sm' or path to saved model.
**kwargs: Additional configuration options passed to spacy.load
Raises:
ImportError: If spaCy or the specified model is not installed.
Expand All @@ -24,12 +26,13 @@ class SpacyNLP(BaseComponent[str]):
>>> doc = component(doc) # Processes doc.data with spaCy
"""

def __init__(self, path_to_pipeline: str):
# TODO: might need to store model specific info
def __init__(self, path_to_pipeline: str, **kwargs: Any):
import spacy

try:
nlp = spacy.load(path_to_pipeline)
nlp = spacy.load(path_to_pipeline, **kwargs)
except TypeError as e:
raise TypeError(f"Invalid kwargs for spacy.load: {str(e)}")
except Exception as e:
raise ImportError(
f"Could not load spaCy model {path_to_pipeline}! "
Expand Down Expand Up @@ -84,6 +87,7 @@ class HFTransformer(BaseComponent[str]):
Must be a valid task supported by the Hugging Face pipeline API.
model (str): The model identifier or path to use for the task.
Can be a model ID from the Hugging Face Hub or a local path.
**kwargs: Additional configuration options passed to the transformers.pipeline API
Raises:
ImportError: If the transformers package is not installed.
Expand All @@ -96,7 +100,7 @@ class HFTransformer(BaseComponent[str]):
>>> doc = component(doc) # Runs sentiment analysis on doc.data
"""

def __init__(self, task, model):
def __init__(self, task: str, model: str, **kwargs: Any):
try:
from transformers import pipeline
except ImportError:
Expand All @@ -105,7 +109,13 @@ def __init__(self, task, model):
"`pip install transformers`"
)

nlp = pipeline(task=task, model=model)
try:
nlp = pipeline(task=task, model=model, **kwargs)
except TypeError as e:
raise TypeError(f"Invalid kwargs for transformers.pipeline: {str(e)}")
except Exception as e:
raise ValueError(f"Error initializing transformer pipeline: {str(e)}")

self.nlp = nlp
self.task = task

Expand All @@ -126,6 +136,7 @@ class LangChainLLM(BaseComponent[str]):
Args:
chain: The LangChain chain to run on the document text.
Can be any chain object from the LangChain library.
**kwargs: Additional parameters to pass to the chain's invoke method
Example:
>>> from langchain.chains import LLMChain
Expand All @@ -134,10 +145,17 @@ class LangChainLLM(BaseComponent[str]):
>>> doc = component(doc) # Runs the chain on doc.data
"""

def __init__(self, chain):
def __init__(self, chain: Any, **kwargs: Any):
self.chain = chain
self.kwargs = kwargs

def __call__(self, doc: Document) -> Document:
output = self.chain.invoke(doc.data)
try:
output = self.chain.invoke(doc.data, **self.kwargs)
except TypeError as e:
raise TypeError(f"Invalid kwargs for chain.invoke: {str(e)}")
except Exception as e:
raise ValueError(f"Error during chain invocation: {str(e)}")

doc.models.add_output("langchain", "chain_output", output)
return doc
173 changes: 153 additions & 20 deletions tests/test_pipeline_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,42 @@ def sample_document():


@pytest.mark.parametrize(
"component_class,mock_module",
"component_class,mock_module,kwargs,expected_kwargs",
[
(SpacyNLP, "spacy.load"),
(
SpacyNLP,
"spacy.load",
{"disable": ["ner", "parser"]},
{"disable": ["ner", "parser"]},
),
pytest.param(
HFTransformer,
"transformers.pipeline",
{"device": "cuda", "batch_size": 32},
{
"task": "dummy_task",
"model": "dummy_model",
"device": "cuda",
"batch_size": 32,
},
marks=pytest.mark.skipif(
not transformers_installed, reason="transformers package not installed"
),
),
],
)
def test_component_initialization(component_class, mock_module):
def test_component_initialization(
component_class, mock_module, kwargs, expected_kwargs
):
with patch(mock_module) as mock:
mock_instance = Mock()
mock.return_value = mock_instance
if component_class == SpacyNLP:
component = component_class("dummy_path")
component = component_class("dummy_path", **kwargs)
else:
component = component_class("dummy_task", "dummy_model")
mock.assert_called_once()
component = component_class("dummy_task", "dummy_model", **kwargs)

mock.assert_called_once_with("dummy_path", **expected_kwargs)
assert hasattr(component, "nlp")
assert component.nlp == mock_instance

Expand All @@ -47,9 +62,23 @@ def test_spacy_component(sample_document):
mock_instance = MagicMock(items=[])
mock_instance.__iter__.return_value = []
mock_load.return_value = mock_instance
component = SpacyNLP("en_core_web_sm")
result = component(sample_document)
assert result.nlp.get_spacy_doc()

# Test with and without kwargs
test_cases = [
({"disable": ["ner", "parser"]}, "with kwargs"),
({}, "without kwargs"),
]

for kwargs, case in test_cases:
component = SpacyNLP("en_core_web_sm", **kwargs)
result = component(sample_document)

# Verify kwargs were passed correctly
expected_args = {"disable": ["ner", "parser"]} if kwargs else {}
mock_load.assert_called_with("en_core_web_sm", **expected_args)

assert result.nlp.get_spacy_doc(), f"SpacyNLP failed {case}"
mock_load.reset_mock()


@pytest.mark.skipif(
Expand All @@ -59,21 +88,125 @@ def test_huggingface_component(sample_document):
with patch("transformers.pipeline") as mock_pipeline:
mock_instance = Mock()
mock_pipeline.return_value = mock_instance
component = HFTransformer(
"sentiment-analysis", "distilbert-base-uncased-finetuned-sst-2-english"
)
result = component(sample_document)
assert result.models.get_output("huggingface", "sentiment-analysis")

# Test with and without kwargs
test_cases = [
(
{
"device": "cuda",
"batch_size": 32,
"max_length": 512,
"truncation": True,
},
"with kwargs",
),
({}, "without kwargs"),
]

for kwargs, case in test_cases:
component = HFTransformer(
"sentiment-analysis",
"distilbert-base-uncased-finetuned-sst-2-english",
**kwargs,
)
result = component(sample_document)

# Verify kwargs were passed correctly
expected_kwargs = {
"task": "sentiment-analysis",
"model": "distilbert-base-uncased-finetuned-sst-2-english",
**kwargs,
}
mock_pipeline.assert_called_once_with(**expected_kwargs)

assert result.models.get_output(
"huggingface", "sentiment-analysis"
), f"HFTransformer failed {case}"
mock_pipeline.reset_mock()


def test_langchain_component(sample_document):
mock_chain = Mock()
mock_chain.invoke.return_value = "mocked chain output"

component = LangChainLLM(mock_chain)
result = component(sample_document)
# Test with and without kwargs
test_cases = [
(
{"temperature": 0.7, "max_tokens": 100, "stop_sequences": ["END"]},
"with kwargs",
),
({}, "without kwargs"),
]

for kwargs, case in test_cases:
component = LangChainLLM(mock_chain, **kwargs)
result = component(sample_document)

# Verify kwargs were passed correctly
mock_chain.invoke.assert_called_once_with(sample_document.data, **kwargs)
assert (
result.models.get_output("langchain", "chain_output")
== "mocked chain output"
), f"LangChainLLM failed {case}"
mock_chain.invoke.reset_mock()


# Test error handling
@pytest.mark.parametrize(
"component_class,args,kwargs,expected_error,expected_message",
[
(
SpacyNLP,
["en_core_web_sm"],
{"invalid_kwarg": "value"},
TypeError,
"Invalid kwargs for spacy.load",
),
pytest.param(
HFTransformer,
["sentiment-analysis", "model"],
{"invalid_kwarg": "value"},
TypeError,
"Invalid kwargs for transformers.pipeline",
marks=pytest.mark.skipif(
not transformers_installed, reason="transformers package not installed"
),
),
(
LangChainLLM,
[Mock()], # Mock chain
{"invalid_kwarg": "value"},
TypeError,
"Invalid kwargs for chain.invoke",
),
],
)
def test_component_invalid_kwargs(
component_class, args, kwargs, expected_error, expected_message
):
if component_class == SpacyNLP:
with patch("spacy.load") as mock_spacy:
mock_spacy.side_effect = TypeError(
"got an unexpected keyword argument 'invalid_kwarg'"
)
with pytest.raises(expected_error) as exc_info:
component_class(*args, **kwargs)

elif component_class == HFTransformer:
with patch("transformers.pipeline") as mock_transformers:
mock_transformers.side_effect = TypeError(
"got an unexpected keyword argument 'invalid_kwarg'"
)
with pytest.raises(expected_error) as exc_info:
component_class(*args, **kwargs)

else: # LangChainLLM
mock_chain = args[0]
mock_chain.invoke.side_effect = TypeError(
"got an unexpected keyword argument 'invalid_kwarg'"
)
with pytest.raises(expected_error) as exc_info:
component = component_class(*args, **kwargs)
component(Document("test"))

mock_chain.invoke.assert_called_once_with(sample_document.data)
assert (
result.models.get_output("langchain", "chain_output") == "mocked chain output"
)
assert expected_message in str(exc_info.value)

0 comments on commit d9f2a6d

Please sign in to comment.