Skip to content

Commit

Permalink
refactor: update start_graph_thread to use current settings and add m…
Browse files Browse the repository at this point in the history
…odel guessing logic
  • Loading branch information
provos committed Jan 31, 2025
1 parent ced8a4d commit af599ee
Showing 1 changed file with 50 additions and 13 deletions.
63 changes: 50 additions & 13 deletions examples/deepsearch/deepsearch/deepsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,13 @@ def notify(metadata, message: Response):
task_queue.put((sid, session_id, message))


def start_graph_thread(
provider: str = "ollama", model: str = "llama2", host: str = "localhost:11434"
):
def start_graph_thread(current_settings: Dict[str, str]):
"""Modified to use current settings."""
global graph_thread, graph, entry_worker, chat_worker, debug_saver, current_settings
global graph_thread, graph, entry_worker, chat_worker, debug_saver

# Update current settings
current_settings["provider"] = provider
current_settings["model"] = model
provider = current_settings["provider"]
model = current_settings["model"]
host = current_settings["ollamaHost"]

graph, entry_worker, chat_worker = setup_graph(
provider=provider, model=model, host=host, notify=notify
Expand Down Expand Up @@ -453,12 +451,46 @@ def stop_threads():
stop_graph_thread()


def guess_model(settings: Dict[str, str]):
"""Guess a model based on the current configuration."""
if settings["provider"] == "ollama":
interface = llm_from_config(
provider=settings["provider"], host=settings["ollamaHost"]
)
# try to see whether we can list the models from ollama - our preferred case
try:
models = interface.list()
if settings["model"] not in [model.model for model in models.models]:
# choose either a llama3 or phi4 model
for model in models.models:
if model.model.startswith("llama3") or model.model.startswith(
"phi4"
):
settings["model"] = model.model
return
except Exception:
pass

if os.environ.get("OPENAI_API_KEY"):
settings["provider"] = "openai"
settings["model"] = "gpt-4o-mini"
return

if os.environ.get("ANTHROPIC_API_KEY"):
settings["provider"] = "anthropic"
settings["model"] = "claude-3-haiku-20240307"

# we leave the model as is
return


def validate_provider(provider: str, api_key: str = None) -> Tuple[bool, List[str]]:
"""Validate provider and return available models.
For Ollama, api_key parameter is used to pass the host address.
For other providers, it's used as the API key.
"""
global current_settings
try:
kwargs = {}
if provider == "ollama":
Expand Down Expand Up @@ -517,6 +549,8 @@ def handle_validate_provider(data):
@socketio.on("load_settings")
def handle_load_settings():
"""Load current settings and validate all providers."""
global current_settings

settings = {
"provider": current_settings["provider"],
"modelName": current_settings["model"],
Expand Down Expand Up @@ -555,6 +589,7 @@ def handle_load_settings():
def handle_save_settings(data):
"""Save settings and update environment variables."""
logging.info(f"Received settings: {data.keys()}")
global current_settings
try:
# Update current settings with lowercase provider
current_settings["provider"] = data.get("provider", "ollama").lower()
Expand All @@ -580,11 +615,7 @@ def handle_save_settings(data):
global graph, entry_worker
if graph:
stop_graph_thread()
start_graph_thread(
provider=current_settings["provider"],
model=current_settings["model"],
host=current_settings["ollamaHost"],
)
start_graph_thread(current_settings)

emit("settings_saved", {"status": "success"})
except Exception as e:
Expand Down Expand Up @@ -637,8 +668,14 @@ def main():
args = parser.parse_args()

# Initialize current settings from args
global current_settings
current_settings["provider"] = args.provider
current_settings["model"] = args.model
current_settings["ollamaHost"] = f"localhost:{args.ollama_port}"
guess_model(current_settings)
print(
f"Starting with settings: {current_settings['provider']} {current_settings['model']}"
)

if args.debug:
global debug_saver
Expand All @@ -659,7 +696,7 @@ def main():
debug_saver.load_replays()

setup_logging(level=logging.DEBUG if args.debug else logging.ERROR)
start_graph_thread(args.provider, args.model, host=f"localhost:{args.ollama_port}")
start_graph_thread(current_settings=current_settings)
setup_web_interface(port=args.port)


Expand Down

0 comments on commit af599ee

Please sign in to comment.