import csv

from sqlalchemy.orm import selectinload
from typing import Iterable
from src.data.repositories.base import BaseRepository
from sqlalchemy import select
from src.data.model.schemas import article as schema, PictureTypes
from src.data.model.db import article as db
from src.data.model.db import partner_agent as db_partner
from typing import Sequence
from src.data.model.schemas import Shop
from src.config.shops_config import *
from io import StringIO
from src.data.model.processing import price_conditions

shop_config = load_shops_config()


class ArticleRepository(BaseRepository):
    """
    Repository for article data
    """

    async def get_articles_csv_export(self):
        """
        Creates a csv export for b2b articles
        :return: path to the export file
        """
        query = await self.async_session.scalars(
            select(db.Picture).where(db.Picture.type == PictureTypes.ARTICLE_IMAGE))
        picture_paths = {
            picture.ref_id: picture.path
            for picture in query.all()
        }

        stmt = (select(db.ArticleDetail, db.Article)
                .join(db.Article)
                .options(selectinload(db.ArticleDetail.article))
                .options(selectinload(db.ArticleDetail.color))
                .options(selectinload(db.Article.group))
                .options(selectinload(db.Article.season))
                .where(self.__select_articles_by_shop(Shop.B2B)))

        query = await self.async_session.scalars(stmt)
        buffer = StringIO()
        writer = csv.writer(buffer)
        writer.writerow(['Model number', 'Article', 'Category', 'Gender',
                         'Picture', 'EAN', 'Size register', 'Size',
                         'Color code', 'Color name', 'Season'])

        for article_detail in query.all():
            writer.writerow(
                [
                    article_detail.article_number,
                    article_detail.article.group.designation,
                    article_detail.article,
                    "",
                    # todo: rebuild database
                    # picture_paths[article_detail.picture_reference],
                    article_detail.ean,
                    article_detail.size_register,
                    article_detail.size,
                    article_detail.color.code,
                    article_detail.color.designation,
                    article_detail.article.season.designation
                ]
            )

        buffer.seek(0)
        return buffer

    async def get_articles(self, shop: Shop, color: str | None = None, size: str | None = None,
                           theme: int | None = None) -> list[schema.Article]:
        """
        Get all available articles based on shop
        :param shop: shop constant
        :return: the list of available articles
        """

        stmt = (select(db.Article).
                distinct(db.Article.article_number).
                options(selectinload(db.Article.group)).
                join(db.Group).
                join(db.ArticleDetail))

        where = self.__select_articles_by_shop(shop)
        if color:
            where &= db.ArticleDetail.color_code == color
        if size:
            where &= db.ArticleDetail.size == size
        if theme:
            where &= db.Article.theme_group_nr == theme

        stmt = stmt.where(where)

        query = await self.async_session.scalars(stmt)

        return [
            schema.Article(
                name=article.group.designation,
                number=article.article_number,
                material=article.composition,
                brand=article.brand_designation
            )
            for article in query.all()
        ]

    async def get_article_color_size(self, query: schema.GetArticleColorSize):
        """
        Get the color size information of an article
        :param query: article query to determinate the article color size data
        :return:the color size information of by query selected article
        """

        # load partner
        partner = await self.async_session.get(db_partner.Partner, query.partner_id)
        surcharge_code = partner.surcharge_code

        # load relevant price lists
        price_lists = await self.async_session.scalars(
            select(db.Pricelist).
                join(db_partner.SalesOrderTypePartner,
                     onclause=db.Pricelist.code == db_partner.SalesOrderTypePartner.pricelist_code).
                where(
                (db_partner.SalesOrderTypePartner.code == query.sales_order_type) &
                (db_partner.SalesOrderTypePartner.partner_num == query.partner_id)
            )
        )
        prices = {
            price_list_item.article_detail_ean: price_list_item.price
            for price_list_item in price_lists.fetchall()
        }

        if not len(prices):
            raise ValueError("There are no prices for given partner and sales order type")

        def article_size_to_scheme(article_sizes: Iterable[db.ArticleDetail]):
            return [
                schema.ArticleSize(
                    size=article_size.size,
                    size_register=article_size.size_group,
                    price=price_conditions.apply_surcharge(prices[article_size.ean], article_size.size, surcharge_code),
                    ean=article_size.ean,
                    quantity=article_size.quantity
                )
                for article_size in article_sizes if article_size.quantity > 0
            ]

        def get_colors(article: db.Article):
            colors = []
            color_pics = pictures['color']
            for code, designation in {
                (color_size.color.code, color_size.color.designation) for color_size in article.color_sizes
            }:
                sizes = article_size_to_scheme(filter(lambda a: a.color_code == code, article.color_sizes))
                if not len(sizes):
                    continue
                if code in color_pics.keys():
                    picture = color_pics[code]
                else:
                    picture = "https://via.placeholder.com/700x400"
                colors.append(
                    schema.ArticleColor(
                        color_name=designation,
                        color_code=code,
                        sizes=sizes,
                        picture=picture
                    )
                )
            return colors

        stmt = (select(db.Article, db.ArticleDetail)
                .join(db.ArticleDetail)
                .options(selectinload(db.Article.group))
                .options(selectinload(db.Article.season))
                .options(selectinload(db.Article.color_sizes))
                .options(selectinload(db.ArticleDetail.color))
                .where(db.Article.article_number == query.article_number))

        article = (await self.async_session.scalars(stmt)).unique().one()
        pictures = await self.load_pictures_for_article(article)
        article_picture = pictures['article']

        schema_result = schema.ArticleColorSize(
            article_number=article.article_number,
            article_name=article.group.designation,
            season=article.season.code,
            delivery_date_to=article.date_to,
            delivery_date_from=article.date_from,
            material=article.composition,
            picture=article_picture,
            colors=get_colors(article)
        )

        return schema_result

    async def get_colors(self, shop: Shop) -> list[schema.Color]:
        """
        Get all unique colors of available articles
        :param shop: shop constant
        :return: the list of available articles unique colors
        """
        stmt = select(db.Color).distinct().join(db.ArticleDetail).join(db.Article).where(
            self.__select_articles_by_shop(shop))
        query = await self.async_session.scalars(stmt)
        colors = set()
        result = []
        for color in query.all():
            if color.designation not in colors:
                colors.add(color.designation)
                result.append(
                    schema.Color(
                        code=color.code,
                        name=color.designation
                    )
                )
        return result

    async def get_sizes(self, shop: Shop) -> list[schema.Size]:
        """
        Get all unique sizes of available articles
        :param shop: shop constant, which determinate
        :return: the list of available articles unique sizes
        """
        stmt = (select(db.ArticleDetail.size).
                distinct().
                order_by(db.ArticleDetail.size).
                join(db.Article).
                where(self.__select_articles_by_shop(shop)))
        query = await self.async_session.scalars(stmt)
        return [schema.Size(size=size) for size in query.all()]

    async def __get_articles(self, shop: Shop) -> Sequence[db.ArticleDetail]:
        """
        Get all available articles for selected shop
        :param shop: shop constant to determinate the availability
        :return: the list of orm objects representing articles
        """
        stmt = select(db.ArticleDetail).join(db.Color).join(db.Article).where(self.__select_articles_by_shop(shop))
        stmt = stmt.limit(25)
        query = await self.async_session.scalars(stmt)
        return query.unique().fetchall()

    @staticmethod
    def __select_articles_by_shop(shop: Shop):
        mobi_season = shop_config.mobi_season
        if shop == Shop.B2B:
            return (db.Article.season_nr != mobi_season) & (db.ArticleDetail.quantity > 0)
        else:
            return (db.Article.season_nr == mobi_season) & (db.ArticleDetail.quantity > 0)

    async def load_pictures_for_article(self, article: db.Article):
        stmt = (select(db.Picture).where(db.Picture.ref_id == article.article_number))
        article_picture = await self.async_session.scalar(stmt)
        colors = {detail.color.code for detail in article.color_sizes}
        stmt = (select(db.Picture).where(db.Picture.ref_id.in_(colors)))
        query = await self.async_session.scalars(stmt)
        color_pictures = query.all()
        return {
            "article": article_picture.path,
            "color": {
                color_pic.ref_id: color_pic.uri
                for color_pic in color_pictures
            }
        }
