"""
Tests common to psycopg.AsyncCursor and its subclasses.
"""

import weakref
import datetime as dt
from typing import Any
from contextlib import aclosing

import pytest
from packaging.version import parse as ver

import psycopg
from psycopg import pq, rows, sql
from psycopg.adapt import PyFormat
from psycopg.types import TypeInfo

from . import _test_cursor
from .utils import raiseif
from .acompat import alist, gather, spawn
from .fix_crdb import crdb_encoding
from .test_adapt import make_loader
from ._test_cursor import my_row_factory, ph

execmany = _test_cursor.execmany  # avoid F811 underneath
_execmany = _test_cursor._execmany  # needed by the execmany fixture


cursor_classes = [psycopg.AsyncCursor, psycopg.AsyncClientCursor]
# Allow to import (not necessarily to run) the module with psycopg 3.1.
# Needed to test psycopg_pool 3.2 tests with psycopg 3.1 imported, i.e. to run
# `pytest -m pool`. (which might happen when releasing pool packages).
if ver(psycopg.__version__) >= ver("3.2.0.dev0"):
    cursor_classes.append(psycopg.AsyncRawCursor)


@pytest.fixture(params=cursor_classes)
async def aconn(aconn, request, anyio_backend):
    aconn.cursor_factory = request.param
    return aconn


async def test_init(aconn):
    cur = aconn.cursor_factory(aconn)
    await cur.execute("select 1")
    assert (await cur.fetchone()) == (1,)

    aconn.row_factory = rows.dict_row
    cur = aconn.cursor_factory(aconn)
    await cur.execute("select 1 as a")
    assert (await cur.fetchone()) == {"a": 1}


async def test_init_factory(aconn):
    cur = aconn.cursor_factory(aconn, row_factory=rows.dict_row)
    await cur.execute("select 1 as a")
    assert (await cur.fetchone()) == {"a": 1}


async def test_close(aconn):
    cur = aconn.cursor()
    assert not cur.closed
    await cur.close()
    assert cur.closed

    with pytest.raises(psycopg.InterfaceError):
        await cur.execute("select 'foo'")

    await cur.close()
    assert cur.closed


async def test_cursor_close_fetchone(aconn):
    cur = aconn.cursor()
    assert not cur.closed

    query = "select * from generate_series(1, 10)"
    await cur.execute(query)
    for _ in range(5):
        await cur.fetchone()

    await cur.close()
    assert cur.closed

    with pytest.raises(psycopg.InterfaceError):
        await cur.fetchone()


async def test_cursor_close_fetchmany(aconn):
    cur = aconn.cursor()
    assert not cur.closed

    query = "select * from generate_series(1, 10)"
    await cur.execute(query)
    assert len(await cur.fetchmany(2)) == 2

    await cur.close()
    assert cur.closed

    with pytest.raises(psycopg.InterfaceError):
        await cur.fetchmany(2)


async def test_cursor_close_fetchall(aconn):
    cur = aconn.cursor()
    assert not cur.closed

    query = "select * from generate_series(1, 10)"
    await cur.execute(query)
    assert len(await cur.fetchall()) == 10

    await cur.close()
    assert cur.closed

    with pytest.raises(psycopg.InterfaceError):
        await cur.fetchall()


async def test_context(aconn):
    async with aconn.cursor() as cur:
        assert not cur.closed

    assert cur.closed


@pytest.mark.slow
async def test_weakref(aconn, gc_collect):
    cur = aconn.cursor()
    w = weakref.ref(cur)
    await cur.close()
    del cur
    gc_collect()
    assert w() is None


async def test_pgresult(aconn):
    cur = aconn.cursor()
    await cur.execute("select 1")
    assert cur.pgresult
    await cur.close()
    assert not cur.pgresult


