Skip to content

Commit

Permalink
feat(py): suppoer traceable(process_chunk=) (#1537)
Browse files Browse the repository at this point in the history
allow processing of chunks for tracing when using `traceable` with a
streaming function
  • Loading branch information
baskaryan authored Feb 25, 2025
1 parent 1818426 commit ff835ce
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 5 deletions.
29 changes: 25 additions & 4 deletions python/langsmith/run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def ensure_traceable(
project_name: Optional[str] = None,
process_inputs: Optional[Callable[[dict], dict]] = None,
process_outputs: Optional[Callable[..., dict]] = None,
process_chunk: Optional[Callable] = None,
) -> SupportsLangsmithExtra[P, R]:
"""Ensure that a function is traceable."""
if is_traceable_function(func):
Expand All @@ -191,6 +192,7 @@ def ensure_traceable(
project_name=project_name,
process_inputs=process_inputs,
process_outputs=process_outputs,
process_chunk=process_chunk,
)(func)


Expand Down Expand Up @@ -285,6 +287,7 @@ def traceable(
project_name: Optional[str] = None,
process_inputs: Optional[Callable[[dict], dict]] = None,
process_outputs: Optional[Callable[..., dict]] = None,
process_chunk: Optional[Callable] = None,
_invocation_params_fn: Optional[Callable[[dict], dict]] = None,
dangerously_allow_filesystem: bool = False,
) -> Callable[[Callable[P, R]], SupportsLangsmithExtra[P, R]]: ...
Expand Down Expand Up @@ -473,6 +476,7 @@ def manual_extra_function(x):
project_name=kwargs.pop("project_name", None),
run_type=run_type,
process_inputs=kwargs.pop("process_inputs", None),
process_chunk=kwargs.pop("process_chunk", None),
invocation_params_fn=kwargs.pop("_invocation_params_fn", None),
dangerously_allow_filesystem=kwargs.pop("dangerously_allow_filesystem", False),
)
Expand Down Expand Up @@ -583,6 +587,7 @@ async def async_generator_wrapper(
),
accepts_context=accepts_context,
results=results,
process_chunk=container_input.get("process_chunk"),
):
yield item
except BaseException as e:
Expand Down Expand Up @@ -659,6 +664,7 @@ def generator_wrapper(
run_container,
is_llm_run=run_type == "llm",
results=results,
process_chunk=container_input.get("process_chunk"),
)

if function_return is not None:
Expand Down Expand Up @@ -1226,6 +1232,7 @@ class _ContainerInput(TypedDict, total=False):
project_name: Optional[str]
run_type: ls_client.RUN_TYPE_T
process_inputs: Optional[Callable[[dict], dict]]
process_chunk: Optional[Callable]
invocation_params_fn: Optional[Callable[[dict], dict]]
dangerously_allow_filesystem: Optional[bool]

