Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
* Make data loading during tabulate() optional
* Use bulk insert for IndexSetData
* Use a normal property for `.data`
  • Loading branch information
glatterf42 committed Oct 29, 2024
1 parent 2b62ae7 commit a7246cd
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 58 deletions.
12 changes: 7 additions & 5 deletions ixmp4/core/optimization/indexset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ def created_by(self) -> str | None:
return self._model.created_by

@property
def docs(self):
def docs(self) -> str | None:
try:
return self.backend.optimization.indexsets.docs.get(self.id).description
except DocsModel.NotFound:
return None

@docs.setter
def docs(self, description):
def docs(self, description: str | None) -> None:
if description is None:
self.backend.optimization.indexsets.docs.delete(self.id)
else:
self.backend.optimization.indexsets.docs.set(self.id, description)

@docs.deleter
def docs(self):
def docs(self) -> None:
try:
self.backend.optimization.indexsets.docs.delete(self.id)
# TODO: silently failing
Expand Down Expand Up @@ -105,7 +105,9 @@ def list(self, name: str | None = None) -> list[IndexSet]:
for i in indexsets
]

def tabulate(self, name: str | None = None) -> pd.DataFrame:
def tabulate(
self, name: str | None = None, include_data: bool = False
) -> pd.DataFrame:
return self.backend.optimization.indexsets.tabulate(
run_id=self._run.id, name=name
run_id=self._run.id, name=name, include_data=include_data
)
9 changes: 7 additions & 2 deletions ixmp4/data/abstract/optimization/indexset.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ def list(self, *, name: str | None = None, **kwargs) -> list[IndexSet]:
"""
...

def tabulate(self, *, name: str | None = None, **kwargs) -> pd.DataFrame:
def tabulate(
self, *, name: str | None = None, include_data: bool = False, **kwargs
) -> pd.DataFrame:
r"""Tabulate IndexSets by specified criteria.
Parameters
----------
name : str
name : str, optional
The name of an IndexSet. If supplied only one result will be returned.
include_data : bool, optional
Whether to load all IndexSet data, which reduces loading speed. Defaults to
`False`.
# TODO: Update kwargs
\*\*kwargs: any
More filter parameters as specified in
Expand Down
18 changes: 4 additions & 14 deletions ixmp4/data/api/optimization/indexset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import ClassVar, List

import pandas as pd
from pydantic import StrictFloat, StrictInt, StrictStr

from ixmp4.data import abstract

Expand All @@ -17,13 +16,7 @@ class IndexSet(base.BaseModel):

id: int
name: str
data: (
StrictFloat
| StrictInt
| StrictStr
| list[StrictFloat | StrictInt | StrictStr]
| None
)
data: float | int | str | list[int | float | str] | None
run__id: int

created_at: datetime | None
Expand Down Expand Up @@ -64,16 +57,13 @@ def enumerate(self, **kwargs) -> list[IndexSet] | pd.DataFrame:
def list(self, **kwargs) -> list[IndexSet]:
return super()._list(json=kwargs)

def tabulate(self, **kwargs) -> pd.DataFrame:
return super()._tabulate(json=kwargs)
def tabulate(self, include_data: bool = False, **kwargs) -> pd.DataFrame:
return super()._tabulate(json=kwargs, params={"include_data": include_data})

def add_data(
self,
indexset_id: int,
data: StrictFloat
| StrictInt
| List[StrictFloat | StrictInt | StrictStr]
| StrictStr,
data: float | int | str | List[float | int | str],
) -> None:
kwargs = {"indexset_id": indexset_id, "data": data}
self._request("PATCH", self.prefix + str(indexset_id) + "/", json=kwargs)
16 changes: 8 additions & 8 deletions ixmp4/data/db/optimization/indexset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,22 @@ class IndexSet(base.BaseModel):
DataInvalid: ClassVar = OptimizationDataValidationError
DeletionPrevented: ClassVar = abstract.IndexSet.DeletionPrevented

data_type: types.OptimizationDataType
_data_type: types.OptimizationDataType

_data: types.Mapped[list["IndexSetData"]] = db.relationship(
back_populates="indexset"
)

@db.hybrid_property
@property
def data(self) -> list[float | int | str]:
return (
[]
if self.data_type is None
else np.array([d.value for d in self._data], dtype=self.data_type).tolist()
if self._data_type is None
else np.array([d.value for d in self._data], dtype=self._data_type).tolist()
)

# NOTE For the core layer (setting and retrieving) to work, the property needs a
# setter method
@data.inplace.setter
def _data_setter(self, value: list[float | int | str]) -> None:
@data.setter
def data(self, value: list[float | int | str]) -> None:
return None

run__id: types.RunId
Expand All @@ -42,6 +40,8 @@ def _data_setter(self, value: list[float | int | str]) -> None:


class IndexSetData(base.RootBaseModel):
table_prefix = "optimization_"

indexset: types.Mapped["IndexSet"] = db.relationship(back_populates="_data")
indexset__id: types.IndexSetId
value: types.String = db.Column(db.String, nullable=False)
Expand Down
38 changes: 23 additions & 15 deletions ixmp4/data/db/optimization/indexset/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,21 @@ def list(self, *args, **kwargs) -> list[IndexSet]:
return super().list(*args, **kwargs)