async def test_statusmessage(aconn):
    cur = aconn.cursor()
    assert cur.statusmessage is None

    await cur.execute("select generate_series(1, 10)")
    assert cur.statusmessage == "SELECT 10"

    await cur.execute("create table statusmessage ()")
    assert cur.statusmessage == "CREATE TABLE"

    with pytest.raises(psycopg.ProgrammingError):
        await cur.execute("wat")
    assert cur.statusmessage is None


async def test_execute_sql(aconn):
    cur = aconn.cursor()
    await cur.execute(sql.SQL("select {value}").format(value="hello"))
    assert (await cur.fetchone()) == ("hello",)


async def test_next(aconn):
    cur = aconn.cursor()
    await cur.execute("select 1")
    assert await anext(cur) == (1,)
    with pytest.raises(StopAsyncIteration):
        await anext(cur)


async def test_query_parse_cache_size(aconn):
    cur = aconn.cursor()
    cls = type(cur)

    # Warning: testing internal structures. Test might need refactoring with the code.
    cache: Any
    if cls is psycopg.AsyncCursor:
        cache = psycopg._queries._query2pg
    elif cls is psycopg.AsyncClientCursor:
        cache = psycopg._queries._query2pg_client
    elif cls is psycopg.AsyncRawCursor:
        pytest.skip("RawCursor has no query parse cache")
    else:
        assert False, cls

    cache.cache_clear()
    ci = cache.cache_info()
    h0, m0 = ci.hits, ci.misses
    tests = [
        (f"select 1 -- {'x' * 3500}", (), h0, m0 + 1),
        (f"select 1 -- {'x' * 3500}", (), h0 + 1, m0 + 1),
        (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1),
        (f"select 1 -- {'x' * 4500}", (), h0 + 1, m0 + 1),
        (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 1, m0 + 2),
        (f"select 1 -- {'%s' * 40}", ("x",) * 40, h0 + 2, m0 + 2),
        (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
        (f"select 1 -- {'%s' * 60}", ("x",) * 60, h0 + 2, m0 + 2),
    ]
    for i, (query, params, hits, misses) in enumerate(tests):
        pq = cur._query_cls(psycopg.adapt.Transformer())
        pq.convert(query, params)
        ci = cache.cache_info()
        assert ci.hits == hits, f"at {i}"
        assert ci.misses == misses, f"at {i}"


async def test_execute_many_results(aconn):
    cur = aconn.cursor()
    assert cur.nextset() is None

    rv = await cur.execute("select 'foo'; select generate_series(1,3)")
    assert rv is cur
    assert (await cur.fetchall()) == [("foo",)]
    assert cur.rowcount == 1
    assert cur.nextset()
    assert (await cur.fetchall()) == [(1,), (2,), (3,)]
    assert cur.rowcount == 3
    assert cur.nextset() is None

    await cur.close()
    assert cur.nextset() is None


async def test_set_results(aconn):
    cur = aconn.cursor()

    with pytest.raises(IndexError):
        await cur.set_result(0)

    await cur.execute("select 'foo'; select generate_series(1,3)")
    assert await cur.set_result(0) is cur
    assert (await cur.fetchall()) == [("foo",)]
    assert cur.rowcount == 1

    assert await cur.set_result(-1) is cur
    assert (await cur.fetchall()) == [(1,), (2,), (3,)]
    assert cur.rowcount == 3

    with pytest.raises(IndexError):
        await cur.set_result(2)

    with pytest.raises(IndexError):
        await cur.set_result(-3)


async def test_execute_sequence(aconn):
    cur = aconn.cursor()
    rv = await cur.execute(
        ph(cur, "select %s::int, %s::text, %s::text"), [1, "foo", None]
    )
    assert rv is cur
    assert len(cur._results) == 1
    assert cur.pgresult.get_value(0, 0) == b"1"
    assert cur.pgresult.get_value(0, 1) == b"foo"
    assert cur.pgresult.get_value(0, 2) is None
    assert cur.nextset() is None


@pytest.mark.parametrize("query", ["", " ", ";"])
async def test_execute_empty_query(aconn, query):
    cur = aconn.cursor()
    await cur.execute(query)
    assert cur.pgresult.status == pq.ExecStatus.EMPTY_QUERY
    with pytest.raises(psycopg.ProgrammingError):
        await cur.fetchone()


async def test_execute_type_change(aconn):
    # issue #112
    await aconn.execute("create table bug_112 (num integer)")
    cur = aconn.cursor()
    sql = ph(cur, "insert into bug_112 (num) values (%s)")
    await cur.execute(sql, (1,))
    await cur.execute(sql, (100_000,))
    await cur.execute("select num from bug_112 order by num")
    assert (await cur.fetchall()) == [(1,), (100_000,)]


async def test_executemany_type_change(aconn):
    await aconn.execute("create table bug_112 (num integer)")
    cur = aconn.cursor()
    sql = ph(cur, "insert into bug_112 (num) values (%s)")
    await cur.executemany(sql, [(1,), (100_000,)])
    await cur.execute("select num from bug_112 order by num")
    assert (await cur.fetchall()) == [(1,), (100_000,)]


@pytest.mark.parametrize(
    "query", ["copy testcopy from stdin", "copy testcopy to stdout"]
)
async def test_execute_copy(aconn, query):
    cur = aconn.cursor()
    await cur.execute("create table testcopy (id int)")
    with pytest.raises(psycopg.ProgrammingError):
        await cur.execute(query)


async def test_fetchone(aconn):
    cur = aconn.cursor()
    await cur.execute(ph(cur, "select %s::int, %s::text, %s::text"), [1, "foo", None])
    assert cur.pgresult.fformat(0) == 0

    row = await cur.fetchone()
    assert row == (1, "foo", None)
    row = await cur.fetchone()
    assert row is None


async def test_binary_cursor_execute(aconn):
    with raiseif(
        aconn.cursor_factory is psycopg.AsyncClientCursor, psycopg.NotSupportedError
    ) as ex:
        cur = aconn.cursor(binary=True)
        await cur.execute(ph(cur, "select %s, %s"), [1, None])
    if ex:
        return

    assert (await cur.fetchone()) == (1, None)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x01"


async def test_execute_binary(aconn):
    cur = aconn.cursor()
    with raiseif(
        aconn.cursor_factory is psycopg.AsyncClientCursor, psycopg.NotSupportedError
    ) as ex:
        await cur.execute(ph(cur, "select %s, %s"), [1, None], binary=True)
    if ex:
        return

    assert (await cur.fetchone()) == (1, None)
    assert cur.pgresult.fformat(0) == 1
    assert cur.pgresult.get_value(0, 0) == b"\x00\x01"


async def test_binary_cursor_text_override(aconn):
    cur = aconn.cursor(binary=True)
    await cur.execute(ph(cur, "select %s, %s"), [1, None], binary=False)
    assert (await cur.fetchone()) == (1, None)
    assert cur.pgresult.fformat(0) == 0
    assert cur.pgresult.get_value(0, 0) == b"1"


@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
async def test_query_encode(aconn, encoding):
    await aconn.execute(f"set client_encoding to {encoding}")
    cur = aconn.cursor()
    await cur.execute("select '\u20ac'")
    (res,) = await cur.fetchone()
    assert res == "\u20ac"


@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
async def test_query_badenc(aconn, encoding):
    await aconn.execute(f"set client_encoding to {encoding}")
    cur = aconn.cursor()
    with pytest.raises(UnicodeEncodeError):
        await cur.execute("select '\u20ac'")


async def test_executemany(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%s, %s)"),
        [(10, "hello"), (20, "world")],
    )
    await cur.execute("select num, data from execmany order by 1")
    rv = await cur.fetchall()
    assert rv == [(10, "hello"), (20, "world")]


