Skip to content

Commit

Permalink
Merge branch 'langchain-ai:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
baur-krykpayev authored Jun 12, 2024
2 parents c124e49 + 8203c1f commit 86db74f
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 30 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/check_new_docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
name: Integration docs lint

on:
push:
branches: [master]
pull_request:

# If another push to the same PR or branch happens while this workflow is still running,
# cancel the earlier run in favor of the next run.
#
# There's no point in testing an outdated version of the code. GitHub only allows
# a limited number of job runners to be active at the same time, so it's better to cancel
# pointless jobs early so that more useful jobs can run sooner.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- id: files
uses: Ana06/[email protected]
- name: Check new docs
run: |
python docs/scripts/check_templates.py ${{ steps.files.outputs.added }}
43 changes: 43 additions & 0 deletions docs/scripts/check_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import re
import sys
from pathlib import Path
from typing import Union

CURR_DIR = Path(__file__).parent.absolute()

CHAT_MODEL_HEADERS = (
"## Overview",
"### Integration details",
"### Model features",
"## Setup",
"## Instantiation",
"## Invocation",
"## Chaining",
"## API reference",
)
CHAT_MODEL_REGEX = r".*".join(CHAT_MODEL_HEADERS)


def check_chat_model(path: Path) -> None:
with open(path, "r") as f:
doc = f.read()
if not re.search(CHAT_MODEL_REGEX, doc, re.DOTALL):
raise ValueError(
f"Document {path} does not match the ChatModel Integration page template. "
f"Please see https://github.com/langchain-ai/langchain/issues/22296 for "
f"instructions on how to correctly format a ChatModel Integration page."
)


def main(*new_doc_paths: Union[str, Path]) -> None:
for path in new_doc_paths:
path = Path(path).resolve().absolute()
if CURR_DIR.parent / "docs" / "integrations" / "chat" in path.parents:
print(f"Checking chat model page {path}")
check_chat_model(path)
else:
continue


if __name__ == "__main__":
main(*sys.argv[1:])
31 changes: 23 additions & 8 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ async def _completion_with_retry(**kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs)


def _convert_delta_to_message_chunk(
_delta: Dict, default_class: Type[BaseMessageChunk]
def _convert_chunk_to_message_chunk(
chunk: Dict, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
_delta = chunk["choices"][0]["delta"]
role = _delta.get("role")
content = _delta.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
Expand Down Expand Up @@ -216,10 +217,19 @@ def _convert_delta_to_message_chunk(
pass
else:
tool_call_chunks = []
if token_usage := chunk.get("usage"):
usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
else:
usage_metadata = None
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
usage_metadata=usage_metadata,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
Expand Down Expand Up @@ -484,14 +494,21 @@ def _generate(

def _create_chat_result(self, response: Dict) -> ChatResult:
generations = []
token_usage = response.get("usage", {})
for res in response["choices"]:
finish_reason = res.get("finish_reason")
message = _convert_mistral_chat_message_to_message(res["message"])
if token_usage and isinstance(message, AIMessage):
message.usage_metadata = {
"input_tokens": token_usage.get("prompt_tokens", 0),
"output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0),
}
gen = ChatGeneration(
message=_convert_mistral_chat_message_to_message(res["message"]),
message=message,
generation_info={"finish_reason": finish_reason},
)
generations.append(gen)
token_usage = response.get("usage", {})

llm_output = {"token_usage": token_usage, "model": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
Expand Down Expand Up @@ -525,8 +542,7 @@ def _stream(
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
Expand All @@ -552,8 +568,7 @@ async def _astream(
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
new_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
gen_chunk = ChatGenerationChunk(message=new_chunk)
Expand Down
15 changes: 6 additions & 9 deletions libs/partners/mistralai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/partners/mistralai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.2.0,<0.3"
langchain-core = ">=0.2.2,<0.3"
tokenizers = ">=0.15.1,<1"
httpx = ">=0.25.2,<1"
httpx-sse = ">=0.3.1,<1"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Test ChatMistral chat model."""

import json
from typing import Any
from typing import Any, Optional

from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessageChunk,
HumanMessage,
)
from langchain_core.pydantic_v1 import BaseModel
Expand All @@ -25,8 +26,28 @@ async def test_astream() -> None:
"""Test streaming tokens from ChatMistralAI."""
llm = ChatMistralAI()

full: Optional[BaseMessageChunk] = None
chunks_with_token_counts = 0
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
full = token if full is None else full + token
if token.usage_metadata is not None:
chunks_with_token_counts += 1
if chunks_with_token_counts != 1:
raise AssertionError(
"Expected exactly one chunk with token counts. "
"AIMessageChunk aggregation adds counts. Check that "
"this is behaving properly."
)
assert isinstance(full, AIMessageChunk)
assert full.usage_metadata is not None
assert full.usage_metadata["input_tokens"] > 0
assert full.usage_metadata["output_tokens"] > 0
assert (
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
== full.usage_metadata["total_tokens"]
)


async def test_abatch() -> None:
Expand Down
11 changes: 0 additions & 11 deletions libs/partners/mistralai/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,3 @@ def chat_model_params(self) -> dict:
"model": "mistral-large-latest",
"temperature": 0,
}

@pytest.mark.xfail(reason="Not implemented.")
def test_usage_metadata(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_usage_metadata(
chat_model_class,
chat_model_params,
)

0 comments on commit 86db74f

Please sign in to comment.