Skip to content

Commit

Permalink
Merge pull request #243 from defenseunicorns/refresh-sdk
Browse files Browse the repository at this point in the history
Spike at new CLI and SDK implementation
  • Loading branch information
gerred authored Nov 27, 2023
2 parents c046e9a + 6f7b294 commit 0cf8e37
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 5 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ dependencies = [
"grpcio-health-checking >=1.58.0",
"confz >= 2.0.1",
"pydantic < 2",
"click >= 8.1.7",
]

[project.scripts]
leapfrogai = "leapfrogai.cli:cli"

[project.optional-dependencies]
tests = ["pytest"]
build = ["build", "hatchling", "twine"]
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --output-file=leapfrogai/requirements.txt pyproject.toml
# pip-compile --output-file=requirements.txt pyproject.toml
#
click==8.1.7
# via leapfrogai (pyproject.toml)
confz==2.0.1
# via leapfrogai (pyproject.toml)
grpcio==1.58.0
Expand Down
39 changes: 39 additions & 0 deletions src/leapfrogai/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import click
import sys
import asyncio

from leapfrogai.utils import import_app
from leapfrogai.serve import serve

@click.argument("app", envvar="LEAPFROGAI_APP")
@click.option(
"--host",
type=str,
default="0.0.0.0",
help="Bind socket to this host.",
show_default=True,
)
@click.option(
"--port",
type=int,
default=50051,
help="Bind socket to this port. If 0, an available port will be picked.",
show_default=True,
)
@click.option(
"--app-dir",
type=str,
default="",
help="Path to the directory containing the app module. Defaults to the current directory.",
show_default=True,
)
@click.command()
def cli(app: str, host: str, port: str, app_dir: str):
sys.path.insert(0, app_dir)
"""Leapfrog AI CLI"""
app = import_app(app)
asyncio.run(serve(app(), host, port))


if __name__ == "__main__":
cli() # pragma: no cover
3 changes: 2 additions & 1 deletion src/leapfrogai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class ModelConfig(BaseConfig):
class BackendConfig(BaseConfig):
name: str | None = None
model: ModelConfig | None = None
max_seq_len: int = 2048
max_context_length: int = 2048
stop_tokens: list[str] | None = None
prompt_format: PromptFormat | None = None
defaults: LLMDefaults = LLMDefaults()

Expand Down
4 changes: 4 additions & 0 deletions src/leapfrogai/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Adapted from Gunicorn's errors module.

class AppImportError(Exception):
""" Exception raised when loading an application """
123 changes: 123 additions & 0 deletions src/leapfrogai/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Any, Generator, List

from pydantic import BaseModel

from leapfrogai import (
BackendConfig,
ChatCompletionChoice,
ChatCompletionRequest,
ChatCompletionResponse,
ChatItem,
ChatRole,
CompletionChoice,
CompletionRequest,
CompletionResponse,
GrpcContext,
)


class GenerationConfig(BaseModel):
max_new_tokens: int
temperature: float
top_k: int
top_p: float
do_sample: bool
n: int
stop: List[str]
repetition_penalty: float
presence_penalty: float
frequency_penalty: float | None = None
best_of: str
logit_bias: dict[str, int]
return_full_text: bool
truncate: int
typical_p: float
watermark: bool
seed: int


def LLM(_cls):
if not hasattr(_cls, "generate"):
raise ValueError("LLM class requires a generate method")

class NewClass(_cls):
config: BackendConfig

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = BackendConfig()

def _build_gen_stream(
self, prompt: str, request: ChatCompletionRequest | CompletionRequest
) -> Generator[str, Any, Any]:
config = GenerationConfig(
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
do_sample=request.do_sample,
n=request.n,
stop=list(request.stop),
repetition_penalty=request.repetition_penalty,
presence_penalty=request.presence_penalty,
best_of=request.best_of,
logit_bias=request.logit_bias,
return_full_text=request.return_full_text,
truncate=request.truncate,
typical_p=request.typical_p,
watermark=request.watermark,
seed=request.seed,
)
return self.generate(prompt, config)

async def ChatComplete(
self, request: ChatCompletionRequest, context: GrpcContext
) -> ChatCompletionResponse:
gen_stream = self._build_gen_stream(
self.config.apply_chat_template(request.chat_items), request
)

content = ""
for text_chunk in gen_stream:
content += text_chunk

item = ChatItem(role=ChatRole.ASSISTANT, content=content)
choice = ChatCompletionChoice(index=0, chat_item=item)
return ChatCompletionResponse(choices=[choice])

async def ChatCompleteStream(
self, request: ChatCompletionRequest, context: GrpcContext
) -> Generator[ChatCompletionResponse, Any, Any]:
gen_stream = self._build_gen_stream(
self.config.apply_chat_template(request.chat_items), request
)