async def test_executemany_name(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%(num)s, %(data)s)"),
        [{"num": 11, "data": "hello", "x": 1}, {"num": 21, "data": "world"}],
    )
    await cur.execute("select num, data from execmany order by 1")
    rv = await cur.fetchall()
    assert rv == [(11, "hello"), (21, "world")]


async def test_executemany_no_data(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%s, %s)"), []
    )
    assert cur.rowcount == 0


async def test_executemany_rowcount(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%s, %s)"),
        [(10, "hello"), (20, "world")],
    )
    assert cur.rowcount == 2


async def test_executemany_returning(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%s, %s) returning num"),
        [(10, "hello"), (20, "world")],
        returning=True,
    )
    assert cur.rowcount == 1
    assert (await cur.fetchone()) == (10,)
    assert cur.nextset()
    assert cur.rowcount == 1
    assert (await cur.fetchone()) == (20,)
    assert cur.nextset() is None


async def test_executemany_returning_discard(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%s, %s) returning num"),
        [(10, "hello"), (20, "world")],
    )
    assert cur.rowcount == 2
    with pytest.raises(psycopg.ProgrammingError):
        await cur.fetchone()
    assert cur.nextset() is None


async def test_executemany_no_result(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(
        ph(cur, "insert into execmany(num, data) values (%s, %s)"),
        [(10, "hello"), (20, "world")],
        returning=True,
    )
    assert cur.rowcount == 1
    assert cur.statusmessage.startswith("INSERT")
    with pytest.raises(psycopg.ProgrammingError):
        await cur.fetchone()
    pgresult = cur.pgresult
    assert cur.nextset()
    assert cur.rowcount == 1
    assert cur.statusmessage.startswith("INSERT")
    assert pgresult is not cur.pgresult
    assert cur.nextset() is None


async def test_executemany_rowcount_no_hit(aconn, execmany):
    cur = aconn.cursor()
    await cur.executemany(ph(cur, "delete from execmany where id = %s"), [(-1,), (-2,)])
    assert cur.rowcount == 0
    await cur.executemany(ph(cur, "delete from execmany where id = %s"), [])
    assert cur.rowcount == 0
    await cur.executemany(
        ph(cur, "delete from execmany where id = %s returning num"), [(-1,), (-2,)]
    )
    assert cur.rowcount == 0


@pytest.mark.parametrize(
    "query",
    [
        "insert into nosuchtable values (%s, %s)",
        "copy (select %s, %s) to stdout",
        "wat (%s, %s)",
    ],
)
async def test_executemany_badquery(aconn, query):
    cur = aconn.cursor()
    with pytest.raises(psycopg.DatabaseError):
        await cur.executemany(ph(cur, query), [(10, "hello"), (20, "world")])


@pytest.mark.parametrize("fmt_in", PyFormat)
async def test_executemany_null_first(aconn, fmt_in):
    cur = aconn.cursor()
    await cur.execute("create table testmany (a bigint, b bigint)")
    await cur.executemany(
        ph(cur, f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})"),
        [[1, None], [3, 4]],
    )
    with pytest.raises((psycopg.DataError, psycopg.ProgrammingError)):
        await cur.executemany(
            ph(cur, f"insert into testmany values (%{fmt_in.value}, %{fmt_in.value})"),
            [[1, ""], [3, 4]],
        )


