from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.data.model.db import article as article_db, partner_agent as partner_agent_db
from src.data.model.processing import price_conditions


class BaseRepository:
    def __init__(self, async_session: AsyncSession):
        self.async_session = async_session

    async def next_id(self, id_attr):
        id_query = await self.async_session.execute(func.max(id_attr))
        scalar = id_query.scalar()
        return scalar + 1 if scalar else 0

    async def get_price(self, order_type: int, partner_num: int, ean: str):
        stmt = (select(article_db.Pricelist.price).
            join(
            partner_agent_db.SalesOrderTypePartner,
            onclause=partner_agent_db.SalesOrderTypePartner.pricelist_code == article_db.Pricelist.code).
            where(
            (partner_agent_db.SalesOrderTypePartner.partner_num == partner_num) &
            (partner_agent_db.SalesOrderTypePartner.code == order_type) &
            (article_db.Pricelist.article_detail_ean == ean)
        ))

        query = await self.async_session.execute(stmt)
        price = query.scalar()
        return price

    async def get_prices(self, order_type: int, partner_num: int, eans: set[str]):
        stmt = ((select(article_db.Pricelist).
            join(
            partner_agent_db.SalesOrderTypePartner,
            onclause=partner_agent_db.SalesOrderTypePartner.pricelist_code == article_db.Pricelist.code).
            where(
            (partner_agent_db.SalesOrderTypePartner.partner_num == partner_num) &
            (partner_agent_db.SalesOrderTypePartner.code == order_type) &
            (article_db.Pricelist.article_detail_ean.in_(eans))
        )))
        query = await self.async_session.scalars(stmt)
        return {
            price_list.ean: price_list.price
            for price_list in query.all()
        }
