diff --git a/alembic/env.py b/alembic/env.py index bce58936822dbb13aecf1857ee188e0b5ce17346..cb290b379aec01d0897fb53bd58ab298c63de1a8 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -3,11 +3,6 @@ from logging.config import fileConfig from sqlalchemy import engine_from_config, pool import iou.db.sql_root as iou_db_model -from iou.lib.user import User -from iou.lib.group import Group, GroupMembership -from iou.lib.deposit import Deposit -from iou.lib.withdrawal import Withdrawal -from iou.lib.transaction import Transaction # alembic.context is only available through the alembic CLI # pylint: disable=no-name-in-module diff --git a/alembic/versions/0ccceb2e05e0_.py b/alembic/versions/0ccceb2e05e0_.py index 14e8f29ea099652724a1054eeb81418c69b16e94..cc213f671c66643362be24f27010b6da2401cbd4 100644 --- a/alembic/versions/0ccceb2e05e0_.py +++ b/alembic/versions/0ccceb2e05e0_.py @@ -1,7 +1,7 @@ """empty message Revision ID: 0ccceb2e05e0 -Revises: +Revises: Create Date: 2024-12-13 23:44:32.120337 """ diff --git a/iou/api/v1/schemas/pagination.py b/iou/api/v1/schemas/pagination.py index 07bdc79c22e692721d03fe42a9b529ad97ce01fd..e128b4f37539039150438dac6f470549c5411884 100644 --- a/iou/api/v1/schemas/pagination.py +++ b/iou/api/v1/schemas/pagination.py @@ -5,7 +5,7 @@ from pydantic import BaseModel class PaginatedResult[T](BaseModel): - result_list: list[T] + page: list[T] more: bool diff --git a/iou/api/v1/users.py b/iou/api/v1/users.py index 35a5cce0c88c078d62e3ad3684926447647fb396..7981c1eb7d791b8afa887af9768c54dbc7a92ba1 100644 --- a/iou/api/v1/users.py +++ b/iou/api/v1/users.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Annotated from fastapi import APIRouter, Depends, status @@ -5,6 +6,7 @@ from fastapi import APIRouter, Depends, status from iou.api import dependencies from iou.api.v1 import utils from iou.api.v1.schemas.group import GroupOut +from iou.api.v1.schemas.pagination import PaginatedResult, PaginationParams from iou.api.v1.schemas.user import UserIn, UserOut, UserSearchParams, UserUpdate from iou.db.db_interface import IouDBInterface from iou.lib.user import User @@ -13,18 +15,23 @@ from iou.security import Authentication router = APIRouter() -@router.get("", response_model=list[UserOut]) +@router.get("", response_model=PaginatedResult[UserOut]) def read_users( authentication: Annotated[Authentication, Depends(dependencies.get_authentication)], database: Annotated[IouDBInterface, Depends(dependencies.get_db)], search_params: Annotated[UserSearchParams, Depends()], -) -> list[UserOut]: - return [ - UserOut(name=user.name, email=user.email, user_id=user.user_id) - for _, user in database.users( - name=search_params.name, email=search_params.email - ).items() - ] + pagination_params: Annotated[PaginationParams, Depends()], +) -> PaginatedResult[UserOut]: + users_partial = partial( + database.users, limit=25, name=search_params.name, email=search_params.email + ) + return PaginatedResult( + page=[ + UserOut(name=user.name, email=user.email, user_id=user.user_id) + for _, user in users_partial(offset=pagination_params.offset).items() + ], + more=len(users_partial(offset=pagination_params.offset + 1)) > 0, + ) @router.post("", response_model=UserOut) diff --git a/iou/db/db_interface.py b/iou/db/db_interface.py index 0aa3253f44a379889dc3a8b635ea502c8234e729..506b822da2235b1f3ff448d03a3af9e8be6422e5 100644 --- a/iou/db/db_interface.py +++ b/iou/db/db_interface.py @@ -23,7 +23,13 @@ class IouDBInterface(ABC): pass @abstractmethod - def get_users(self) -> list[User]: + def get_users( + self, + offset: int = 0, + limit: int = 25, + name: str | None = None, + email: str | None = None, + ) -> list[User]: pass @abstractmethod @@ -60,7 +66,11 @@ class IouDBInterface(ABC): @abstractmethod def users( - self, name: str | None = None, email: str | None = None + self, + offset: int = 0, + limit: int = 25, + name: str | None = None, + email: str | None = None, ) -> dict[str, User]: pass diff --git a/iou/db/mock_db.py b/iou/db/mock_db.py index afcf07d616d8838a57b201b48a4fa4d640203233..5ef2882c5094d86f5192ef013cceb97e06d767f9 100644 --- a/iou/db/mock_db.py +++ b/iou/db/mock_db.py @@ -26,8 +26,21 @@ class MockDB(IouDBInterface): ) -> None: pass - def get_users(self) -> list[User]: - return list(self._users.values()) + def get_users( + self, + offset: int = 0, + limit: int = 25, + name: str | None = None, + email: str | None = None, + ) -> list[User]: + users = [ + user + for user in self._users.values() + if (name is None and email is None) + or (name is not None and user.name == name) + or (email is not None and user.email == email) + ] + return users[limit * offset : limit] def add_user(self, user: User) -> None: self._users[user.user_id] = user @@ -56,12 +69,14 @@ class MockDB(IouDBInterface): del self._groups[group_id] def users( - self, name: str | None = None, email: str | None = None + self, + offset: int = 0, + limit: int = 25, + name: str | None = None, + email: str | None = None, ) -> dict[str, User]: return { - user_id: user - for user_id, user in self._users.items() - if user.name == name or user.email == email + user.user_id: user for user in self.get_users(offset, limit, name, email) } def groups(self) -> dict[str, Group]: diff --git a/iou/db/sql_db.py b/iou/db/sql_db.py index 489d065d2b458d4d5f9ce3a08aec83e708355b26..0f6a6bd55009963e7612f0c1ae6c32ef36c50d5f 100644 --- a/iou/db/sql_db.py +++ b/iou/db/sql_db.py @@ -127,14 +127,18 @@ class SqlDb(IouDBInterface): logger.info("Disposed database connection pool") def get_users( - self, name: str | None = None, email: str | None = None + self, + offset: int = 0, + limit: int = 25, + name: str | None = None, + email: str | None = None, ) -> list[User]: query: Query[User] = self.session.query(User) if name is not None: query = query.filter(User.name.contains(name)) if email is not None: query = query.filter(User.email.contains(email)) - return query.offset(0).limit(25).all() + return query.offset(offset).limit(limit).all() def add_user(self, user: User) -> None: self.session.add(user) @@ -165,9 +169,15 @@ class SqlDb(IouDBInterface): self.session.delete(self.session.get(Group, group_id)) def users( - self, name: str | None = None, email: str | None = None + self, + offset: int = 0, + limit: int = 25, + name: str | None = None, + email: str | None = None, ) -> dict[str, User]: - return {user.user_id: user for user in self.get_users(name, email)} + return { + user.user_id: user for user in self.get_users(offset, limit, name, email) + } def groups(self) -> dict[str, Group]: return {group.group_id: group for group in self.get_groups()} diff --git a/test/test_api.py b/test/test_api.py index c629c3b006be60942473a4f5ad5a2a8a9a529257..1c94343de01f88f3bc6b5beb5dcd72acdc15ab68 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -178,7 +178,7 @@ class AbstractTestAPI(ABC): params={"name": "user5"}, ) assert response.status_code == 200, response.json() - assert response.json()[0]["name"] == "user5" + assert response.json()["page"][0]["name"] == "user5" @pytest.mark.asyncio async def test_create_user(self, iou_client: AsyncClient) -> None: