Skip to content

Commit

Permalink
use max_output_tokens from all models to show a slider with max value…
Browse files Browse the repository at this point in the history
… in lm settings
  • Loading branch information
nikochiko committed Feb 28, 2025
1 parent 65b7c59 commit 9148a9d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 24 deletions.
32 changes: 29 additions & 3 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class LLMSpec(typing.NamedTuple):
model_id: str | tuple
llm_api: LLMApis
context_window: int
max_output_tokens: int = 4096
max_output_tokens: int | None = None
price: int = 1
is_chat_model: bool = True
is_vision_model: bool = False
Expand Down Expand Up @@ -171,6 +171,7 @@ class LargeLanguageModels(Enum):
),
llm_api=LLMApis.openai,
context_window=128_000,
max_output_tokens=4096,
price=6,
is_vision_model=True,
supports_json=True,
Expand All @@ -180,6 +181,7 @@ class LargeLanguageModels(Enum):
model_id="gpt-4-vision-preview",
llm_api=LLMApis.openai,
context_window=128_000,
max_output_tokens=4096,
price=6,
is_vision_model=True,
is_deprecated=True,
Expand All @@ -191,6 +193,7 @@ class LargeLanguageModels(Enum):
model_id=("openai-gpt-4-turbo-prod-ca-1", "gpt-4-1106-preview"),
llm_api=LLMApis.openai,
context_window=128_000,
max_output_tokens=4096,
price=5,
supports_json=True,
)
Expand All @@ -201,13 +204,15 @@ class LargeLanguageModels(Enum):
model_id=("openai-gpt-4-prod-ca-1", "gpt-4"),
llm_api=LLMApis.openai,
context_window=8192,
max_output_tokens=8192,
price=10,
)
gpt_4_32k = LLMSpec(
label="GPT-4 32K (openai) 🔻",
model_id="openai-gpt-4-32k-prod-ca-1",
llm_api=LLMApis.openai,
context_window=32_768,
max_output_tokens=8192,
price=20,
)