@pytest.mark.slow
async def test_executemany_lock(aconn):
    async def do_execmany():
        async with aconn.cursor() as cur:
            await cur.executemany(
                ph(cur, "select pg_sleep(%s)"), [(0.1,) for _ in range(10)]
            )

    async def do_exec():
        async with aconn.cursor() as cur:
            for i in range(100):
                await cur.execute("select 1")

    await gather(spawn(do_execmany), spawn(do_exec))


async def test_rowcount(aconn):
    cur = aconn.cursor()

    await cur.execute("select 1 from generate_series(1, 0)")
    assert cur.rowcount == 0

    await cur.execute("select 1 from generate_series(1, 42)")
    assert cur.rowcount == 42

    await cur.execute("show timezone")
    assert cur.rowcount == 1

    await cur.execute("create table test_rowcount_notuples (id int primary key)")
    assert cur.rowcount == -1

    await cur.execute(
        "insert into test_rowcount_notuples select generate_series(1, 42)"
    )
    assert cur.rowcount == 42


async def test_rownumber(aconn):
    cur = aconn.cursor()
    assert cur.rownumber is None

    await cur.execute("select 1 from generate_series(1, 42)")
    assert cur.rownumber == 0

    await cur.fetchone()
    assert cur.rownumber == 1
    await cur.fetchone()
    assert cur.rownumber == 2
    await cur.fetchmany(10)
    assert cur.rownumber == 12
    rns: list[int] = []
    async for i in cur:
        assert cur.rownumber
        rns.append(cur.rownumber)
        if len(rns) >= 3:
            break
    assert rns == [13, 14, 15]
    assert len(await cur.fetchall()) == 42 - rns[-1]
    assert cur.rownumber == 42


