Skip to content

Commit

Permalink
feat: add a configurable category-based logger
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Mar 3, 2025
1 parent 46b0a40 commit f80a305
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 47 deletions.
19 changes: 9 additions & 10 deletions llama_stack/distribution/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# the root directory of this source tree.
import importlib
import inspect
import logging
from typing import Any, Dict, List, Set

from llama_stack import logcat
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
Expand Down Expand Up @@ -50,8 +50,6 @@
VectorDBsProtocolPrivate,
)

log = logging.getLogger(__name__)


class InvalidProviderError(Exception):
pass
Expand Down Expand Up @@ -128,19 +126,20 @@ async def resolve_impls(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
log.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logcat.warning("core", f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue

if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")

p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
log.error(p.deprecation_error, "red", attrs=["bold"])
logcat.error("core", p.deprecation_error)
raise InvalidProviderError(p.deprecation_error)

elif p.deprecation_warning:
log.warning(
logcat.warning(
"core",
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
)
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
Expand Down Expand Up @@ -214,10 +213,10 @@ async def resolve_impls(
)
)

log.info(f"Resolved {len(sorted_providers)} providers")
logcat.debug("core", f"Resolved {len(sorted_providers)} providers")
for api_str, provider in sorted_providers:
log.info(f" {api_str} => {provider.provider_id}")
log.info("")
logcat.debug("core", f" {api_str} => {provider.provider_id}")
logcat.debug("core", "")

impls = {}
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
Expand Down Expand Up @@ -354,7 +353,7 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logcat.error("core", f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
Expand Down
60 changes: 60 additions & 0 deletions llama_stack/distribution/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
from typing import Any, AsyncGenerator, Dict, List, Optional

from llama_stack import logcat
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
Expand Down Expand Up @@ -63,12 +64,15 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing VectorIORouter")
self.routing_table = routing_table

async def initialize(self) -> None:
logcat.debug("core", "VectorIORouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "VectorIORouter.shutdown")
pass

async def register_vector_db(
Expand All @@ -79,6 +83,7 @@ async def register_vector_db(
provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None,
) -> None:
logcat.debug("core", f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
Expand All @@ -93,6 +98,10 @@ async def insert_chunks(
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
logcat.debug(
"core",
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.id for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)

async def query_chunks(
Expand All @@ -101,6 +110,7 @@ async def query_chunks(
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse:
logcat.debug("core", f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)


Expand All @@ -111,12 +121,15 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing InferenceRouter")
self.routing_table = routing_table

async def initialize(self) -> None:
logcat.debug("core", "InferenceRouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "InferenceRouter.shutdown")
pass

async def register_model(
Expand All @@ -127,6 +140,10 @@ async def register_model(
metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None:
logcat.debug(
"core",
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
)
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)

async def chat_completion(
Expand All @@ -142,6 +159,10 @@ async def chat_completion(
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
Expand Down Expand Up @@ -203,6 +224,10 @@ async def completion(
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
logcat.debug(
"core",
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
Expand Down Expand Up @@ -230,6 +255,7 @@ async def embeddings(
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
logcat.debug("core", f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
Expand All @@ -249,12 +275,15 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing SafetyRouter")
self.routing_table = routing_table

async def initialize(self) -> None:
logcat.debug("core", "SafetyRouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "SafetyRouter.shutdown")
pass

async def register_shield(
Expand All @@ -264,6 +293,7 @@ async def register_shield(
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield:
logcat.debug("core", f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)

async def run_shield(
Expand All @@ -272,6 +302,7 @@ async def run_shield(
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
logcat.debug("core", f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
Expand All @@ -284,12 +315,15 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing DatasetIORouter")
self.routing_table = routing_table

async def initialize(self) -> None:
logcat.debug("core", "DatasetIORouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "DatasetIORouter.shutdown")
pass

async def get_rows_paginated(
Expand All @@ -299,6 +333,7 @@ async def get_rows_paginated(
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
logcat.debug("core", f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}")
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
Expand All @@ -307,6 +342,7 @@ async def get_rows_paginated(
)

async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
logcat.debug("core", f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
Expand All @@ -318,12 +354,15 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ScoringRouter")
self.routing_table = routing_table

async def initialize(self) -> None:
logcat.debug("core", "ScoringRouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "ScoringRouter.shutdown")
pass

async def score_batch(
Expand All @@ -332,6 +371,7 @@ async def score_batch(
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logcat.debug("core", f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
Expand All @@ -352,6 +392,7 @@ async def score(
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
logcat.debug("core", f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
Expand All @@ -369,19 +410,23 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing EvalRouter")
self.routing_table = routing_table

async def initialize(self) -> None:
logcat.debug("core", "EvalRouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "EvalRouter.shutdown")
pass

async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
) -> Job:
logcat.debug("core", f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
task_config=task_config,
Expand All @@ -394,6 +439,7 @@ async def evaluate_rows(
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
Expand All @@ -406,13 +452,15 @@ async def job_status(
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
logcat.debug("core", f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)

async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logcat.debug("core", f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
Expand All @@ -423,6 +471,7 @@ async def job_result(
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logcat.debug("core", f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
Expand All @@ -435,6 +484,7 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table

async def query(
Expand All @@ -443,6 +493,7 @@ async def query(
vector_db_ids: List[str],
query_config: Optional[RAGQueryConfig] = None,
) -> RAGQueryResult:
logcat.debug("core", f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
Expand All @@ -453,6 +504,10 @@ async def insert(
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logcat.debug(
"core",
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}",
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
Expand All @@ -461,6 +516,7 @@ def __init__(
self,
routing_table: RoutingTable,
) -> None:
logcat.debug("core", "Initializing ToolRuntimeRouter")
self.routing_table = routing_table

# HACK ALERT this should be in sync with "get_all_api_endpoints()"
Expand All @@ -469,12 +525,15 @@ def __init__(
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))

async def initialize(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.initialize")
pass

async def shutdown(self) -> None:
logcat.debug("core", "ToolRuntimeRouter.shutdown")
pass

async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
logcat.debug("core", f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
Expand All @@ -483,4 +542,5 @@ async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> Any:
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
logcat.debug("core", f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
Loading

0 comments on commit f80a305

Please sign in to comment.