Expand All @@ -225,6 +230,7 @@ class LargeLanguageModels(Enum):
model_id=("openai-gpt-35-turbo-16k-prod-ca-1", "gpt-3.5-turbo-16k-0613"),
llm_api=LLMApis.openai,
context_window=16_384,
max_output_tokens=4096,
price=2,
)
gpt_3_5_turbo_instruct = LLMSpec(
Expand All @@ -251,6 +257,7 @@ class LargeLanguageModels(Enum):
model_id="llama-3.3-70b-versatile",
llm_api=LLMApis.groq,
context_window=128_000,
max_output_tokens=32_768,
price=1,
supports_json=True,
)
Expand All @@ -259,6 +266,7 @@ class LargeLanguageModels(Enum):
model_id="llama-3.2-90b-vision-preview",
llm_api=LLMApis.groq,
context_window=128_000,
max_output_tokens=8192,
price=1,
supports_json=True,
is_vision_model=True,
Expand All @@ -268,6 +276,7 @@ class LargeLanguageModels(Enum):
model_id="llama-3.2-11b-vision-preview",
llm_api=LLMApis.groq,
context_window=128_000,
max_output_tokens=8192,
price=1,
supports_json=True,
is_vision_model=True,
Expand All @@ -278,6 +287,7 @@ class LargeLanguageModels(Enum):
model_id="llama-3.2-3b-preview",
llm_api=LLMApis.groq,
context_window=128_000,
max_output_tokens=8192,
price=1,
supports_json=True,
)
Expand All @@ -286,6 +296,7 @@ class LargeLanguageModels(Enum):
model_id="llama-3.2-1b-preview",
llm_api=LLMApis.groq,
context_window=128_000,
max_output_tokens=8192,
price=1,
supports_json=True,
)
Expand All @@ -295,6 +306,7 @@ class LargeLanguageModels(Enum):
model_id="accounts/fireworks/models/llama-v3p1-405b-instruct",
llm_api=LLMApis.fireworks,
context_window=128_000,
max_output_tokens=4096,
price=1,
supports_json=True,
)
Expand All @@ -303,6 +315,7 @@ class LargeLanguageModels(Enum):
model_id="llama-3.1-70b-versatile",
llm_api=LLMApis.groq,
context_window=128_000,
max_output_tokens=4096,
price=1,
supports_json=True,
is_deprecated=True,
Expand All @@ -311,7 +324,8 @@ class LargeLanguageModels(Enum):
label="Llama 3.1 8B (Meta AI)",
model_id="llama-3.1-8b-instant",
llm_api=LLMApis.groq,
context_window=128_00,
context_window=128_000,
max_output_tokens=8192,
price=1,
supports_json=True,
)
Expand All @@ -338,6 +352,7 @@ class LargeLanguageModels(Enum):
model_id="pixtral-large-2411",
llm_api=LLMApis.mistral,
context_window=131_000,
max_output_tokens=4096,
is_vision_model=True,
supports_json=True,
)
Expand All @@ -346,13 +361,15 @@ class LargeLanguageModels(Enum):
model_id="mistral-large-2411",
llm_api=LLMApis.mistral,
context_window=131_000,
max_output_tokens=4096,
supports_json=True,
)
mistral_small_24b_instruct = LLMSpec(
label="Mistral Small 25/01",
model_id="mistral-small-2501",
llm_api=LLMApis.mistral,
context_window=32_768,
max_output_tokens=4096,
price=1,
supports_json=True,
)
Expand All @@ -361,6 +378,7 @@ class LargeLanguageModels(Enum):
model_id="mixtral-8x7b-32768",
llm_api=LLMApis.groq,
context_window=32_768,
max_output_tokens=4096,
price=1,
supports_json=True,
is_deprecated=True,
Expand All @@ -370,6 +388,7 @@ class LargeLanguageModels(Enum):
model_id="gemma2-9b-it",
llm_api=LLMApis.groq,
context_window=8_192,
max_output_tokens=4096,
price=1,
supports_json=True,
)
Expand All @@ -378,6 +397,7 @@ class LargeLanguageModels(Enum):
model_id="gemma-7b-it",
llm_api=LLMApis.groq,
context_window=8_192,
max_output_tokens=4096,
price=1,
supports_json=True,
is_deprecated=True,
Expand Down Expand Up @@ -419,7 +439,6 @@ class LargeLanguageModels(Enum):
model_id="gemini-1.0-pro-vision",
llm_api=LLMApis.gemini,
context_window=2048,
max_output_tokens=8192,
price=25,
is_vision_model=True,
is_chat_model=False,
Expand All @@ -436,13 +455,15 @@ class LargeLanguageModels(Enum):
model_id="chat-bison",
llm_api=LLMApis.palm2,
context_window=4096,
max_output_tokens=1024,
price=10,
)
palm2_text = LLMSpec(
label="PaLM 2 Text (Google)",
model_id="text-bison",
llm_api=LLMApis.palm2,
context_window=8192,
max_output_tokens=1024,
price=15,
is_chat_model=False,
)
Expand All @@ -463,6 +484,7 @@ class LargeLanguageModels(Enum):
model_id="claude-3-opus-20240229",
llm_api=LLMApis.anthropic,
context_window=200_000,
max_output_tokens=4096,
price=75,
is_vision_model=True,
supports_json=True,
Expand All @@ -472,6 +494,7 @@ class LargeLanguageModels(Enum):
model_id="claude-3-sonnet-20240229",
llm_api=LLMApis.anthropic,
context_window=200_000,
max_output_tokens=4096,
price=15,
is_vision_model=True,
supports_json=True,
Expand All @@ -481,6 +504,7 @@ class LargeLanguageModels(Enum):
model_id="claude-3-haiku-20240307",
llm_api=LLMApis.anthropic,
context_window=200_000,
max_output_tokens=4096,
price=2,
is_vision_model=True,
supports_json=True,
Expand Down Expand Up @@ -514,6 +538,7 @@ class LargeLanguageModels(Enum):
model_id="llama3-groq-70b-8192-tool-use-preview",
llm_api=LLMApis.groq,
context_window=8192,
max_output_tokens=4096,
price=1,
supports_json=True,
is_deprecated=True,
Expand All @@ -523,6 +548,7 @@ class LargeLanguageModels(Enum):
model_id="llama3-groq-8b-8192-tool-use-preview",
llm_api=LLMApis.groq,
context_window=8192,
max_output_tokens=4096,
price=1,
supports_json=True,
is_deprecated=True,
Expand Down
60 changes: 40 additions & 20 deletions daras_ai_v2/language_model_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,24 @@ def language_model_selector(
)