for text_chunk in gen_stream:
item = ChatItem(role=ChatRole.ASSISTANT, content=text_chunk)
choice = ChatCompletionChoice(index=0, chat_item=item)

yield ChatCompletionResponse(choices=[choice])

async def Complete(
self, request: CompletionRequest, context: GrpcContext
) -> CompletionResponse:
gen_stream = self._build_gen_stream(request.prompt, request)

content = ""
for text_chunk in gen_stream:
content += text_chunk

choice = CompletionChoice(index=0, text=content)
return CompletionResponse(choices=[choice])

async def CompleteStream(
self, request: CompletionRequest, context: GrpcContext
) -> Generator[CompletionResponse, Any, Any]:
gen_stream = self._build_gen_stream(request.prompt, request)
for text_chunk in gen_stream:
print(text_chunk)
choice = CompletionChoice(index=0, text=text_chunk)
yield CompletionResponse(choices=[choice])

NewClass.__name__ = _cls.__name__
return NewClass
6 changes: 3 additions & 3 deletions src/leapfrogai/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from leapfrogai.name import name_pb2_grpc


async def serve(o):
async def serve(o, host, port):
# Create a tuple of all of the services we want to export via reflection.
services = (reflection.SERVICE_NAME, health.SERVICE_NAME)

Expand Down Expand Up @@ -58,8 +58,8 @@ async def serve(o):
health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING)

# Listen on port 50051
print("Starting server. Listening on port 50051.")
server.add_insecure_port("[::]:50051")
server.add_insecure_port("{}:{}".format(host, port))
print("Starting server. Listening on {}:{}.".format(host, port))
await server.start()

# block the thread until the server terminates...without using async to await the completion
Expand Down
116 changes: 116 additions & 0 deletions src/leapfrogai/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Adapted from Gunicorn's utils module.

import sys
import ast
import importlib
import logging
import traceback
from leapfrogai.errors import AppImportError

def import_app(module):
parts = module.split(":", 1)
if len(parts) == 1:
obj = "application"
else:
module, obj = parts[0], parts[1]

try:
mod = importlib.import_module(module)
except ImportError:
if module.endswith(".py") and os.path.exists(module):
msg = "Failed to find application, did you mean '%s:%s'?"
raise ImportError(msg % (module.rsplit(".", 1)[0], obj))
raise

# Parse obj as a single expression to determine if it's a valid
# attribute name or function call.
try:
expression = ast.parse(obj, mode="eval").body
except SyntaxError:
raise AppImportError(
"Failed to parse %r as an attribute name or function call." % obj
)

if isinstance(expression, ast.Name):
name = expression.id
args = kwargs = None
elif isinstance(expression, ast.Call):
# Ensure the function name is an attribute name only.
if not isinstance(expression.func, ast.Name):
raise AppImportError("Function reference must be a simple name: %r" % obj)

name = expression.func.id

# Parse the positional and keyword arguments as literals.
try:
args = [ast.literal_eval(arg) for arg in expression.args]
kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expression.keywords}
except ValueError:
# literal_eval gives cryptic error messages, show a generic
# message with the full expression instead.
raise AppImportError(
"Failed to parse arguments as literal values: %r" % obj
)
else:
raise AppImportError(
"Failed to parse %r as an attribute name or function call." % obj
)

is_debug = logging.root.level == logging.DEBUG
try:
app = getattr(mod, name)
except AttributeError:
if is_debug:
traceback.print_exception(*sys.exc_info())
raise AppImportError("Failed to find attribute %r in %r." % (name, module))

# If the expression was a function call, call the retrieved object
# to get the real application.
if args is not None:
try:
app = app(*args, **kwargs)
except TypeError as e:
# If the TypeError was due to bad arguments to the factory
# function, show Python's nice error message without a
# traceback.
if _called_with_wrong_args(app):
raise AppImportError(
"".join(traceback.format_exception_only(TypeError, e)).strip()
)

# Otherwise it was raised from within the function, show the
# full traceback.
raise

if app is None:
raise AppImportError("Failed to find application object: %r" % obj)

if not callable(app):
raise AppImportError("Application object must be callable.")
return app

def _called_with_wrong_args(f):
"""Check whether calling a function raised a ``TypeError`` because
the call failed or because something in the function raised the
error.
:param f: The function that was called.
:return: ``True`` if the call failed.
"""
tb = sys.exc_info()[2]

try:
while tb is not None:
if tb.tb_frame.f_code is f.__code__:
# In the function, it was called successfully.
return False

tb = tb.tb_next

# Didn't reach the function.
return True
finally:
# Delete tb to break a circular reference in Python 2.
# https://docs.python.org/2/library/sys.html#sys.exc_info
del tb

0 comments on commit 0cf8e37

Please sign in to comment.