diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e33b9c..aaf77ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- Return model dumps of DB schema objects. + ### Added - LLM evaluation logic - Integrated Alembic for managing chat history migrations diff --git a/src/neuroagent/cell_types.py b/src/neuroagent/cell_types.py index df84d38..9c91e32 100644 --- a/src/neuroagent/cell_types.py +++ b/src/neuroagent/cell_types.py @@ -17,7 +17,7 @@ class CellTypesMeta: """ def __init__(self) -> None: - self.name_: dict[str, str] = {} + self.name_: dict[Any, Any | None] = {} self.descendants_ids: dict[str, set[str]] = {} def descendants(self, ids: str | set[str]) -> set[str]: diff --git a/swarm_copy/cell_types.py b/swarm_copy/cell_types.py index df84d38..9c91e32 100644 --- a/swarm_copy/cell_types.py +++ b/swarm_copy/cell_types.py @@ -17,7 +17,7 @@ class CellTypesMeta: """ def __init__(self) -> None: - self.name_: dict[str, str] = {} + self.name_: dict[Any, Any | None] = {} self.descendants_ids: dict[str, set[str]] = {} def descendants(self, ids: str | set[str]) -> set[str]: diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index a7cee19..6c95c0e 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal from pydantic import BaseModel, Field @@ -46,7 +46,7 @@ class MEModelGetAllTool(BaseTool): metadata: MEModelGetAllMetadata input_schema: InputMEModelGetAll - async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelResponse: + async def arun(self) -> dict[str, Any]: """Run the MEModelGetAll tool.""" logger.info( f"Running MEModelGetAll tool with inputs {self.input_schema.model_dump()}" @@ -61,7 +61,6 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo }, headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - breakpoint() return PaginatedResponseUnionMEModelResponseSynaptomeModelResponse( **response.json() - ) + ).model_dump() diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py index f84acfa..2f36f64 100644 --- a/swarm_copy/tools/bluenaas_memodel_getone.py +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar +from typing import Any, ClassVar from urllib.parse import quote_plus from pydantic import BaseModel, Field @@ -38,7 +38,7 @@ class MEModelGetOneTool(BaseTool): metadata: MEModelGetOneMetadata input_schema: InputMEModelGetOne - async def arun(self) -> MEModelResponse: + async def arun(self) -> dict[str, Any]: """Run the MEModelGetOne tool.""" logger.info( f"Running MEModelGetOne tool with inputs {self.input_schema.model_dump()}" @@ -49,4 +49,4 @@ async def arun(self) -> MEModelResponse: headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - return MEModelResponse(**response.json()) + return MEModelResponse(**response.json()).model_dump() diff --git a/swarm_copy/tools/bluenaas_scs_getall.py b/swarm_copy/tools/bluenaas_scs_getall.py index 95897dc..533ed7f 100644 --- a/swarm_copy/tools/bluenaas_scs_getall.py +++ b/swarm_copy/tools/bluenaas_scs_getall.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar, Literal +from typing import Any, ClassVar, Literal from pydantic import BaseModel, Field @@ -47,7 +47,7 @@ class SCSGetAllTool(BaseTool): metadata: SCSGetAllMetadata input_schema: InputSCSGetAll - async def arun(self) -> PaginatedResponseSimulationDetailsResponse: + async def arun(self) -> dict[str, Any]: """Run the SCSGetAll tool.""" logger.info( f"Running SCSGetAll tool with inputs {self.input_schema.model_dump()}" @@ -63,4 +63,6 @@ async def arun(self) -> PaginatedResponseSimulationDetailsResponse: headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - return PaginatedResponseSimulationDetailsResponse(**response.json()) + return PaginatedResponseSimulationDetailsResponse( + **response.json() + ).model_dump() diff --git a/swarm_copy/tools/bluenaas_scs_getone.py b/swarm_copy/tools/bluenaas_scs_getone.py index 4957be9..2575682 100644 --- a/swarm_copy/tools/bluenaas_scs_getone.py +++ b/swarm_copy/tools/bluenaas_scs_getone.py @@ -1,7 +1,7 @@ """BlueNaaS single cell stimulation, simulation and synapse placement tool.""" import logging -from typing import ClassVar +from typing import Any, ClassVar from pydantic import BaseModel, Field @@ -39,7 +39,7 @@ class SCSGetOneTool(BaseTool): metadata: SCSGetOneMetadata input_schema: InputSCSGetOne - async def arun(self) -> SimulationDetailsResponse: + async def arun(self) -> dict[str, Any]: """Run the SCSGetOne tool.""" logger.info( f"Running SCSGetOne tool with inputs {self.input_schema.model_dump()}" @@ -50,4 +50,4 @@ async def arun(self) -> SimulationDetailsResponse: headers={"Authorization": f"Bearer {self.metadata.token}"}, ) - return SimulationDetailsResponse(**response.json()) + return SimulationDetailsResponse(**response.json()).model_dump() diff --git a/swarm_copy/tools/bluenaas_scs_post.py b/swarm_copy/tools/bluenaas_scs_post.py index 6c8e154..7e3144c 100644 --- a/swarm_copy/tools/bluenaas_scs_post.py +++ b/swarm_copy/tools/bluenaas_scs_post.py @@ -94,7 +94,7 @@ class SCSPostTool(BaseTool): metadata: SCSPostMetadata input_schema: InputSCSPost - async def arun(self) -> SCSPostOutput: + async def arun(self) -> dict[str, Any]: """Run the SCSPost tool.""" logger.info( f"Running SCSPost tool with inputs {self.input_schema.model_dump()}" @@ -126,7 +126,7 @@ async def arun(self) -> SCSPostOutput: status=json_response["status"], name=json_response["name"], error=json_response["error"], - ) + ).model_dump() @staticmethod def create_json_api( diff --git a/swarm_copy/tools/electrophys_tool.py b/swarm_copy/tools/electrophys_tool.py index 00673c1..2984836 100644 --- a/swarm_copy/tools/electrophys_tool.py +++ b/swarm_copy/tools/electrophys_tool.py @@ -194,7 +194,7 @@ class ElectrophysFeatureTool(BaseTool): input_schema: ElectrophysInput metadata: ElectrophysMetadata - async def arun(self) -> FeatureOutput: + async def arun(self) -> dict[str, Any]: """Give features about trace.""" logger.info( f"Entering electrophys tool. Inputs: {self.input_schema.trace_id=}, {self.input_schema.calculated_feature=}," @@ -329,4 +329,4 @@ async def arun(self) -> FeatureOutput: ) return FeatureOutput( brain_region=metadata.brain_region, feature_dict=output_features - ) + ).model_dump() diff --git a/swarm_copy/tools/get_morpho_tool.py b/swarm_copy/tools/get_morpho_tool.py index dc8d4a6..45c72bb 100644 --- a/swarm_copy/tools/get_morpho_tool.py +++ b/swarm_copy/tools/get_morpho_tool.py @@ -70,7 +70,7 @@ class GetMorphoTool(BaseTool): input_schema: GetMorphoInput metadata: GetMorphoMetadata - async def arun(self) -> list[KnowledgeGraphOutput]: + async def arun(self) -> list[dict[str, Any]]: """From a brain region ID, extract morphologies. Returns @@ -175,7 +175,7 @@ def create_query( return entire_query @staticmethod - def _process_output(output: Any) -> list[KnowledgeGraphOutput]: + def _process_output(output: Any) -> list[dict[str, Any]]: """Process output to fit the KnowledgeGraphOutput pydantic class defined above. Parameters @@ -211,7 +211,7 @@ def _process_output(output: Any) -> list[KnowledgeGraphOutput]: if "subjectAge" in res["_source"] else None ), - ) + ).model_dump() for res in output["hits"]["hits"] ] return formatted_output diff --git a/swarm_copy/tools/kg_morpho_features_tool.py b/swarm_copy/tools/kg_morpho_features_tool.py index 24eeac8..7298636 100644 --- a/swarm_copy/tools/kg_morpho_features_tool.py +++ b/swarm_copy/tools/kg_morpho_features_tool.py @@ -186,7 +186,7 @@ class KGMorphoFeatureTool(BaseTool): input_schema: KGMorphoFeatureInput metadata: KGMorphoFeatureMetadata - async def arun(self) -> list[KGMorphoFeatureOutput]: + async def arun(self) -> list[dict[str, Any]]: """Run the tool async. Returns @@ -319,7 +319,7 @@ def create_query( return entire_query @staticmethod - def _process_output(output: Any) -> list[KGMorphoFeatureOutput]: + def _process_output(output: Any) -> list[dict[str, Any]]: """Process output. Parameters @@ -347,7 +347,7 @@ def _process_output(output: Any) -> list[KGMorphoFeatureOutput]: morphology_id=morpho_source["neuronMorphology"]["@id"], morphology_name=morpho_source["neuronMorphology"].get("name"), features=feature_output, - ) + ).model_dump() ) return formatted_output diff --git a/swarm_copy/tools/literature_search_tool.py b/swarm_copy/tools/literature_search_tool.py index 99880b9..92ebf7d 100644 --- a/swarm_copy/tools/literature_search_tool.py +++ b/swarm_copy/tools/literature_search_tool.py @@ -61,7 +61,7 @@ class LiteratureSearchTool(BaseTool): input_schema: LiteratureSearchInput metadata: LiteratureSearchMetadata - async def arun(self) -> list[ParagraphMetadata]: + async def arun(self) -> list[dict[str, Any]]: """Async search the scientific literature and returns citations. Returns @@ -91,7 +91,7 @@ async def arun(self) -> list[ParagraphMetadata]: return self._process_output(response.json()) @staticmethod - def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]: + def _process_output(output: list[dict[str, Any]]) -> list[dict[str, Any]]: """Process output.""" paragraphs_metadata = [ ParagraphMetadata( @@ -101,7 +101,7 @@ def _process_output(output: list[dict[str, Any]]) -> list[ParagraphMetadata]: section=paragraph["section"], article_doi=paragraph["article_doi"], journal_issn=paragraph["journal_issn"], - ) + ).model_dump() for paragraph in output ] return paragraphs_metadata diff --git a/swarm_copy/tools/morphology_features_tool.py b/swarm_copy/tools/morphology_features_tool.py index 00bd349..31b2548 100644 --- a/swarm_copy/tools/morphology_features_tool.py +++ b/swarm_copy/tools/morphology_features_tool.py @@ -52,7 +52,7 @@ class MorphologyFeatureTool(BaseTool): input_schema: MorphologyFeatureInput metadata: MorphologyFeatureMetadata - async def arun(self) -> list[MorphologyFeatureOutput]: + async def arun(self) -> list[dict[str, Any]]: """Give features about morphology.""" logger.info( f"Entering morphology feature tool. Inputs: {self.input_schema.morphology_id=}" @@ -71,7 +71,7 @@ async def arun(self) -> list[MorphologyFeatureOutput]: return [ MorphologyFeatureOutput( brain_region=metadata.brain_region, feature_dict=features - ) + ).model_dump() ] def get_features(self, morphology_content: bytes, reader: str) -> dict[str, Any]: diff --git a/swarm_copy/tools/resolve_entities_tool.py b/swarm_copy/tools/resolve_entities_tool.py index 1264ac1..e6ab88d 100644 --- a/swarm_copy/tools/resolve_entities_tool.py +++ b/swarm_copy/tools/resolve_entities_tool.py @@ -1,7 +1,7 @@ """Tool to resolve the brain region from natural english to a KG ID.""" import logging -from typing import ClassVar +from typing import Any, ClassVar from pydantic import BaseModel, Field @@ -86,14 +86,14 @@ class ResolveEntitiesTool(BaseTool): async def arun( self, - ) -> list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput]: + ) -> list[dict[str, Any]]: """Given a brain region in natural language, resolve its ID.""" logger.info( f"Entering Brain Region resolver tool. Inputs: {self.input_schema.brain_region=}, " f"{self.input_schema.mtype=}, {self.input_schema.etype=}" ) # Prepare the output list. - output: list[BRResolveOutput | MTypeResolveOutput | EtypeResolveOutput] = [] + output: list[dict[str, Any]] = [] # First resolve the brain regions. brain_regions = await resolve_query( @@ -108,7 +108,9 @@ async def arun( # Extend the resolved BRs. output.extend( [ - BRResolveOutput(brain_region_name=br["label"], brain_region_id=br["id"]) + BRResolveOutput( + brain_region_name=br["label"], brain_region_id=br["id"] + ).model_dump() for br in brain_regions ] ) @@ -127,7 +129,9 @@ async def arun( # Extend the resolved mtypes. output.extend( [ - MTypeResolveOutput(mtype_name=mtype["label"], mtype_id=mtype["id"]) + MTypeResolveOutput( + mtype_name=mtype["label"], mtype_id=mtype["id"] + ).model_dump() for mtype in mtypes ] ) @@ -138,7 +142,7 @@ async def arun( EtypeResolveOutput( etype_name=self.input_schema.etype, etype_id=ETYPE_IDS[self.input_schema.etype], - ) + ).model_dump() ) return output diff --git a/swarm_copy/tools/traces_tool.py b/swarm_copy/tools/traces_tool.py index 0434013..0703259 100644 --- a/swarm_copy/tools/traces_tool.py +++ b/swarm_copy/tools/traces_tool.py @@ -71,7 +71,7 @@ class GetTracesTool(BaseTool): input_schema: GetTracesInput metadata: GetTracesMetadata - async def arun(self) -> list[TracesOutput]: + async def arun(self) -> list[dict[str, Any]]: """From a brain region ID, extract traces.""" logger.info( f"Entering get trace tool. Inputs: {self.input_schema.brain_region_id=}, {self.input_schema.etype_id=}" @@ -153,7 +153,7 @@ def create_query( return entire_query @staticmethod - def _process_output(output: Any) -> list[TracesOutput]: + def _process_output(output: Any) -> list[dict[str, Any]]: """Process output to fit the TracesOutput pydantic class defined above. Parameters @@ -190,7 +190,7 @@ def _process_output(output: Any) -> list[TracesOutput]: if "subjectAge" in res["_source"] else None ), - ) + ).model_dump() for res in output["hits"]["hits"] ] return results