@pytest.mark.parametrize("query", ["", "set timezone to utc"])
async def test_rownumber_none(aconn, query):
    cur = aconn.cursor()
    await cur.execute(query)
    assert cur.rownumber is None


async def test_rownumber_mixed(aconn):
    cur = aconn.cursor()
    await cur.execute("""
select x from generate_series(1, 3) x;
set timezone to utc;
select x from generate_series(4, 6) x;
""")
    assert cur.rownumber == 0
    assert await cur.fetchone() == (1,)
    assert cur.rownumber == 1
    assert await cur.fetchone() == (2,)
    assert cur.rownumber == 2
    cur.nextset()
    assert cur.rownumber is None
    cur.nextset()
    assert cur.rownumber == 0
    assert await cur.fetchone() == (4,)
    assert cur.rownumber == 1


async def test_iter(aconn):
    cur = aconn.cursor()
    await cur.execute("select generate_series(1, 3)")
    assert await alist(cur) == [(1,), (2,), (3,)]


async def test_iter_stop(aconn):
    cur = aconn.cursor()
    await cur.execute("select generate_series(1, 3)")
    async for rec in cur:
        assert rec == (1,)
        break

    async for rec in cur:
        assert rec == (2,)
        break

    assert (await cur.fetchone()) == (3,)
    assert (await alist(cur)) == []


async def test_row_factory(aconn):
    cur = aconn.cursor(row_factory=my_row_factory)

    await cur.execute("reset search_path")
    with pytest.raises(psycopg.ProgrammingError):
        await cur.fetchone()

    await cur.execute("select 'foo' as bar")
    (r,) = await cur.fetchone()
    assert r == "FOObar"

    await cur.execute("select 'x' as x; select 'y' as y, 'z' as z")
    assert await cur.fetchall() == [["Xx"]]
    assert cur.nextset()
    assert await cur.fetchall() == [["Yy", "Zz"]]

    await cur.scroll(-1)
    cur.row_factory = rows.dict_row
    assert await cur.fetchone() == {"y": "y", "z": "z"}


async def test_row_factory_none(aconn):
    cur = aconn.cursor(row_factory=None)
    assert cur.row_factory is rows.tuple_row
    await cur.execute("select 1 as a, 2 as b")
    r = await cur.fetchone()
    assert type(r) is tuple
    assert r == (1, 2)


async def test_bad_row_factory(aconn):
    def broken_factory(cur):
        1 / 0

    cur = aconn.cursor(row_factory=broken_factory)
    with pytest.raises(ZeroDivisionError):
        await cur.execute("select 1")

    def broken_maker(cur):
        def make_row(seq):
            1 / 0

        return make_row

    cur = aconn.cursor(row_factory=broken_maker)
    await cur.execute("select 1")
    with pytest.raises(ZeroDivisionError):
        await cur.fetchone()


