diff --git a/pyproject.toml b/pyproject.toml index e71c78449..384751ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,12 @@ dependencies = [ "grpcio-health-checking >=1.58.0", "confz >= 2.0.1", "pydantic < 2", + "click >= 8.1.7", ] +[project.scripts] +leapfrogai = "leapfrogai.cli:cli" + [project.optional-dependencies] tests = ["pytest"] build = ["build", "hatchling", "twine"] diff --git a/requirements.txt b/requirements.txt index 7973a9738..26f2494a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,10 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --output-file=leapfrogai/requirements.txt pyproject.toml +# pip-compile --output-file=requirements.txt pyproject.toml # +click==8.1.7 + # via leapfrogai (pyproject.toml) confz==2.0.1 # via leapfrogai (pyproject.toml) grpcio==1.58.0 diff --git a/src/leapfrogai/cli.py b/src/leapfrogai/cli.py new file mode 100644 index 000000000..7cf98b038 --- /dev/null +++ b/src/leapfrogai/cli.py @@ -0,0 +1,39 @@ +import click +import sys +import asyncio + +from leapfrogai.utils import import_app +from leapfrogai.serve import serve + +@click.argument("app", envvar="LEAPFROGAI_APP") +@click.option( + "--host", + type=str, + default="0.0.0.0", + help="Bind socket to this host.", + show_default=True, +) +@click.option( + "--port", + type=int, + default=50051, + help="Bind socket to this port. If 0, an available port will be picked.", + show_default=True, +) +@click.option( + "--app-dir", + type=str, + default="", + help="Path to the directory containing the app module. Defaults to the current directory.", + show_default=True, +) +@click.command() +def cli(app: str, host: str, port: str, app_dir: str): + sys.path.insert(0, app_dir) + """Leapfrog AI CLI""" + app = import_app(app) + asyncio.run(serve(app(), host, port)) + + +if __name__ == "__main__": + cli() # pragma: no cover \ No newline at end of file diff --git a/src/leapfrogai/config.py b/src/leapfrogai/config.py index 8db5a22cf..a4ed27f32 100644 --- a/src/leapfrogai/config.py +++ b/src/leapfrogai/config.py @@ -36,7 +36,8 @@ class ModelConfig(BaseConfig): class BackendConfig(BaseConfig): name: str | None = None model: ModelConfig | None = None - max_seq_len: int = 2048 + max_context_length: int = 2048 + stop_tokens: list[str] | None = None prompt_format: PromptFormat | None = None defaults: LLMDefaults = LLMDefaults() diff --git a/src/leapfrogai/errors.py b/src/leapfrogai/errors.py new file mode 100644 index 000000000..dcdbbf91c --- /dev/null +++ b/src/leapfrogai/errors.py @@ -0,0 +1,4 @@ +# Adapted from Gunicorn's errors module. + +class AppImportError(Exception): + """ Exception raised when loading an application """ \ No newline at end of file diff --git a/src/leapfrogai/llm.py b/src/leapfrogai/llm.py new file mode 100644 index 000000000..fc8321432 --- /dev/null +++ b/src/leapfrogai/llm.py @@ -0,0 +1,123 @@ +from typing import Any, Generator, List + +from pydantic import BaseModel + +from leapfrogai import ( + BackendConfig, + ChatCompletionChoice, + ChatCompletionRequest, + ChatCompletionResponse, + ChatItem, + ChatRole, + CompletionChoice, + CompletionRequest, + CompletionResponse, + GrpcContext, +) + + +class GenerationConfig(BaseModel): + max_new_tokens: int + temperature: float + top_k: int + top_p: float + do_sample: bool + n: int + stop: List[str] + repetition_penalty: float + presence_penalty: float + frequency_penalty: float | None = None + best_of: str + logit_bias: dict[str, int] + return_full_text: bool + truncate: int + typical_p: float + watermark: bool + seed: int + + +def LLM(_cls): + if not hasattr(_cls, "generate"): + raise ValueError("LLM class requires a generate method") + + class NewClass(_cls): + config: BackendConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.config = BackendConfig() + + def _build_gen_stream( + self, prompt: str, request: ChatCompletionRequest | CompletionRequest + ) -> Generator[str, Any, Any]: + config = GenerationConfig( + max_new_tokens=request.max_new_tokens, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + do_sample=request.do_sample, + n=request.n, + stop=list(request.stop), + repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + best_of=request.best_of, + logit_bias=request.logit_bias, + return_full_text=request.return_full_text, + truncate=request.truncate, + typical_p=request.typical_p, + watermark=request.watermark, + seed=request.seed, + ) + return self.generate(prompt, config) + + async def ChatComplete( + self, request: ChatCompletionRequest, context: GrpcContext + ) -> ChatCompletionResponse: + gen_stream = self._build_gen_stream( + self.config.apply_chat_template(request.chat_items), request + ) + + content = "" + for text_chunk in gen_stream: + content += text_chunk + + item = ChatItem(role=ChatRole.ASSISTANT, content=content) + choice = ChatCompletionChoice(index=0, chat_item=item) + return ChatCompletionResponse(choices=[choice]) + + async def ChatCompleteStream( + self, request: ChatCompletionRequest, context: GrpcContext + ) -> Generator[ChatCompletionResponse, Any, Any]: + gen_stream = self._build_gen_stream( + self.config.apply_chat_template(request.chat_items), request + ) + + for text_chunk in gen_stream: + item = ChatItem(role=ChatRole.ASSISTANT, content=text_chunk) + choice = ChatCompletionChoice(index=0, chat_item=item) + + yield ChatCompletionResponse(choices=[choice]) + + async def Complete( + self, request: CompletionRequest, context: GrpcContext + ) -> CompletionResponse: + gen_stream = self._build_gen_stream(request.prompt, request) + + content = "" + for text_chunk in gen_stream: + content += text_chunk + + choice = CompletionChoice(index=0, text=content) + return CompletionResponse(choices=[choice]) + + async def CompleteStream( + self, request: CompletionRequest, context: GrpcContext + ) -> Generator[CompletionResponse, Any, Any]: + gen_stream = self._build_gen_stream(request.prompt, request) + for text_chunk in gen_stream: + print(text_chunk) + choice = CompletionChoice(index=0, text=text_chunk) + yield CompletionResponse(choices=[choice]) + + NewClass.__name__ = _cls.__name__ + return NewClass diff --git a/src/leapfrogai/serve.py b/src/leapfrogai/serve.py index 5ad8563c9..2c4a89f4e 100644 --- a/src/leapfrogai/serve.py +++ b/src/leapfrogai/serve.py @@ -11,7 +11,7 @@ from leapfrogai.name import name_pb2_grpc -async def serve(o): +async def serve(o, host, port): # Create a tuple of all of the services we want to export via reflection. services = (reflection.SERVICE_NAME, health.SERVICE_NAME) @@ -58,8 +58,8 @@ async def serve(o): health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) # Listen on port 50051 - print("Starting server. Listening on port 50051.") - server.add_insecure_port("[::]:50051") + server.add_insecure_port("{}:{}".format(host, port)) + print("Starting server. Listening on {}:{}.".format(host, port)) await server.start() # block the thread until the server terminates...without using async to await the completion diff --git a/src/leapfrogai/utils.py b/src/leapfrogai/utils.py new file mode 100644 index 000000000..dd0600508 --- /dev/null +++ b/src/leapfrogai/utils.py @@ -0,0 +1,116 @@ +# Adapted from Gunicorn's utils module. + +import sys +import ast +import importlib +import logging +import traceback +from leapfrogai.errors import AppImportError + +def import_app(module): + parts = module.split(":", 1) + if len(parts) == 1: + obj = "application" + else: + module, obj = parts[0], parts[1] + + try: + mod = importlib.import_module(module) + except ImportError: + if module.endswith(".py") and os.path.exists(module): + msg = "Failed to find application, did you mean '%s:%s'?" + raise ImportError(msg % (module.rsplit(".", 1)[0], obj)) + raise + + # Parse obj as a single expression to determine if it's a valid + # attribute name or function call. + try: + expression = ast.parse(obj, mode="eval").body + except SyntaxError: + raise AppImportError( + "Failed to parse %r as an attribute name or function call." % obj + ) + + if isinstance(expression, ast.Name): + name = expression.id + args = kwargs = None + elif isinstance(expression, ast.Call): + # Ensure the function name is an attribute name only. + if not isinstance(expression.func, ast.Name): + raise AppImportError("Function reference must be a simple name: %r" % obj) + + name = expression.func.id + + # Parse the positional and keyword arguments as literals. + try: + args = [ast.literal_eval(arg) for arg in expression.args] + kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in expression.keywords} + except ValueError: + # literal_eval gives cryptic error messages, show a generic + # message with the full expression instead. + raise AppImportError( + "Failed to parse arguments as literal values: %r" % obj + ) + else: + raise AppImportError( + "Failed to parse %r as an attribute name or function call." % obj + ) + + is_debug = logging.root.level == logging.DEBUG + try: + app = getattr(mod, name) + except AttributeError: + if is_debug: + traceback.print_exception(*sys.exc_info()) + raise AppImportError("Failed to find attribute %r in %r." % (name, module)) + + # If the expression was a function call, call the retrieved object + # to get the real application. + if args is not None: + try: + app = app(*args, **kwargs) + except TypeError as e: + # If the TypeError was due to bad arguments to the factory + # function, show Python's nice error message without a + # traceback. + if _called_with_wrong_args(app): + raise AppImportError( + "".join(traceback.format_exception_only(TypeError, e)).strip() + ) + + # Otherwise it was raised from within the function, show the + # full traceback. + raise + + if app is None: + raise AppImportError("Failed to find application object: %r" % obj) + + if not callable(app): + raise AppImportError("Application object must be callable.") + return app + +def _called_with_wrong_args(f): + """Check whether calling a function raised a ``TypeError`` because + the call failed or because something in the function raised the + error. + + :param f: The function that was called. + :return: ``True`` if the call failed. + """ + tb = sys.exc_info()[2] + + try: + while tb is not None: + if tb.tb_frame.f_code is f.__code__: + # In the function, it was called successfully. + return False + + tb = tb.tb_next + + # Didn't reach the function. + return True + finally: + # Delete tb to break a circular reference in Python 2. + # https://docs.python.org/2/library/sys.html#sys.exc_info + del tb +