-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,5 +14,7 @@ wheels/ | |
main_test.py | ||
test_config.json | ||
|
||
*.env | ||
ignore* | ||
ignore/ | ||
.coverage |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
3.12.5 | ||
3.12.6 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from settings.source import ConfigTreeSource, RevisionSource | ||
Check failure on line 1 in settings/__init__.py
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import asyncio | ||
from typing import Optional, Type | ||
|
||
from pydantic import Field | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
||
from rapyuta_io_sdk_v2 import Configuration | ||
from settings import ConfigTreeSource, RevisionSource | ||
|
||
# Provider implementation | ||
|
||
|
||
class AuthConfig(BaseSettings): | ||
model_config = SettingsConfigDict(env_file=".env") | ||
env: str | ||
auth_token: str | ||
organization_guid: str = Field(alias="ORG") | ||
project_guid: str = Field(alias="PROJ") | ||
|
||
def __new__(cls, *args, **kwargs): | ||
instance = super().__new__(cls) | ||
cls.__init__(instance, *args, **kwargs) | ||
|
||
return Configuration( | ||
auth_token=instance.auth_token, | ||
environment=instance.env, | ||
organization_guid=instance.organization_guid, | ||
project_guid=instance.project_guid, | ||
) | ||
|
||
|
||
class RRTreeSource: | ||
def __init__(self, config: Type[Configuration] = None, local_file: Optional[str] = None): | ||
""" | ||
Initialize RRTreeSource with optional local file support. | ||
Args: | ||
local_file (Optional[str]): Path to a local JSON/YAML file for the config tree. Defaults to None. | ||
""" | ||
config = config if config is not None else AuthConfig() | ||
|
||
# Create a ConfigTreeSource instance | ||
self.config_tree_source = ConfigTreeSource( | ||
settings_cls=AuthConfig, config=config, local_file=local_file | ||
) | ||
|
||
asyncio.run(self.config_tree_source.load_config_tree()) | ||
|
||
# Print the loaded configuration tree | ||
print(self.config_tree_source.config_tree) | ||
|
||
|
||
class RevSource: | ||
def __init__(self): | ||
auth_config = AuthConfig() # AuthConfig instance | ||
# Pass AuthConfig (subclass of BaseSettings) as settings_cls | ||
revision_source = RevisionSource( | ||
AuthConfig, config=auth_config, tree_name="ankit" | ||
) | ||
asyncio.run(revision_source.list_revisions()) | ||
print(revision_source.revisions) | ||
|
||
|
||
def test(): | ||
print("Config Tree Source") | ||
RRTreeSource(local_file='ignore_config.json') | ||
Check failure on line 66 in settings/main.py
|
||
print("\nRevision Source") | ||
RevSource() | ||
|
||
|
||
test() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import json | ||
from typing import Any, Type | ||
|
||
import yaml | ||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource | ||
|
||
from rapyuta_io_sdk_v2 import AsyncClient, Configuration | ||
|
||
|
||
class ConfigTreeSource(PydanticBaseSettingsSource): | ||
def __init__( | ||
self, | ||
settings_cls: Type["BaseSettings"], | ||
config: Configuration, | ||
local_file: str = None, | ||
): | ||
super().__init__(settings_cls) | ||
self.async_client = AsyncClient(config=config) | ||
self.local_file = local_file | ||
self.config_tree = None # Placeholder for the configuration tree data | ||
|
||
async def fetch_from_api(self): | ||
""" | ||
Load the configuration tree from an external API. | ||
""" | ||
try: | ||
self.config_tree = await self.async_client.list_configtrees() | ||
except Exception as e: | ||
raise ValueError(f"Failed to fetch configuration tree from API: {e}") | ||
|
||
def load_from_local_file(self): | ||
""" | ||
Load the configuration tree from a local JSON or YAML file. | ||
""" | ||
if not self.local_file: | ||
raise ValueError("No local file path provided for configuration tree.") | ||
|
||
try: | ||
with open(self.local_file, "r") as file: | ||
if self.local_file.endswith(".json"): | ||
self.config_tree = json.load(file) | ||
elif self.local_file.endswith(".yaml") or self.local_file.endswith( | ||
".yml" | ||
): | ||
self.config_tree = yaml.safe_load(file) | ||
else: | ||
raise ValueError("Unsupported file format. Use .json or .yaml/.yml.") | ||
except FileNotFoundError: | ||
raise ValueError(f"Local file '{self.local_file}' not found.") | ||
except Exception as e: | ||
raise ValueError(f"Failed to load configuration tree from file: {e}") | ||
|
||
async def load_config_tree(self): | ||
""" | ||
Load the configuration tree from either an API or a local file. | ||
""" | ||
if self.local_file: | ||
self.load_from_local_file() | ||
else: | ||
await self.fetch_from_api() | ||
|
||
def __call__(self) -> dict[str, Any]: | ||
if self.config_tree is None: | ||
raise ValueError("Configuration tree is not loaded.") | ||
return self.config_tree | ||
|
||
def get_field_value(self, field_name: str) -> Any: | ||
if self.config_tree is None: | ||
raise ValueError("Configuration tree is not loaded.") | ||
return self.config_tree.get(field_name, None) | ||
|
||
|
||
class RevisionSource(PydanticBaseSettingsSource): | ||
def __init__( | ||
self, settings_cls: Type[BaseSettings], config: Configuration, tree_name: str | ||
): | ||
super().__init__(settings_cls) | ||
self.tree_name = tree_name | ||
self.revisions = None | ||
self.async_client = AsyncClient(config=config) | ||
|
||
async def list_revisions(self): | ||
"""Load the revisions from the external source.""" | ||
self.revisions = await self.async_client.list_revisions(tree_name=self.tree_name) | ||
|
||
def __call__(self) -> dict[str, Any]: | ||
"""Provide the revisions as a dictionary.""" | ||
if self.revisions is None: | ||
raise ValueError("Revisions is not loaded.") | ||
return self.revisions | ||
|
||
def get_field_value(self, field_name: str) -> Any: | ||
"""Retrieve the value for a specific field.""" | ||
if self.revisions is None: | ||
raise ValueError("Revisions is not loaded.") | ||
return self.revisions.get(field_name, None) |