# Copyright 2022 Amethyst Reese
# Licensed under the MIT license

"""
Friendlier version of asyncio standard library.

Provisional library.  Must be imported as `aioitertools.asyncio`.
"""

import asyncio
import time
from typing import (
    Any,
    AsyncGenerator,
    AsyncIterable,
    Awaitable,
    cast,
    Dict,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
)

from .builtins import iter as aiter, maybe_await
from .types import AnyIterable, AsyncIterator, MaybeAwaitable, T


async def as_completed(
    aws: Iterable[Awaitable[T]],
    *,
    timeout: Optional[float] = None,
) -> AsyncIterator[T]:
    """
    Run awaitables in `aws` concurrently, and yield results as they complete.

    Unlike `asyncio.as_completed`, this yields actual results, and does not require
    awaiting each item in the iterable.

    Cancels all remaining awaitables if a timeout is given and the timeout threshold
    is reached.

    Example::

        async for value in as_completed(futures):
            ...  # use value immediately

    """
    done: Set[Awaitable[T]] = set()
    pending: Set[Awaitable[T]] = {asyncio.ensure_future(a) for a in aws}
    remaining: Optional[float] = None

    if timeout and timeout > 0:
        threshold = time.time() + timeout
    else:
        timeout = None

    while pending:
        if timeout:
            remaining = threshold - time.time()
            if remaining <= 0:
                for fut in pending:
                    if isinstance(fut, asyncio.Future):
                        fut.cancel()
                    else:  # pragma: no cover
                        pass
                raise asyncio.TimeoutError()

        # asyncio.Future inherits from typing.Awaitable
        # asyncio.wait takes Iterable[Union[Future, Generator, Awaitable]], but
        # returns Tuple[Set[Future], Set[Future]. Because mypy doesn't like assigning
        # these values to existing Set[Awaitable] or even Set[Union[Awaitable, Future]],
        # we need to first cast the results to something that we can actually use
        # asyncio.Future: https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/futures.pyi#L30  # noqa: E501
        # asyncio.wait(): https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/tasks.pyi#L89  # noqa: E501
        done, pending = cast(
            Tuple[Set[Awaitable[T]], Set[Awaitable[T]]],
            await asyncio.wait(
                pending,
                timeout=remaining,
                return_when=asyncio.FIRST_COMPLETED,
            ),
        )

        for item in done:
            yield await item


async def as_generated(
    iterables: Iterable[AsyncIterable[T]],
    *,
    return_exceptions: bool = False,
) -> AsyncIterable[T]:
    """
    Yield results from one or more async iterables, in the order they are produced.

    Like :func:`as_completed`, but for async iterators or generators instead of futures.
    Creates a separate task to drain each iterable, and a single queue for results.

    If ``return_exceptions`` is ``False``, then any exception will be raised, and
    pending iterables and tasks will be cancelled, and async generators will be closed.
    If ``return_exceptions`` is ``True``, any exceptions will be yielded as results,
    and execution will continue until all iterables have been fully consumed.

    Example::

        async def generator(x):
            for i in range(x):
                yield i

        gen1 = generator(10)
        gen2 = generator(12)

        async for value in as_generated([gen1, gen2]):
            ...  # intermixed values yielded from gen1 and gen2
    """

    exc_queue: asyncio.Queue[Exception] = asyncio.Queue()
    queue: asyncio.Queue[T] = asyncio.Queue()

    async def tailer(iter: AsyncIterable[T]) -> None:
        try:
            async for item in iter:
                await queue.put(item)
        except asyncio.CancelledError:
            if isinstance(iter, AsyncGenerator):  # pragma:nocover
                await iter.aclose()
            raise
        except Exception as e:
            await exc_queue.put(e)

    tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]
    pending = set(tasks)

    try:
        while pending:
            try:
                exc = exc_queue.get_nowait()
                if return_exceptions:
                    yield exc  # type: ignore
                else:
                    raise exc
            except asyncio.QueueEmpty:
                pass

            try:
                value = queue.get_nowait()
                yield value
            except asyncio.QueueEmpty:
                for task in list(pending):
                    if task.done():
                        pending.remove(task)
                await asyncio.sleep(0.001)

    except (asyncio.CancelledError, GeneratorExit):
        pass

    finally:
        for task in tasks:
            if not task.done():
                task.cancel()

        for task in tasks:
            try:
                await task
            except asyncio.CancelledError:
                pass


async def gather(
    *args: Awaitable[T],
    return_exceptions: bool = False,
    limit: int = -1,
) -> List[Any]:
    """
    Like asyncio.gather but with a limit on concurrency.

    Note that all results are buffered.

    If gather is cancelled all tasks that were internally created and still pending
    will be cancelled as well.

    Example::

        futures = [some_coro(i) for i in range(10)]

        results = await gather(*futures, limit=2)
    """

    # For detecting input duplicates and reconciling them at the end
    input_map: Dict[Awaitable[T], List[int]] = {}
    # This is keyed on what we'll get back from asyncio.wait
    pos: Dict[asyncio.Future[T], int] = {}
    ret: List[Any] = [None] * len(args)

    pending: Set[asyncio.Future[T]] = set()
    done: Set[asyncio.Future[T]] = set()

    next_arg = 0

    while True:
        while next_arg < len(args) and (limit == -1 or len(pending) < limit):
            # We have to defer the creation of the Task as long as possible
            # because once we do, it starts executing, regardless of what we
            # have in the pending set.
            if args[next_arg] in input_map:
                input_map[args[next_arg]].append(next_arg)
            else:
                # We call ensure_future directly to ensure that we have a Task
                # because the return value of asyncio.wait will be an implicit
                # task otherwise, and we won't be able to know which input it
                # corresponds to.
                task: asyncio.Future[T] = asyncio.ensure_future(args[next_arg])
                pending.add(task)
                pos[task] = next_arg
                input_map[args[next_arg]] = [next_arg]
            next_arg += 1

        # pending might be empty if the last items of args were dupes;
        # asyncio.wait([]) will raise an exception.
        if pending:
            try:
                done, pending = await asyncio.wait(
                    pending, return_when=asyncio.FIRST_COMPLETED
                )
                for x in done:
                    if return_exceptions and x.exception():
                        ret[pos[x]] = x.exception()
                    else:
                        ret[pos[x]] = x.result()
            except asyncio.CancelledError:
                # Since we created these tasks we should cancel them
                for x in pending:
                    x.cancel()
                # we insure that all tasks are cancelled before we raise
                await asyncio.gather(*pending, return_exceptions=True)
                raise

        if not pending and next_arg == len(args):
            break

    for lst in input_map.values():
        for i in range(1, len(lst)):
            ret[lst[i]] = ret[lst[0]]

    return ret


async def gather_iter(
    itr: AnyIterable[MaybeAwaitable[T]],
    return_exceptions: bool = False,
    limit: int = -1,
) -> List[T]:
    """
    Wrapper around gather to handle gathering an iterable instead of *args.

    Note that the iterable values don't have to be awaitable.
    """
    return await gather(
        *[maybe_await(i) async for i in aiter(itr)],
        return_exceptions=return_exceptions,
        limit=limit,
    )