Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't attempt batching with InstructLab's llama-cpp-python #358

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Third Party
from datasets import Dataset
from tqdm import tqdm
import httpx
import openai

# Local
Expand Down Expand Up @@ -40,16 +41,31 @@ def server_supports_batched(client, model_id: str) -> bool:
supported = getattr(client, "server_supports_batched", None)
if supported is not None:
return supported
try:
# Make a test call to the server to determine whether it supports
# multiple input prompts per request and also the n parameter
response = client.completions.create(
model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3
)
# Number outputs should be 2 * 3 = 6
supported = len(response.choices) == 6
except openai.InternalServerError:
supported = False
# Start looking for InstructLab's default llama-cpp-python so we
# can avoid throwing an assertion error in the server, as
# llama-cpp-python does not like us explicitly testing batches
if "/v1" in client.base_url.path:
try:
# The root (without /v1) will have InstructLab's welcome
# message
http_res = client.get("../", cast_to=httpx.Response)
if "Hello from InstructLab" in http_res.text:
# The server is llama-cpp-python, so disable batching
supported = False
except openai.APIStatusError:
# The server is not InstructLab's llama-cpp-python
pass
if supported is None:
try:
# Make a test call to the server to determine whether it supports
# multiple input prompts per request and also the n parameter
response = client.completions.create(
model=model_id, prompt=["test1", "test2"], max_tokens=1, n=3
)
# Number outputs should be 2 * 3 = 6
supported = len(response.choices) == 6
except openai.InternalServerError:
supported = False
client.server_supports_batched = supported
logger.info(f"LLM server supports batched inputs: {client.server_supports_batched}")
return supported
Expand Down
56 changes: 55 additions & 1 deletion tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

# Third Party
from datasets import Dataset, Features, Value
from httpx import URL
from openai import InternalServerError, NotFoundError, OpenAI

# First Party
from src.instructlab.sdg.llmblock import LLMBlock
from src.instructlab.sdg.llmblock import LLMBlock, server_supports_batched


class TestLLMBlockModelPrompt(unittest.TestCase):
Expand Down Expand Up @@ -103,3 +105,55 @@ def test_max_num_tokens_override(self, mock_load_config):
)
num_tokens = block.gen_kwargs["max_tokens"]
assert num_tokens == 512

def test_server_supports_batched_llama_cpp(self):
resp_text = """{"message":"Hello from InstructLab! Visit us at https://instructlab.ai"}"""
mock_client = MagicMock()
mock_client.server_supports_batched = None
mock_client.base_url = URL("http://localhost:8000/v1")
mock_client.get = MagicMock()
mock_client.get.return_value = MagicMock()
mock_client.get().text = resp_text
self.mock_ctx.client = mock_client
supports_batched = server_supports_batched(self.mock_ctx.client, "my-model")
assert not supports_batched

def test_server_supports_batched_other_llama_cpp(self):
resp_text = "another server"
mock_client = MagicMock()
mock_client.server_supports_batched = None
mock_client.base_url = URL("http://localhost:8000/v1")
mock_client.get = MagicMock()
mock_client.get.return_value = MagicMock()
mock_client.get().text = resp_text
mock_completion = MagicMock()
mock_completion.create = MagicMock()
mock_completion.create.side_effect = InternalServerError(
"mock error",
response=MagicMock(),
body=MagicMock(),
)
mock_client.completions = mock_completion
self.mock_ctx.client = mock_client
supports_batched = server_supports_batched(self.mock_ctx.client, "my-model")
assert not supports_batched

def test_server_supports_batched_vllm(self):
mock_client = MagicMock()
mock_client.server_supports_batched = None
mock_client.base_url = URL("http://localhost:8000/v1")
mock_client.get = MagicMock()
mock_client.get.side_effect = NotFoundError(
"mock error",
response=MagicMock(),
body=MagicMock(),
)
mock_completion_resp = MagicMock()
mock_completion_resp.choices = ["a", "b", "c", "d", "e", "f"]
mock_completion = MagicMock()
mock_completion.create = MagicMock()
mock_completion.create.return_value = mock_completion_resp
mock_client.completions = mock_completion
self.mock_ctx.client = mock_client
supports_batched = server_supports_batched(self.mock_ctx.client, "my-model")
assert supports_batched
Loading