Skip to content

Commit

Permalink
✨ feat: add providers for ConfigTrees using pydantic-settings
Browse files Browse the repository at this point in the history
  • Loading branch information
guptadev21 committed Jan 28, 2025
1 parent c798adf commit 5ea4e1a
Show file tree
Hide file tree
Showing 7 changed files with 404 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ wheels/
main_test.py
test_config.json

*.env
ignore*
ignore/
.coverage
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.12.5
3.12.6
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ description = "Python SDK for rapyuta.io v2 APIs"
dependencies = [
"httpx>=0.27.2",
"munch>=4.0.0",
"pydantic-settings>=2.7.1",
"pyyaml>=6.0.2",
]
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
1 change: 1 addition & 0 deletions settings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from settings.source import ConfigTreeSource # noqa: F401
50 changes: 50 additions & 0 deletions settings/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional, Type

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from rapyuta_io_sdk_v2 import Configuration
from settings import ConfigTreeSource

# Provider implementation


class AuthConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env")
env: str
auth_token: str
organization_guid: str = Field(alias="ORG")
project_guid: str = Field(alias="PROJ")

def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
cls.__init__(instance, *args, **kwargs)

return Configuration(
auth_token=instance.auth_token,
environment=instance.env,
organization_guid=instance.organization_guid,
project_guid=instance.project_guid,
)


class RRTreeSource:
def __init__(self, config: Type[Configuration] = None, local_file: Optional[str] = None):
"""
Initialize RRTreeSource with optional local file support.
Args:
local_file (Optional[str]): Path to a local JSON/YAML file for the config tree. Defaults to None.
"""
config = config if config is not None else AuthConfig()

# Create a ConfigTreeSource instance
self.config_tree_source = ConfigTreeSource(
settings_cls=AuthConfig, config=config, local_file=local_file
)

self.config_tree_source.load_config_tree()

# Print the loaded configuration tree
print(self.config_tree_source.config_tree)

126 changes: 126 additions & 0 deletions settings/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio
import json
from typing import Any, Type, Dict
from benedict import benedict
from munch import Munch

import yaml
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource

from rapyuta_io_sdk_v2 import AsyncClient, Configuration


class ConfigTreeSource(PydanticBaseSettingsSource):
def __init__(
self,
settings_cls: Type["BaseSettings"],
config: Configuration,
tree_name: str = "default",
key_prefix: str = "",
local_file: str = None,
):
super().__init__(settings_cls)
self.async_client = AsyncClient(config=config)
self.tree_name = tree_name
self.local_file = local_file
self._top_prefix = key_prefix

# ? Placeholder for the configuration tree data
self.config_tree = None

# * Load the configuration tree
self.load_config_tree()

processed_data = self._process_config_tree(raw_data=self.config_tree)

self._configtree_data = benedict(processed_data).unflatten(separator="/")
# print(self._configtree_data)

# * Methods to fetch Configtree
async def fetch_from_api(self):
"""
Load the configuration tree from an external API.
"""
try:
response = await self.async_client.get_configtree(
name=self.tree_name,
include_data=True,
content_types="kv",
with_project=False,
)
config_tree_response = Munch.toDict(response).get("keys")
self.config_tree = self._extract_data_api(input_data=config_tree_response)
except Exception as e:
raise ValueError(f"Failed to fetch configuration tree from API: {e}")

# TODO: Check remaining for this
def load_from_local_file(self):
"""
Load the configuration tree from a local JSON or YAML file.
"""
if not self.local_file:
raise ValueError("No local file path provided for configuration tree.")

try:
with open(self.local_file, "r") as file:
if self.local_file.endswith(".json"):
self.config_tree = self._extract_data_local(json.load(file))
elif self.local_file.endswith(".yaml") or self.local_file.endswith(
".yml"
):
self.config_tree = self._extract_data_local(yaml.safe_load(file))
else:
raise ValueError(
"Unsupported file format. Use .json or .yaml/.yml."
)
except FileNotFoundError:
raise ValueError(f"Local file '{self.local_file}' not found.")
except Exception as e:
raise ValueError(f"Failed to load configuration tree from file: {e}")

def load_config_tree(self):
"""
Load the configuration tree from either an API or a local file.
"""
if self.local_file:
self.load_from_local_file()
else:
asyncio.run(self.fetch_from_api())

# * Methods to process the tree

def _extract_data_api(self, input_data: Dict[str, Any] = None) -> Dict[str, Any]:
return {key: value.get("data") for key, value in input_data.items() if "data" in value}

def _extract_data_local(self, input_data: Dict[str, Any] = None) -> Dict[str, Any]:
for key, value in input_data.items():
if isinstance(value, dict):
if "value" in value:
input_data[key] = value.get("value")
else:
self._extract_data_local(value)
return input_data

# * This method is extracting the data from the raw data and removing the top level prefix
def _process_config_tree(self, raw_data: Dict[str, Any]) -> Dict[str, Any]:
d: Dict[str, Any] = {}
prefix_length = len(self._top_prefix)

if prefix_length == 0:
return raw_data

for key in raw_data:
processed_key = key[prefix_length + 1 :]
d[processed_key] = raw_data[key]

return d

def __call__(self) -> dict[str, Any]:
if self.config_tree is None:
raise ValueError("Configuration tree is not loaded.")
return self._configtree_data

def get_field_value(self, field_name: str) -> Any:
if self.config_tree is None:
raise ValueError("Configuration tree is not loaded.")
return self.config_tree.get(field_name, None)
Loading

0 comments on commit 5ea4e1a

Please sign in to comment.