from __future__ import annotations

import dataclasses
import typing
import weakref
from collections.abc import Iterator

from flask import request
from flask.wrappers import Response
from limits import RateLimitItem, parse_many
from limits.strategies import RateLimiter
from limits.util import WindowStats

from .typing import Callable

if typing.TYPE_CHECKING:
    from .extension import Limiter


class RequestLimit:
    """
    Provides details of a rate limit within the context of a request
    """

    #: The instance of the rate limit
    limit: RateLimitItem

    #: The full key for the request against which the rate limit is tested
    key: str

    #: Whether the limit was breached within the context of this request
    breached: bool

    #: Whether the limit is a shared limit
    shared: bool

    def __init__(
        self,
        extension: Limiter,
        limit: RateLimitItem,
        request_args: list[str],
        breached: bool,
        shared: bool,
    ) -> None:
        self.extension: weakref.ProxyType[Limiter] = weakref.proxy(extension)
        self.limit = limit
        self.request_args = request_args
        self.key = limit.key_for(*request_args)
        self.breached = breached
        self.shared = shared
        self._window: WindowStats | None = None

    @property
    def limiter(self) -> RateLimiter:
        return typing.cast(RateLimiter, self.extension.limiter)

    @property
    def window(self) -> WindowStats:
        if not self._window:
            self._window = self.limiter.get_window_stats(self.limit, *self.request_args)

        return self._window

    @property
    def reset_at(self) -> int:
        """Timestamp at which the rate limit will be reset"""

        return int(self.window[0] + 1)

    @property
    def remaining(self) -> int:
        """Quantity remaining for this rate limit"""

        return self.window[1]


@dataclasses.dataclass(eq=True, unsafe_hash=True)
class Limit:
    """
    simple wrapper to encapsulate limits and their context
    """

    limit: RateLimitItem
    key_func: Callable[[], str]
    _scope: str | Callable[[str], str] | None
    per_method: bool = False
    methods: tuple[str, ...] | None = None
    error_message: str | None = None
    exempt_when: Callable[[], bool] | None = None
    override_defaults: bool | None = False
    deduct_when: Callable[[Response], bool] | None = None
    on_breach: Callable[[RequestLimit], Response | None] | None = None
    _cost: Callable[[], int] | int = 1
    shared: bool = False

    def __post_init__(self) -> None:
        if self.methods:
            self.methods = tuple([k.lower() for k in self.methods])

    @property
    def is_exempt(self) -> bool:
        """Check if the limit is exempt."""

        if self.exempt_when:
            return self.exempt_when()

        return False

    @property
    def scope(self) -> str | None:
        return (
            self._scope(request.endpoint or "")
            if callable(self._scope)
            else self._scope
        )

    @property
    def cost(self) -> int:
        if isinstance(self._cost, int):
            return self._cost

        return self._cost()

    @property
    def method_exempt(self) -> bool:
        """Check if the limit is not applicable for this method"""

        return self.methods is not None and request.method.lower() not in self.methods

    def scope_for(self, endpoint: str, method: str | None) -> str:
        """
        Derive final bucket (scope) for this limit given the endpoint
        and request method. If the limit is shared between multiple
        routes, the scope does not include the endpoint.
        """
        limit_scope = self.scope

        if limit_scope:
            if self.shared:
                scope = limit_scope
            else:
                scope = f"{endpoint}:{limit_scope}"
        else:
            scope = endpoint

        if self.per_method:
            assert method
            scope += f":{method.upper()}"

        return scope


@dataclasses.dataclass(eq=True, unsafe_hash=True)
class LimitGroup:
    """
    represents a group of related limits either from a string or a callable
    that returns one
    """

    limit_provider: Callable[[], str] | str
    key_function: Callable[[], str]
    scope: str | Callable[[str], str] | None = None
    methods: tuple[str, ...] | None = None
    error_message: str | None = None
    exempt_when: Callable[[], bool] | None = None
    override_defaults: bool | None = False
    deduct_when: Callable[[Response], bool] | None = None
    on_breach: Callable[[RequestLimit], Response | None] | None = None
    per_method: bool = False
    cost: Callable[[], int] | int | None = None
    shared: bool = False

    def __iter__(self) -> Iterator[Limit]:
        limit_str = (
            self.limit_provider()
            if callable(self.limit_provider)
            else self.limit_provider
        )
        limit_items = parse_many(limit_str) if limit_str else []

        for limit in limit_items:
            yield Limit(
                limit,
                self.key_function,
                self.scope,
                self.per_method,
                self.methods,
                self.error_message,
                self.exempt_when,
                self.override_defaults,
                self.deduct_when,
                self.on_breach,
                self.cost or 1,
                self.shared,
            )