Expand Down Expand Up @@ -1565,21 +1572,26 @@ def _process_iterator(
is_llm_run: bool,
# Results is mutated
results: List[Any],
process_chunk: Optional[Callable],
) -> Generator[T, None, Any]:
try:
while True:
item: T = run_container["context"].run(next, generator) # type: ignore[arg-type]
if process_chunk:
traced_item = process_chunk(item)
else:
traced_item = item
if is_llm_run and run_container["new_run"]:
run_container["new_run"].add_event(
{
"name": "new_token",
"time": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
"kwargs": {"token": item},
"kwargs": {"token": traced_item},
}
)
results.append(item)
results.append(traced_item)
yield item
except StopIteration as e:
return e.value
Expand All @@ -1592,6 +1604,7 @@ async def _process_async_iterator(
is_llm_run: bool,
accepts_context: bool,
results: List[Any],
process_chunk: Optional[Callable],
) -> AsyncGenerator[T, None]:
try:
while True:
Expand All @@ -1604,17 +1617,21 @@ async def _process_async_iterator(
# Python < 3.11
with tracing_context(**get_tracing_context(run_container["context"])):
item = await aitertools.py_anext(generator)
if process_chunk:
traced_item = process_chunk(item)
else:
traced_item = item
if is_llm_run and run_container["new_run"]:
run_container["new_run"].add_event(
{
"name": "new_token",
"time": datetime.datetime.now(
datetime.timezone.utc
).isoformat(),
"kwargs": {"token": item},
"kwargs": {"token": traced_item},
}
)
results.append(item)
results.append(traced_item)
yield item
except StopAsyncIteration:
pass
Expand Down Expand Up @@ -1689,6 +1706,7 @@ def __init__(
stream: Iterator[T],
trace_container: _TraceableContainer,
reduce_fn: Optional[Callable] = None,
process_chunk: Optional[Callable] = None,
):
super().__init__(
stream=stream, trace_container=trace_container, reduce_fn=reduce_fn
Expand All @@ -1699,6 +1717,7 @@ def __init__(
self.__ls_trace_container__,
is_llm_run=self.__is_llm_run__,
results=self.__ls_accumulated_output__,
process_chunk=process_chunk,
)

def __next__(self) -> T:
Expand Down Expand Up @@ -1735,6 +1754,7 @@ def __init__(
stream: AsyncIterator[T],
trace_container: _TraceableContainer,
reduce_fn: Optional[Callable] = None,
process_chunk: Optional[Callable] = None,
):
super().__init__(
stream=stream, trace_container=trace_container, reduce_fn=reduce_fn
Expand All @@ -1746,6 +1766,7 @@ def __init__(
is_llm_run=self.__is_llm_run__,
accepts_context=aitertools.asyncio_accepts_context(),
results=self.__ls_accumulated_output__,
process_chunk=process_chunk,
)

async def _aend_trace(self, error: Optional[BaseException] = None):
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langsmith"
version = "0.3.10"
version = "0.3.11"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
authors = ["LangChain <[email protected]>"]
license = "MIT"
Expand Down
50 changes: 50 additions & 0 deletions python/tests/unit_tests/test_run_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,3 +1833,53 @@ def foo(x: dict = {}) -> str:
expected = {"x": {"b": "c"}}
actual = _get_inputs_and_attachments_safe(inspect.signature(foo), x={"b": "c"})[0]
assert expected == actual


def test_traceable_iterator_process_chunk(mock_client: Client) -> None:
with tracing_context(enabled=True):

@traceable(client=mock_client, process_chunk=lambda x: x["val"])
def my_iterator_fn(a, b, d, **kwargs) -> Any:
assert kwargs == {"e": 5}
for i in range(a + b + d):
yield {"val": i, "squared": i**2}

expected = [{"val": x, "squared": x**2} for x in range(6)]
genout = my_iterator_fn(1, 2, 3, e=5)
results = list(genout)
assert results == expected
# check the mock_calls
mock_calls = _get_calls(mock_client, minimum=1)
assert 1 <= len(mock_calls) <= 2

call = mock_calls[0]
assert call.args[0] == "POST"
assert call.args[1].startswith("https://api.smith.langchain.com")
body = json.loads(mock_calls[0].kwargs["data"])
assert body["post"]
assert body["post"][0]["outputs"]["output"] == list(range(6))


async def test_traceable_async_iterator_process_chunk(mock_client: Client) -> None:
with tracing_context(enabled=True):

@traceable(client=mock_client, process_chunk=lambda x: x["val"])
async def my_iterator_fn(a, b, d, **kwargs) -> Any:
assert kwargs == {"e": 5}
for i in range(a + b + d):
yield {"val": i, "squared": i**2}

expected = [{"val": x, "squared": x**2} for x in range(6)]
genout = my_iterator_fn(1, 2, 3, e=5)
results = [x async for x in genout]
assert results == expected
# check the mock_calls
mock_calls = _get_calls(mock_client, minimum=1)
assert 1 <= len(mock_calls) <= 2

call = mock_calls[0]
assert call.args[0] == "POST"
assert call.args[1].startswith("https://api.smith.langchain.com")
body = json.loads(mock_calls[0].kwargs["data"])
assert body["post"]
assert body["post"][0]["outputs"]["output"] == list(range(6))

0 comments on commit ff835ce

Please sign in to comment.