async def test_scroll(aconn):
    cur = aconn.cursor()
    with pytest.raises(psycopg.ProgrammingError):
        await cur.scroll(0)

    await cur.execute("select generate_series(0,9)")
    await cur.scroll(2)
    assert await cur.fetchone() == (2,)
    await cur.scroll(2)
    assert await cur.fetchone() == (5,)
    await cur.scroll(2, mode="relative")
    assert await cur.fetchone() == (8,)
    await cur.scroll(-1)
    assert await cur.fetchone() == (8,)
    await cur.scroll(-2)
    assert await cur.fetchone() == (7,)
    await cur.scroll(2, mode="absolute")
    assert await cur.fetchone() == (2,)

    # on the boundary
    await cur.scroll(0, mode="absolute")
    assert await cur.fetchone() == (0,)
    with pytest.raises(IndexError):
        await cur.scroll(-1, mode="absolute")

    await cur.scroll(0, mode="absolute")
    with pytest.raises(IndexError):
        await cur.scroll(-1)

    await cur.scroll(9, mode="absolute")
    assert await cur.fetchone() == (9,)
    with pytest.raises(IndexError):
        await cur.scroll(10, mode="absolute")

    await cur.scroll(9, mode="absolute")
    with pytest.raises(IndexError):
        await cur.scroll(1)

    with pytest.raises(ValueError):
        await cur.scroll(1, "wat")


@pytest.mark.parametrize(
    "query, params, want",
    [
        ("select %(x)s", {"x": 1}, (1,)),
        ("select %(x)s, %(y)s", {"x": 1, "y": 2}, (1, 2)),
        ("select %(x)s, %(x)s", {"x": 1}, (1, 1)),
    ],
)
async def test_execute_params_named(aconn, query, params, want):
    cur = aconn.cursor()
    await cur.execute(ph(cur, query), params)
    rec = await cur.fetchone()
    assert rec == want


async def test_stream(aconn):
    cur = aconn.cursor()
    recs = []
    async for rec in cur.stream(
        ph(cur, "select i, '2021-01-01'::date + i from generate_series(1, %s) as i"),
        [2],
    ):
        recs.append(rec)

    assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]


async def test_stream_sql(aconn):
    cur = aconn.cursor()
    recs = await alist(
        cur.stream(
            sql.SQL(
                "select i, '2021-01-01'::date + i from generate_series(1, {}) as i"
            ).format(2)
        )
    )

    assert recs == [(1, dt.date(2021, 1, 2)), (2, dt.date(2021, 1, 3))]


async def test_stream_row_factory(aconn):
    cur = aconn.cursor(row_factory=rows.dict_row)
    it = cur.stream("select generate_series(1,2) as a")
    assert (await anext(it))["a"] == 1
    cur.row_factory = rows.namedtuple_row
    assert (await anext(it)).a == 2


async def test_stream_no_row(aconn):
    cur = aconn.cursor()
    recs = await alist(cur.stream("select generate_series(2,1) as a"))
    assert recs == []


async def test_stream_chunked_invalid_size(aconn):
    cur = aconn.cursor()
    with pytest.raises(ValueError, match=r"size must be >= 1"):
        await anext(cur.stream("select 1", size=0))


@pytest.mark.libpq("< 17")
async def test_stream_chunked_not_supported(aconn):
    cur = aconn.cursor()
    with pytest.raises(psycopg.NotSupportedError):
        await anext(cur.stream("select generate_series(1, 4)", size=2))


@pytest.mark.libpq(">= 17")
async def test_stream_chunked(aconn):
    cur = aconn.cursor()
    recs = await alist(cur.stream("select generate_series(1, 5) as a", size=2))
    assert recs == [(1,), (2,), (3,), (4,), (5,)]


@pytest.mark.libpq(">= 17")
async def test_stream_chunked_row_factory(aconn):
    cur = aconn.cursor(row_factory=rows.scalar_row)
    it = cur.stream("select generate_series(1, 5) as a", size=2)
    for i in range(1, 6):
        assert await anext(it) == i
        assert [c.name for c in cur.description] == ["a"]


