Skip to content

Commit

Permalink
frequency penalty fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gerred committed Nov 14, 2023
1 parent f991de8 commit 9b5c172
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/leapfrogai/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import click
import sys
import asyncio

from leapfrogai.errors import AppImportError
from leapfrogai.utils import import_app
from leapfrogai.serve import serve

@click.argument("app", envvar="LEAPFROGAI_APP")
@click.option(
Expand Down Expand Up @@ -31,7 +33,7 @@ def cli(app: str, host: str, port: str, app_dir: str):
sys.path.insert(0, app_dir)
"""Leapfrog AI CLI"""
app = import_app(app)
app().serve(host, port)
asyncio.run(serve(app(), host, port))


if __name__ == "__main__":
Expand Down
8 changes: 2 additions & 6 deletions src/leapfrogai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GenerationConfig(BaseModel):
stop: List[str]
repetition_penalty: float
presence_penalty: float
frequency_penalty: float
frequency_penalty: float | None = None
best_of: str
logit_bias: dict[str, int]
return_full_text: bool
Expand All @@ -58,10 +58,9 @@ def _build_gen_stream(self, prompt: str, request: ChatCompletionRequest | Comple
top_p=request.top_p,
do_sample=request.do_sample,
n=request.n,
stop=request.stop,
stop=list(request.stop),
repetition_penalty=request.repetition_penalty,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
best_of=request.best_of,
logit_bias=request.logit_bias,
return_full_text=request.return_full_text,
Expand Down Expand Up @@ -107,9 +106,6 @@ async def CompleteStream(self, request: CompletionRequest, context: GrpcContext)
for text_chunk in gen_stream:
choice = CompletionChoice(index=0, text=text_chunk)
yield CompletionResponse(choices=[choice])

def serve(self, host, port):
asyncio.run(serve(self, host, port))

NewClass.__name__ = _cls.__name__
return NewClass

0 comments on commit 9b5c172

Please sign in to comment.