Skip to content

Commit

Permalink
WIP Models
Browse files Browse the repository at this point in the history
  • Loading branch information
lenhoanglnh committed Feb 23, 2025
1 parent 51e9b26 commit 494c377
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 21 deletions.
18 changes: 16 additions & 2 deletions solidago/src/solidago/state/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,17 @@ def is_base(self) -> bool:
def evaluated_entities(self, entities: "Entities") -> "Entities":
return entities if self.is_base() else self.parent.evaluated_entities(entities)

def set_depth(self, depth: int) -> None:
def set_depth(self, depth: int, change_scales=True) -> None:
self.depth = depth
if self.user_models is None:
self.scales["depth"] = self.scales["depth"] + 1
else:
scales = self.user_models.user_scales
indices = scales.get(username=self.username).index
for i in indices:
self.user_models.user_scales.iloc[i, "depth"] = scales.iloc[i, "depth"] + 1
if not self.is_base():
self.parent.set_depth(depth + 1)
self.parent.set_depth(depth + 1, change_scales=False)

def to_direct(self, entities: "Entities") -> "DirectScoring":
from .direct import DirectScoring
Expand Down Expand Up @@ -121,6 +128,13 @@ def save(self, filename: Optional[Union[str, Path]]=None) -> tuple[str, dict]:
with open(filename, "w") as f:
json.dump([type(self).__name__, saved_dict], f, indent=4)
return type(self).__name__, kwargs

def is_cls(self, cls: tuple[str, dict]) -> bool:
if type(self).__name__ != cls[0]:
return False
if any({ getattr(self, key) != value for key, value in cls[1].items() if key != "parent" }):
return False
return self.is_base() or self.parent.is_cls(cls[1]["parent"])

def base_model(self) -> "BaseModel":
return self if self.is_base() else self.parent.base_model()
Expand Down
64 changes: 45 additions & 19 deletions solidago/src/solidago/state/models/user_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self,
user_directs: Optional[Union[str, DataFrame, MultiScore]]=None,
user_scales: Optional[Union[str, DataFrame, MultiScore]]=None,
common_scales: Optional[Union[str, DataFrame, MultiScore]]=None,
default_model_cls: tuple[str, dict]=("DirectScoring", None),
default_model_cls: Optional[tuple[str, dict]]=None,
user_model_cls_dict: Optional[dict[str, tuple]]=None
):
self.user_directs = user_directs or MultiScore.load(directs,
Expand All @@ -29,7 +29,7 @@ def __init__(self,
key_names=["depth", "kind", "criterion"],
name="common_scales"
)
self.default_model_cls = default_model_cls
self.default_model_cls = default_model_cls or ("DirectScoring", dict())
self.user_model_cls_dict = user_model_cls_dict or dict()
self._cache_users = set()

Expand Down Expand Up @@ -62,32 +62,35 @@ def __getitem__(self, user: Union[str, "User"]) -> ScoringModel:
import solidago.state.models as models
constructor_name, kwargs = self.model_cls(user)
return models.constructor_name(
directs=self.directs.get(username=user, cache_group=True),
scales=self.scales.get(username=user, cache_group=True),
directs=self.user_directs.get(username=user, cache_group=True),
scales=self.user_scales.get(username=user, cache_group=True) \
| self.common_scales.assign(username=str(user)),
username=str(user),
user_models=self,
**kwargs
)

def __delitem__(self, user: Union[str, "User"]) -> None:
self.directs = self.directs.delete(username=str(user))
self.scales = self.scales.delete(username=str(user))
self.user_directs = self.user_directs.delete(username=str(user))
self.user_scales = self.user_scales.delete(username=str(user))
if str(user) in self.user_model_cls_dict:
del self.user_model_cls_dict[str(user)]
if self._cache_users is not None:
self._cache_users.remove(str(user))

def __setitem__(self, user: Union[str, "User"], model: ScoringModel) -> None:
del self[user]
self.directs = self.directs | model.directs.assign(username=str(user))
self.scales = self.scales | model.scales.assign(username=str(user))
self.user_model_cls_dict[str(user)] = model.save()
self.user_directs = self.user_directs | model.directs.assign(username=str(user))
self.user_scales = self.user_scales | model.scales.assign(username=str(user))
if not model.is_cls(self.default_model_cls):
self.user_model_cls_dict[str(user)] = model.save()
if self._cache_users is not None:
self._cache_users.add(str(user))

def users(self) -> set[str]:
if self._cache_users is None:
self._cache_users = set(self.directs["username"]) | set(self.scales["username"]) \
self._cache_users = set(self.user_directs["username"]) \
| set(self.user_scales["username"]) \
| set(self.user_model_cls_dict.keys())
return self._cache_users

Expand All @@ -100,19 +103,38 @@ def __iter__(self) -> Iterable:

def scale(self,
mutlipliers: Optional[MultiScore]=None,
translations: Optional[MultiScore]=None
translations: Optional[MultiScore]=None,
note: str="None",
) -> UserModels:
scale_key_names = ["username", "depth", "kind", "criterion"]
multipliers = multipliers or MultiScore(key_names=scale_key_names)
translations = translations or MultiScore(key_names=scale_key_names)
assert multipliers is not None or translations is not None
multipliers = multipliers or MultiScore(key_names=translations.key_names)
translations = translations or MultiScore(key_names=multipliers.key_names)
user_scales = self.user_scales.assign(depth=self.user_scales["depth"] + 1)
common_scales = self.common_scales.assign(depth=self.common_scales["depth"] + 1)
if "username" in multipliers.key_names:
user_scales = user_scales | multipliers | translations
else:
common_scales = common_scales | multipliers | translations
return UserModels(
user_directs=self.user_directs,
user_scales=user_scales,
common_scales=common_scales,
default_model_cls=("ScaledModel", dict(note=note, parent=self.default_model_cls)),
user_model_cls_dict={
username: ("ScaledModel", dict(note=note, parent=model_cls))
for username, model_cls in self.user_model_cls_dict.items()
}
)

def save(self, directory: Union[Path, str], json_dump: bool=False) -> tuple[str, dict]:
assert isinstance(directory, (Path, str)), directory
j = type(self).__name__, dict()
if not self.directs.empty:
j[1]["directs"] = self.directs.to_csv(directory)[1]
if not self.scales.empty:
j[1]["scales"] = self.scales.to_csv(directory)[1]
if not self.user_directs.empty:
j[1]["user_directs"] = self.user_directs.to_csv(directory)[1]
if not self.user_scales.empty:
j[1]["user_scales"] = self.user_scales.to_csv(directory)[1]
if not self.common_scales.empty:
j[1]["common_scales"] = self.common_scales.to_csv(directory)[1]
if self.default_model_cls is not None:
j[1]["default_model_cls"] = self.default_model_cls
if len(self.user_model_cls_dict) > 0:
Expand All @@ -123,4 +145,8 @@ def save(self, directory: Union[Path, str], json_dump: bool=False) -> tuple[str,
return j

def __repr__(self) -> str:
return f"{repr(self.directs)}\n\n{repr(self.scales)}"
return "\n\n".join([
repr(df)
for df in (self.user_directs, self.user_scales, self.common_scales)
])

0 comments on commit 494c377

Please sign in to comment.