@pytest.mark.crdb_skip("no col query")
async def test_stream_no_col(aconn):
    cur = aconn.cursor()
    recs = await alist(cur.stream("select"))
    assert recs == [()]


@pytest.mark.parametrize(
    "query",
    [
        "create table test_stream_badq ()",
        "copy (select 1) to stdout",
        "wat?",
    ],
)
async def test_stream_badquery(aconn, query):
    cur = aconn.cursor()
    with pytest.raises(psycopg.ProgrammingError):
        async for rec in cur.stream(query):
            pass


async def test_stream_error_tx(aconn):
    cur = aconn.cursor()
    with pytest.raises(psycopg.ProgrammingError):
        async for rec in cur.stream("wat"):
            pass
    assert aconn.info.transaction_status == pq.TransactionStatus.INERROR


async def test_stream_error_notx(aconn):
    await aconn.set_autocommit(True)
    cur = aconn.cursor()
    with pytest.raises(psycopg.ProgrammingError):
        async for rec in cur.stream("wat"):
            pass
    assert aconn.info.transaction_status == pq.TransactionStatus.IDLE


async def test_stream_error_python_to_consume(aconn):
    cur = aconn.cursor()
    with pytest.raises(ZeroDivisionError):
        async with aclosing(cur.stream("select generate_series(1, 10000)")) as gen:
            async for rec in gen:
                1 / 0
    assert aconn.info.transaction_status in (
        pq.TransactionStatus.INTRANS,
        pq.TransactionStatus.INERROR,
    )


async def test_stream_error_python_consumed(aconn):
    cur = aconn.cursor()
    with pytest.raises(ZeroDivisionError):
        gen = cur.stream("select 1")
        async for rec in gen:
            1 / 0

    await gen.aclose()
    assert aconn.info.transaction_status == pq.TransactionStatus.INTRANS


@pytest.mark.parametrize("autocommit", [False, True])
async def test_stream_close(aconn, autocommit):
    await aconn.set_autocommit(autocommit)
    cur = aconn.cursor()
    with pytest.raises(psycopg.OperationalError):
        async for rec in cur.stream("select generate_series(1, 3)"):
            if rec[0] == 1:
                await aconn.close()
            else:
                assert False

    assert aconn.closed


async def test_stream_binary_cursor(aconn):
    with raiseif(
        aconn.cursor_factory is psycopg.AsyncClientCursor, psycopg.NotSupportedError
    ):
        cur = aconn.cursor(binary=True)
        recs = []
        async for rec in cur.stream("select x::int4 from generate_series(1, 2) x"):
            recs.append(rec)
            assert cur.pgresult.fformat(0) == 1
            assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])

        assert recs == [(1,), (2,)]


async def test_stream_execute_binary(aconn):
    cur = aconn.cursor()
    recs = []
    with raiseif(
        aconn.cursor_factory is psycopg.AsyncClientCursor, psycopg.NotSupportedError
    ):
        async for rec in cur.stream(
            "select x::int4 from generate_series(1, 2) x", binary=True
        ):
            recs.append(rec)
            assert cur.pgresult.fformat(0) == 1
            assert cur.pgresult.get_value(0, 0) == bytes([0, 0, 0, rec[0]])

        assert recs == [(1,), (2,)]


async def test_stream_binary_cursor_text_override(aconn):
    cur = aconn.cursor(binary=True)
    recs = []
    async for rec in cur.stream("select generate_series(1, 2)", binary=False):
        recs.append(rec)
        assert cur.pgresult.fformat(0) == 0
        assert cur.pgresult.get_value(0, 0) == str(rec[0]).encode()

    assert recs == [(1,), (2,)]


