-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #243 from defenseunicorns/refresh-sdk
Spike at new CLI and SDK implementation
- Loading branch information
Showing
8 changed files
with
294 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import click | ||
import sys | ||
import asyncio | ||
|
||
from leapfrogai.utils import import_app | ||
from leapfrogai.serve import serve | ||
|
||
@click.argument("app", envvar="LEAPFROGAI_APP") | ||
@click.option( | ||
"--host", | ||
type=str, | ||
default="0.0.0.0", | ||
help="Bind socket to this host.", | ||
show_default=True, | ||
) | ||
@click.option( | ||
"--port", | ||
type=int, | ||
default=50051, | ||
help="Bind socket to this port. If 0, an available port will be picked.", | ||
show_default=True, | ||
) | ||
@click.option( | ||
"--app-dir", | ||
type=str, | ||
default="", | ||
help="Path to the directory containing the app module. Defaults to the current directory.", | ||
show_default=True, | ||
) | ||
@click.command() | ||
def cli(app: str, host: str, port: str, app_dir: str): | ||
sys.path.insert(0, app_dir) | ||
"""Leapfrog AI CLI""" | ||
app = import_app(app) | ||
asyncio.run(serve(app(), host, port)) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() # pragma: no cover |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Adapted from Gunicorn's errors module. | ||
|
||
class AppImportError(Exception): | ||
""" Exception raised when loading an application """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from typing import Any, Generator, List | ||
|
||
from pydantic import BaseModel | ||
|
||
from leapfrogai import ( | ||
BackendConfig, | ||
ChatCompletionChoice, | ||
ChatCompletionRequest, | ||
ChatCompletionResponse, | ||
ChatItem, | ||
ChatRole, | ||
CompletionChoice, | ||
CompletionRequest, | ||
CompletionResponse, | ||
GrpcContext, | ||
) | ||
|
||
|
||
class GenerationConfig(BaseModel): | ||
max_new_tokens: int | ||
temperature: float | ||
top_k: int | ||
top_p: float | ||
do_sample: bool | ||
n: int | ||
stop: List[str] | ||
repetition_penalty: float | ||
presence_penalty: float | ||
frequency_penalty: float | None = None | ||
best_of: str | ||
logit_bias: dict[str, int] | ||
return_full_text: bool | ||
truncate: int | ||
typical_p: float | ||
watermark: bool | ||
seed: int | ||
|
||
|
||
def LLM(_cls): | ||
if not hasattr(_cls, "generate"): | ||
raise ValueError("LLM class requires a generate method") | ||
|
||
class NewClass(_cls): | ||
config: BackendConfig | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.config = BackendConfig() | ||
|
||
def _build_gen_stream( | ||
self, prompt: str, request: ChatCompletionRequest | CompletionRequest | ||
) -> Generator[str, Any, Any]: | ||
config = GenerationConfig( | ||
max_new_tokens=request.max_new_tokens, | ||
temperature=request.temperature, | ||
top_k=request.top_k, | ||
top_p=request.top_p, | ||
do_sample=request.do_sample, | ||
n=request.n, | ||
stop=list(request.stop), | ||
repetition_penalty=request.repetition_penalty, | ||
presence_penalty=request.presence_penalty, | ||
best_of=request.best_of, | ||
logit_bias=request.logit_bias, | ||
return_full_text=request.return_full_text, | ||
truncate=request.truncate, | ||
typical_p=request.typical_p, | ||
watermark=request.watermark, | ||
seed=request.seed, | ||
) | ||
return self.generate(prompt, config) | ||
|
||
async def ChatComplete( | ||
self, request: ChatCompletionRequest, context: GrpcContext | ||
) -> ChatCompletionResponse: | ||
gen_stream = self._build_gen_stream( | ||
self.config.apply_chat_template(request.chat_items), request | ||
) | ||
|
||
content = "" | ||
for text_chunk in gen_stream: | ||
content += text_chunk | ||
|
||
item = ChatItem(role=ChatRole.ASSISTANT, content=content) | ||
choice = ChatCompletionChoice(index=0, chat_item=item) | ||
return ChatCompletionResponse(choices=[choice]) | ||
|
||
async def ChatCompleteStream( | ||
self, request: ChatCompletionRequest, context: GrpcContext | ||
) -> Generator[ChatCompletionResponse, Any, Any]: | ||
gen_stream = self._build_gen_stream( | ||
self.config.apply_chat_template(request.chat_items), request | ||
) | ||
|
||
for text_chunk in gen_stream: | ||
item = ChatItem(role=ChatRole.ASSISTANT, content=text_chunk) | ||
choice = ChatCompletionChoice(index=0, chat_item=item) | ||
|
||
yield ChatCompletionResponse(choices=[choice]) | ||
|
||
async def Complete( | ||
self, request: CompletionRequest, context: GrpcContext | ||
) -> CompletionResponse: | ||
gen_stream = self._build_gen_stream(request.prompt, request) | ||
|
||
content = "" | ||
for text_chunk in gen_stream: | ||
content += text_chunk | ||
|
||
choice = CompletionChoice(index=0, text=content) | ||
return CompletionResponse(choices=[choice]) | ||
|
||
async def CompleteStream( | ||
self, request: CompletionRequest, context: GrpcContext | ||
) -> Generator[CompletionResponse, Any, Any]: | ||
gen_stream = self._build_gen_stream(request.prompt, request) | ||
for text_chunk in gen_stream: | ||
print(text_chunk) | ||
choice = CompletionChoice(index=0, text=text_chunk) | ||
yield CompletionResponse(choices=[choice]) | ||
|
||
NewClass.__name__ = _cls.__name__ | ||
return NewClass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Adapted from Gunicorn's utils module. | ||
|
||
import sys | ||
import ast | ||
import importlib | ||
import logging | ||
import traceback | ||
from leapfrogai.errors import AppImportError | ||
|
||
def import_app(module): | ||
parts = module.split(":", 1) | ||
if len(parts) == 1: | ||
obj = "application" | ||
else: | ||
module, obj = parts[0], parts[1] | ||
|
||
try: | ||
mod = importlib.import_module(module) | ||
except ImportError: | ||
if module.endswith(".py") and os.path.exists(module): | ||
msg = "Failed to find application, did you mean '%s:%s'?" | ||
raise ImportError(msg % (module.rsplit(".", 1)[0], obj)) | ||
raise | ||
|
||
# Parse obj as a single expression to determine if it's a valid | ||
# attribute name or function call. | ||
try: | ||
expression = ast.parse(obj, mode="eval").body | ||
except SyntaxError: | ||
raise AppImportError( | ||
"Failed to parse %r as an attribute name or function call." % obj | ||
) | ||
|
||
if isinstance(expression, ast.Name): | ||
name = expression.id | ||
args = kwargs = None | ||
elif isinstance(expression, ast.Call): | ||
# Ensure the function name is an attribute name only. | ||
if not isinstance(expression.func, ast.Name): | ||
raise AppImportError("Function reference must be a simple name: %r" % obj) | ||
|
||
name = expression.func.id | ||
|
||
# Parse the positional and keyword arguments as literals. | ||
try: | ||
args = [ast.literal_eval(arg) for arg in expression.args] | ||
kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expression.keywords} | ||
except ValueError: | ||
# literal_eval gives cryptic error messages, show a generic | ||
# message with the full expression instead. | ||
raise AppImportError( | ||
"Failed to parse arguments as literal values: %r" % obj | ||
) | ||
else: | ||
raise AppImportError( | ||
"Failed to parse %r as an attribute name or function call." % obj | ||
) | ||
|
||
is_debug = logging.root.level == logging.DEBUG | ||
try: | ||
app = getattr(mod, name) | ||
except AttributeError: | ||
if is_debug: | ||
traceback.print_exception(*sys.exc_info()) | ||
raise AppImportError("Failed to find attribute %r in %r." % (name, module)) | ||
|
||
# If the expression was a function call, call the retrieved object | ||
# to get the real application. | ||
if args is not None: | ||
try: | ||
app = app(*args, **kwargs) | ||
except TypeError as e: | ||
# If the TypeError was due to bad arguments to the factory | ||
# function, show Python's nice error message without a | ||
# traceback. | ||
if _called_with_wrong_args(app): | ||
raise AppImportError( | ||
"".join(traceback.format_exception_only(TypeError, e)).strip() | ||
) | ||
|
||
# Otherwise it was raised from within the function, show the | ||
# full traceback. | ||
raise | ||
|
||
if app is None: | ||
raise AppImportError("Failed to find application object: %r" % obj) | ||
|
||
if not callable(app): | ||
raise AppImportError("Application object must be callable.") | ||
return app | ||
|
||
def _called_with_wrong_args(f): | ||
"""Check whether calling a function raised a ``TypeError`` because | ||
the call failed or because something in the function raised the | ||
error. | ||
:param f: The function that was called. | ||
:return: ``True`` if the call failed. | ||
""" | ||
tb = sys.exc_info()[2] | ||
|
||
try: | ||
while tb is not None: | ||
if tb.tb_frame.f_code is f.__code__: | ||
# In the function, it was called successfully. | ||
return False | ||
|
||
tb = tb.tb_next | ||
|
||
# Didn't reach the function. | ||
return True | ||
finally: | ||
# Delete tb to break a circular reference in Python 2. | ||
# https://docs.python.org/2/library/sys.html#sys.exc_info | ||
del tb | ||
|