"""
Define dependency between async session and repositories
"""
import typing
import fastapi
from src.data.repositories import base
from sqlalchemy.ext.asyncio import AsyncSession
from src.data.model.db.connection import async_engine


async def get_async_session() -> typing.AsyncGenerator[AsyncSession, None]:
    async_session = AsyncSession(bind=async_engine)
    try:
        yield async_session
    except Exception as e:
        await async_session.rollback()
    finally:
        await async_session.close()


def get_repository(
        repo_type: typing.Type[base.BaseRepository],
) -> typing.Callable[[AsyncSession], base.BaseRepository]:
    def _get_repo(
            session: AsyncSession = fastapi.Depends(get_async_session),
    ) -> base.BaseRepository:
        return repo_type(async_session=session)

    return _get_repo