def language_model_settings(selected_model: str = None):
try:
llm = LargeLanguageModels[selected_model]
except KeyError:
llm = None
def language_model_settings(selected_models: str | list[str] | None = None) -> None:
if isinstance(selected_models, str):
selected_models = [selected_models]
elif not selected_models:
selected_models = []

llms = []
for model in selected_models:
try:
llms.append(LargeLanguageModels[model])
except KeyError:
pass

col1, col2 = gui.columns(2)
with col1:
gui.checkbox("Avoid Repetition", key="avoid_repetition")
if not llm or llm.supports_json:

if any(map(lambda llm: llm.supports_json, llms)):
with col2:
gui.selectbox(
f"###### {field_title_desc(LanguageModelSettings, 'response_format_type')}",
Expand All @@ -55,27 +63,37 @@ def language_model_settings(selected_model: str = None):

col1, col2 = gui.columns(2)
with col1:
gui.number_input(
if llms:
max_output_tokens = min(
[llm.max_output_tokens or llm.context_window for llm in llms]
)
else:
max_output_tokens = 4096

gui.slider(
label="""
###### Max Output Tokens
The maximum number of tokens to generate in the completion. Increase to generate longer responses.
""",
key="max_tokens",
min_value=10,
step=10,
)
with col2:
gui.slider(
label="""
###### Creativity (aka Sampling Temperature)
Higher values allow the LLM to take more risks. Try values larger than 1 for more creative applications or 0 to ensure that LLM gives the same answer when given the same user input.
""",
key="sampling_temperature",
min_value=0.0,
max_value=2.0,
max_value=max_output_tokens,
step=2,
)

if any(map(lambda llm: llm.supports_temperature, llms)):
with col2:
gui.slider(
label="""
###### Creativity (aka Sampling Temperature)
Higher values allow the LLM to take more risks. Try values larger than 1 for more creative applications or 0 to ensure that LLM gives the same answer when given the same user input.
""",
key="sampling_temperature",
min_value=0.0,
max_value=2.0,
)

col1, col2 = gui.columns(2)
with col1:
gui.slider(
Expand All @@ -87,7 +105,9 @@ def language_model_settings(selected_model: str = None):
min_value=1,
max_value=4,
)
if llm and not llm.is_chat_model and llm.llm_api == LLMApis.openai:
if llms and any(
map(lambda llm: not llm.is_chat_model and llm.llm_api == LLMApis.openai, llms)
):
with col2:
gui.slider(
label="""
Expand Down
4 changes: 3 additions & 1 deletion recipes/CompareLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def render_usage_guide(self):
youtube_video("dhexRRDAuY8")

def render_settings(self):
language_model_settings()
language_model_settings(
selected_models=gui.session_state.get("selected_models")
)

def run(self, state: dict) -> typing.Iterator[str | None]:
request: CompareLLMPage.RequestModel = self.RequestModel.parse_obj(state)
Expand Down

0 comments on commit 9148a9d

Please sign in to comment.