async def test_str(aconn):
    cur = aconn.cursor()
    assert "[IDLE]" in str(cur)
    assert "[closed]" not in str(cur)
    assert "[no result]" in str(cur)
    await cur.execute("select 1")
    assert "[INTRANS]" in str(cur)
    assert "[TUPLES_OK]" in str(cur)
    assert "[closed]" not in str(cur)
    assert "[no result]" not in str(cur)
    await cur.close()
    assert "[closed]" in str(cur)
    assert "[INTRANS]" in str(cur)


@pytest.mark.pipeline
async def test_message_0x33(aconn):
    # https://github.com/psycopg/psycopg/issues/314
    notices = []
    aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))

    await aconn.set_autocommit(True)
    async with aconn.pipeline():
        cur = await aconn.execute("select 'test'")
        assert (await cur.fetchone()) == ("test",)

    assert not notices


async def test_typeinfo(aconn):
    info = await TypeInfo.fetch(aconn, "jsonb")
    assert info is not None


async def test_error_no_result(aconn):
    cur = aconn.cursor()
    with pytest.raises(psycopg.ProgrammingError, match="no result available"):
        await cur.fetchone()

    await cur.execute("set timezone to utc")
    with pytest.raises(
        psycopg.ProgrammingError, match="last operation.*command status: SET"
    ):
        await cur.fetchone()

    await cur.execute("")
    with pytest.raises(
        psycopg.ProgrammingError, match="last operation.*result status: EMPTY_QUERY"
    ):
        await cur.fetchone()


async def test_row_maker_returns_none(aconn):
    cur = aconn.cursor(row_factory=rows.scalar_row)
    query = "values (null), (0)"
    recs = [None, 0]

    await cur.execute(query)
    assert [await cur.fetchone() for _ in range(len(recs))] == recs
    await cur.execute(query)
    assert await cur.fetchmany(len(recs)) == recs
    await cur.execute(query)
    assert await cur.fetchall() == recs
    await cur.execute(query)
    assert await alist(cur) == recs
    stream = cur.stream(query)
    assert await alist(stream) == recs


@pytest.mark.parametrize("count", [1, 3])
async def test_results_after_execute(aconn, count):
    async with aconn.cursor() as cur:
        await cur.execute(
            ";".join(f"select * from generate_series(1, {i})" for i in range(count))
        )
        ress = await alist(await res.fetchall() async for res in cur.results())
        assert ress == [[(j + 1,) for j in range(i)] for i in range(count)]


@pytest.mark.parametrize("count", [0, 1, 3])
@pytest.mark.parametrize("returning", [False, True])
async def test_results_after_executemany(aconn, count, returning):
    async with aconn.cursor() as cur:
        await cur.executemany(
            ph(cur, "select * from generate_series(1, %s)"),
            [(i,) for i in range(count)],
            returning=returning,
        )
        ress = await alist(await res.fetchall() async for res in cur.results())
        if returning:
            assert ress == [[(j + 1,) for j in range(i)] for i in range(count)]
        else:
            assert ress == []


async def test_change_loader_results(aconn):
    cur = aconn.cursor()
    # With no result
    cur.adapters.register_loader("text", make_loader("1"))

    await cur.execute("""
        values ('foo'::text);
        values ('bar'::text), ('baz');
        values ('qux'::text);
        """)
    assert (await cur.fetchall()) == [("foo1",)]

    cur.nextset()
    assert (await cur.fetchone()) == ("bar1",)
    cur.adapters.register_loader("text", make_loader("2"))
    assert (await cur.fetchone()) == ("baz2",)
    await cur.scroll(-2)
    assert (await cur.fetchall()) == [("bar2",), ("baz2",)]

    cur.nextset()
    assert (await cur.fetchall()) == [("qux2",)]

    # After the final result
    assert not cur.nextset()
    cur.adapters.register_loader("text", make_loader("3"))
    assert (await cur.fetchone()) is None
    await cur.set_result(0)
    assert (await cur.fetchall()) == [("foo3",)]
