# Copyright 2022 Amethyst Reese
# Licensed under the MIT license

import asyncio
from typing import AsyncIterator
from unittest import TestCase

import aioitertools as ait
from .helpers import async_test

slist = ["A", "B", "C"]
srange = range(3)
srange1 = range(1, 4)
srange0 = range(1)


class BuiltinsTest(TestCase):

    # aioitertools.all()

    @async_test
    async def test_all_list(self):
        self.assertTrue(await ait.all([True, 1, "string"]))
        self.assertFalse(await ait.all([True, 0, "string"]))

    @async_test
    async def test_all_range(self):
        self.assertTrue(await ait.all(srange1))
        self.assertFalse(await ait.all(srange))

    @async_test
    async def test_all_generator(self):
        self.assertTrue(await ait.all(x for x in srange1))
        self.assertFalse(await ait.all(x for x in srange))

    @async_test
    async def test_all_async_generator(self):
        self.assertTrue(await ait.all(ait.iter(srange1)))
        self.assertFalse(await ait.all(ait.iter(srange)))

    # aioitertools.any()

    @async_test
    async def test_any_list(self):
        self.assertTrue(await ait.any([False, 1, ""]))
        self.assertFalse(await ait.any([False, 0, ""]))

    @async_test
    async def test_any_range(self):
        self.assertTrue(await ait.any(srange))
        self.assertTrue(await ait.any(srange1))
        self.assertFalse(await ait.any(srange0))

    @async_test
    async def test_any_generator(self):
        self.assertTrue(await ait.any(x for x in srange))
        self.assertTrue(await ait.any(x for x in srange1))
        self.assertFalse(await ait.any(x for x in srange0))

    @async_test
    async def test_any_async_generator(self):
        self.assertTrue(await ait.any(ait.iter(srange)))
        self.assertTrue(await ait.any(ait.iter(srange1)))
        self.assertFalse(await ait.any(ait.iter(srange0)))

    # aioitertools.iter()

    @async_test
    async def test_iter_list(self):
        it = ait.iter(slist)
        self.assertIsInstance(it, AsyncIterator)
        idx = 0
        async for item in it:
            self.assertEqual(item, slist[idx])
            idx += 1

    @async_test
    async def test_iter_range(self):
        it = ait.iter(srange)
        self.assertIsInstance(it, AsyncIterator)
        idx = 0
        async for item in it:
            self.assertEqual(item, srange[idx])
            idx += 1

    @async_test
    async def test_iter_iterable(self):
        sentinel = object()

        class async_iterable:
            def __aiter__(self):
                return sentinel

        aiter = async_iterable()
        self.assertEqual(ait.iter(aiter), sentinel)

    @async_test
    async def test_iter_iterator(self):
        sentinel = object()

        class async_iterator:
            def __aiter__(self):
                return sentinel

            def __anext__(self):
                return sentinel

        aiter = async_iterator()
        self.assertEqual(ait.iter(aiter), aiter)

    @async_test
    async def test_iter_async_generator(self):
        async def async_gen():
            yield 1
            yield 2

        agen = async_gen()
        self.assertEqual(ait.iter(agen), agen)

    # aioitertools.next()

    @async_test
    async def test_next_list(self):
        it = ait.iter(slist)
        self.assertEqual(await ait.next(it), "A")
        self.assertEqual(await ait.next(it), "B")
        self.assertEqual(await ait.next(it), "C")
        with self.assertRaises(StopAsyncIteration):
            await ait.next(it)

    @async_test
    async def test_next_range(self):
        it = ait.iter(srange)
        self.assertEqual(await ait.next(it), 0)
        self.assertEqual(await ait.next(it), 1)
        self.assertEqual(await ait.next(it), 2)
        with self.assertRaises(StopAsyncIteration):
            await ait.next(it)

    @async_test
    async def test_next_iterable(self):
        class async_iter:
            def __init__(self):
                self.index = 0

            def __aiter__(self):
                return self

            def __anext__(self):
                if self.index > 2:
                    raise StopAsyncIteration()
                return self.fake_next()

            async def fake_next(self):
                value = slist[self.index]
                self.index += 1
                return value

        it = ait.iter(async_iter())
        self.assertEqual(await ait.next(it), "A")
        self.assertEqual(await ait.next(it), "B")
        self.assertEqual(await ait.next(it), "C")
        with self.assertRaises(StopAsyncIteration):
            await ait.next(it)

        it = iter(slist)
        self.assertEqual(await ait.next(it), "A")
        self.assertEqual(await ait.next(it), "B")
        self.assertEqual(await ait.next(it), "C")
        with self.assertRaises(StopAsyncIteration):
            await ait.next(it)

    @async_test
    async def test_next_async_generator(self):
        async def async_gen():
            for item in slist:
                yield item

        it = ait.iter(async_gen())
        self.assertEqual(await ait.next(it), "A")
        self.assertEqual(await ait.next(it), "B")
        self.assertEqual(await ait.next(it), "C")
        with self.assertRaises(StopAsyncIteration):
            await ait.next(it)

    @async_test
    async def test_next_default_iterable(self):
        it = iter(["A"])

        self.assertEqual(await ait.next(it, "?"), "A")
        # End of iteration
        self.assertEqual(await ait.next(it, "?"), "?")

    @async_test
    async def test_next_default_async_iterable(self):
        it = ait.iter(["A"])
        self.assertEqual(await ait.next(it, "?"), "A")
        # End of iteration
        self.assertEqual(await ait.next(it, "?"), "?")

    # aioitertools.list()

    @async_test
    async def test_list(self):
        self.assertEqual(await ait.list(ait.iter(slist)), slist)

    @async_test
    async def test_tuple(self):
        self.assertEqual(await ait.tuple(ait.iter(slist)), tuple(slist))

    # aioitertools.set()

    @async_test
    async def test_set(self):
        self.assertEqual(await ait.set(ait.iter(slist)), set(slist))

    # aioitertools.enumerate()

    @async_test
    async def test_enumerate(self):
        async for index, value in ait.enumerate(slist):
            self.assertEqual(value, slist[index])

    @async_test
    async def test_enumerate_start(self):
        async for index, value in ait.enumerate(slist, 4):
            self.assertEqual(value, slist[index - 4])

    # aioitertools.map()

    @async_test
    async def test_map_function_list(self):
        idx = 0
        async for value in ait.map(str.lower, slist):
            self.assertEqual(value, slist[idx].lower())
            idx += 1

    @async_test
    async def test_map_function_async_generator(self):
        async def gen():
            for item in slist:
                yield item

        idx = 0
        async for value in ait.map(str.lower, gen()):
            self.assertEqual(value, slist[idx].lower())
            idx += 1

    @async_test
    async def test_map_coroutine_list(self):
        async def double(x):
            await asyncio.sleep(0.0001)
            return x * 2

        idx = 0
        async for value in ait.map(double, slist):
            self.assertEqual(value, slist[idx] * 2)
            idx += 1

    @async_test
    async def test_map_coroutine_generator(self):
        async def gen():
            for item in slist:
                yield item

        async def double(x):
            await asyncio.sleep(0.0001)
            return x * 2

        idx = 0
        async for value in ait.map(double, gen()):
            self.assertEqual(value, slist[idx] * 2)
            idx += 1

    # aioitertools.max()

    @async_test
    async def test_max_basic(self):
        async def gen():
            for item in slist:
                yield item

        self.assertEqual(await ait.max(gen()), "C")
        self.assertEqual(await ait.max(range(4)), 3)

        with self.assertRaisesRegex(ValueError, "iterable is empty"):
            await ait.max([])

        with self.assertRaisesRegex(ValueError, "kwarg .+ not supported"):
            await ait.max(None, foo="foo")

    @async_test
    async def test_max_default(self):
        self.assertEqual(await ait.max(range(2), default="x"), 1)
        self.assertEqual(await ait.max([], default="x"), "x")
        self.assertEqual(await ait.max([], default=None), None)

    @async_test
    async def test_max_key(self):
        words = ["star", "buzz", "guard"]

        def reverse(s):
            return s[::-1]

        self.assertEqual(reverse("python"), "nohtyp")

        self.assertEqual(await ait.max(words), "star")
        self.assertEqual(await ait.max(words, key=reverse), "buzz")

    # aioitertools.min()

    @async_test
    async def test_min_basic(self):
        async def gen():
            for item in slist:
                yield item

        self.assertEqual(await ait.min(gen()), "A")
        self.assertEqual(await ait.min(range(4)), 0)

        with self.assertRaisesRegex(ValueError, "iterable is empty"):
            await ait.min([])

        with self.assertRaisesRegex(ValueError, "kwarg .+ not supported"):
            await ait.min(None, foo="foo")

    @async_test
    async def test_min_default(self):
        self.assertEqual(await ait.min(range(2), default="x"), 0)
        self.assertEqual(await ait.min([], default="x"), "x")
        self.assertEqual(await ait.min([], default=None), None)

    @async_test
    async def test_min_key(self):
        words = ["star", "buzz", "guard"]

        def reverse(s):
            return s[::-1]

        self.assertEqual(reverse("python"), "nohtyp")

        self.assertEqual(await ait.min(words), "buzz")
        self.assertEqual(await ait.min(words, key=reverse), "guard")

    # aioitertools.sum()

    @async_test
    async def test_sum_range_default(self):
        self.assertEqual(await ait.sum(srange), sum(srange))

    @async_test
    async def test_sum_list_string(self):
        self.assertEqual(await ait.sum(slist, "foo"), "fooABC")

    # aioitertools.zip()

    @async_test
    async def test_zip_equal(self):
        idx = 0
        async for a, b in ait.zip(slist, srange):
            self.assertEqual(a, slist[idx])
            self.assertEqual(b, srange[idx])
            idx += 1

    @async_test
    async def test_zip_shortest(self):
        short = ["a", "b", "c"]
        long = [0, 1, 2, 3, 5]

        result = await ait.list(ait.zip(short, long))
        expected = [("a", 0), ("b", 1), ("c", 2)]
        self.assertListEqual(expected, result)