@guard("view")
def tabulate(self, *args, **kwargs) -> pd.DataFrame:
result = super().tabulate(*args, **kwargs).drop(labels="data_type", axis=1)
result.insert(
loc=0,
column="data",
value=[self.get_by_id(id=indexset_id).data for indexset_id in result.id],
)
return result
def tabulate(self, *args, include_data: bool = False, **kwargs) -> pd.DataFrame:
if not include_data:
return (
super()
.tabulate(*args, **kwargs)
.rename(columns={"_data_type": "data_type"})
)
else:
result = super().tabulate(*args, **kwargs).drop(labels="_data_type", axis=1)
result.insert(
loc=0,
column="data",
value=[indexset.data for indexset in self.list(**kwargs)],
)
return result

@guard("edit")
def add_data(
Expand All @@ -78,19 +85,20 @@ def add_data(
indexset = self.get_by_id(id=indexset_id)
if not isinstance(data, list):
data = [data]
# TODO If adding rows one by one is too expensive, look into executemany pattern
for value in data:
self.session.add(
IndexSetData(indexset=indexset, indexset__id=indexset_id, value=value)
)

bulk_insert_enabled_data: list[dict[str, str]] = [
{"value": str(d)} for d in data
]
try:
self.session.flush()
self.session.execute(
db.insert(IndexSetData).values(indexset__id=indexset_id),
bulk_insert_enabled_data,
)
except db.IntegrityError as e:
self.session.rollback()
raise indexset.DataInvalid from e

indexset.data_type = type(data[0]).__name__
indexset._data_type = type(data[0]).__name__

self.session.add(indexset)
self.session.commit()
13 changes: 8 additions & 5 deletions ixmp4/server/rest/optimization/indexset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import APIRouter, Body, Depends, Query
from pydantic import StrictFloat, StrictInt, StrictStr

from ixmp4.data import api
from ixmp4.data.backend.db import SqlAlchemyBackend as Backend
Expand All @@ -21,16 +20,19 @@ class IndexSetInput(BaseModel):


class DataInput(BaseModel):
data: (
StrictFloat | StrictInt | StrictStr | list[StrictFloat | StrictInt | StrictStr]
)
data: float | int | str | list[float | int | str]


@autodoc
@router.patch("/", response_model=EnumerationOutput[api.IndexSet])
@router.patch(
"/",
response_model=EnumerationOutput[api.IndexSet],
response_model_exclude={"_data_type"},
)
def query(
filter: OptimizationIndexSetFilter = Body(OptimizationIndexSetFilter()),
table: bool = Query(False),
include_data: bool = Query(False),
pagination: Pagination = Depends(),
backend: Backend = Depends(deps.get_backend),
):
Expand All @@ -40,6 +42,7 @@ def query(
limit=pagination.limit,
offset=pagination.offset,
table=bool(table),
include_data=bool(include_data),
),
total=backend.optimization.indexsets.count(_filter=filter),
pagination=pagination,
Expand Down
44 changes: 40 additions & 4 deletions tests/core/test_optimization_indexset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from ..utils import create_indexsets_for_run


def df_from_list(indexsets: list[IndexSet]):
return pd.DataFrame(
def df_from_list(indexsets: list[IndexSet], include_data: bool = False) -> pd.DataFrame:
result = pd.DataFrame(
# Order is important here to avoid utils.assert_unordered_equality,
# which doesn't like lists
[
[
indexset.data,
indexset.run_id,
indexset.name,
indexset.id,
Expand All @@ -25,14 +24,27 @@ def df_from_list(indexsets: list[IndexSet]):
for indexset in indexsets
],
columns=[
"data",
"run__id",
"name",
"id",
"created_at",
"created_by",
],
)
if include_data:
result.insert(
loc=0, column="data", value=[indexset.data for indexset in indexsets]
)
else:
result.insert(
loc=0,
column="data_type",
value=[
type(indexset.data[0]).__name__ if indexset.data != [] else None
for indexset in indexsets
],
)
return result


class TestCoreIndexset:
Expand Down Expand Up @@ -74,6 +86,21 @@ def test_add_data(self, platform: ixmp4.Platform):
with pytest.raises(OptimizationDataValidationError):
indexset_2.add(["baz", "baz"])

# Test data types are conserved
indexset_3 = run.optimization.indexsets.create("Indexset 3")
test_data_2: list[float | int | str] = [1.2, 3.4, 5.6]
indexset_3.add(data=test_data_2)

assert indexset_3.data == test_data_2
assert type(indexset_3.data[0]).__name__ == "float"

indexset_4 = run.optimization.indexsets.create("Indexset 4")
test_data_3: list[float | int | str] = [0, 1, 2]
indexset_4.add(data=test_data_3)

assert indexset_4.data == test_data_3
assert type(indexset_4.data[0]).__name__ == "int"

def test_list_indexsets(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
indexset_1, indexset_2 = create_indexsets_for_run(
Expand Down Expand Up @@ -116,6 +143,15 @@ def test_tabulate_indexsets(self, platform: ixmp4.Platform):
result = run.optimization.indexsets.tabulate(name="Indexset 2")
pdt.assert_frame_equal(expected, result)

# Test tabulating including the data
expected = df_from_list(indexsets=[indexset_2], include_data=True)
pdt.assert_frame_equal(
expected,
run.optimization.indexsets.tabulate(
name=indexset_2.name, include_data=True
),
)

def test_indexset_docs(self, platform: ixmp4.Platform):
run = platform.runs.create("Model", "Scenario")
(indexset_1,) = tuple(
Expand Down
Loading

0 comments on commit a7246cd

Please sign in to comment.