diff --git a/src/leapfrogai/cli.py b/src/leapfrogai/cli.py index 467716883..3ece2a3c9 100644 --- a/src/leapfrogai/cli.py +++ b/src/leapfrogai/cli.py @@ -1,8 +1,10 @@ import click import sys +import asyncio from leapfrogai.errors import AppImportError from leapfrogai.utils import import_app +from leapfrogai.serve import serve @click.argument("app", envvar="LEAPFROGAI_APP") @click.option( @@ -31,7 +33,7 @@ def cli(app: str, host: str, port: str, app_dir: str): sys.path.insert(0, app_dir) """Leapfrog AI CLI""" app = import_app(app) - app().serve(host, port) + asyncio.run(serve(app(), host, port)) if __name__ == "__main__": diff --git a/src/leapfrogai/llm.py b/src/leapfrogai/llm.py index 3b78ae9de..c9ab82e5e 100644 --- a/src/leapfrogai/llm.py +++ b/src/leapfrogai/llm.py @@ -31,7 +31,7 @@ class GenerationConfig(BaseModel): stop: List[str] repetition_penalty: float presence_penalty: float - frequency_penalty: float + frequency_penalty: float | None = None best_of: str logit_bias: dict[str, int] return_full_text: bool @@ -58,10 +58,9 @@ def _build_gen_stream(self, prompt: str, request: ChatCompletionRequest | Comple top_p=request.top_p, do_sample=request.do_sample, n=request.n, - stop=request.stop, + stop=list(request.stop), repetition_penalty=request.repetition_penalty, presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, best_of=request.best_of, logit_bias=request.logit_bias, return_full_text=request.return_full_text, @@ -107,9 +106,6 @@ async def CompleteStream(self, request: CompletionRequest, context: GrpcContext) for text_chunk in gen_stream: choice = CompletionChoice(index=0, text=text_chunk) yield CompletionResponse(choices=[choice]) - - def serve(self, host, port): - asyncio.run(serve(self, host, port)) NewClass.__name__ = _cls.__name__ return NewClass \ No newline at end of file