Skip to content

Commit

Permalink
Add "request" as a parameter in the methods get_list_query and get_co…
Browse files Browse the repository at this point in the history
…unt_query (#592)

* Add "request" as a parameter in the method get_list_query and get_count_query

* Add "request" as a parameter in the method get_list_query and get_count_query
  • Loading branch information
jowilf authored Oct 25, 2024
1 parent 5ccd356 commit 889ad5f
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions starlette_admin/contrib/sqla/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def handle_row_action(
except SQLAlchemyError as exc:
raise ActionFailed(str(exc)) from exc

def get_list_query(self) -> Select:
def get_list_query(self, request: Request) -> Select:
"""
Return a Select expression which is used as base statement for
[find_all][starlette_admin.views.BaseModelView.find_all] method.
Expand All @@ -185,10 +185,10 @@ def get_list_query(self) -> Select:
```python hl_lines="3-4"
class PostView(ModelView):
def get_list_query(self):
def get_list_query(self, request: Request):
return super().get_list_query().where(Post.published == true())
def get_count_query(self):
def get_count_query(self, request: Request):
return super().get_count_query().where(Post.published == true())
```
Expand All @@ -198,7 +198,7 @@ def get_count_query(self):
"""
return select(self.model)

def get_count_query(self) -> Select:
def get_count_query(self, request: Request) -> Select:
"""
Return a Select expression which is used as base statement for
[count][starlette_admin.views.BaseModelView.count] method.
Expand All @@ -207,10 +207,10 @@ def get_count_query(self) -> Select:
```python hl_lines="6-7"
class PostView(ModelView):
def get_list_query(self):
def get_list_query(self, request: Request):
return super().get_list_query().where(Post.published == true())
def get_count_query(self):
def get_count_query(self, request: Request):
return super().get_count_query().where(Post.published == true())
```
"""
Expand Down Expand Up @@ -252,7 +252,7 @@ async def count(
where: Union[Dict[str, Any], str, None] = None,
) -> int:
session: Union[Session, AsyncSession] = request.state.session
stmt = self.get_count_query()
stmt = self.get_count_query(request)
if where is not None:
if isinstance(where, dict):
where = build_query(where, self.model)
Expand All @@ -274,7 +274,7 @@ async def find_all(
order_by: Optional[List[str]] = None,
) -> Sequence[Any]:
session: Union[Session, AsyncSession] = request.state.session
stmt = self.get_list_query().offset(skip)
stmt = self.get_list_query(request).offset(skip)
if limit > 0:
stmt = stmt.limit(limit)
if where is not None:
Expand Down

0 comments on commit 889ad5f

Please sign in to comment.