import collections

import h2.config
import h2.connection
import pytest

import httpcore
from httpcore._backends.auto import (
    AsyncBackend,
    AsyncLock,
    AsyncSemaphore,
    AsyncSocketStream,
)


class MockStream(AsyncSocketStream):
    def __init__(self, http_buffer, disconnect):
        self.read_buffer = collections.deque(http_buffer)
        self.disconnect = disconnect

    def get_http_version(self) -> str:
        return "HTTP/2"

    async def write(self, data, timeout):
        pass

    async def read(self, n, timeout):
        return self.read_buffer.popleft()

    async def aclose(self):
        pass

    def is_readable(self):
        return self.disconnect


class MockLock(AsyncLock):
    async def release(self):
        pass

    async def acquire(self):
        pass


class MockSemaphore(AsyncSemaphore):
    def __init__(self):
        pass

    async def acquire(self, timeout=None):
        pass

    async def release(self):
        pass


class MockBackend(AsyncBackend):
    def __init__(self, http_buffer, disconnect=False):
        self.http_buffer = http_buffer
        self.disconnect = disconnect

    async def open_tcp_stream(
        self, hostname, port, ssl_context, timeout, *, local_address
    ):
        return MockStream(self.http_buffer, self.disconnect)

    def create_lock(self):
        return MockLock()

    def create_semaphore(self, max_value, exc_class):
        return MockSemaphore()


class HTTP2BytesGenerator:
    def __init__(self):
        self.client_config = h2.config.H2Configuration(client_side=True)
        self.client_conn = h2.connection.H2Connection(config=self.client_config)
        self.server_config = h2.config.H2Configuration(client_side=False)
        self.server_conn = h2.connection.H2Connection(config=self.server_config)
        self.initialized = False

    def get_server_bytes(
        self, request_headers, request_data, response_headers, response_data
    ):
        if not self.initialized:
            self.client_conn.initiate_connection()
            self.server_conn.initiate_connection()
            self.initialized = True

        # Feed the request events to the client-side state machine
        client_stream_id = self.client_conn.get_next_available_stream_id()
        self.client_conn.send_headers(client_stream_id, headers=request_headers)
        self.client_conn.send_data(client_stream_id, data=request_data, end_stream=True)

        # Determine the bytes that're sent out the client side, and feed them
        # into the server-side state machine to get it into the correct state.
        client_bytes = self.client_conn.data_to_send()
        events = self.server_conn.receive_data(client_bytes)
        server_stream_id = [
            event.stream_id
            for event in events
            if isinstance(event, h2.events.RequestReceived)
        ][0]

        # Feed the response events to the server-side state machine
        self.server_conn.send_headers(server_stream_id, headers=response_headers)
        self.server_conn.send_data(
            server_stream_id, data=response_data, end_stream=True
        )

        return self.server_conn.data_to_send()


@pytest.mark.trio
async def test_get_request() -> None:
    bytes_generator = HTTP2BytesGenerator()
    http_buffer = [
        bytes_generator.get_server_bytes(
            request_headers=[
                (b":method", b"GET"),
                (b":authority", b"www.example.com"),
                (b":scheme", b"https"),
                (b":path", "/"),
            ],
            request_data=b"",
            response_headers=[
                (b":status", b"200"),
                (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"),
                (b"server", b"Apache"),
                (b"content-length", b"13"),
                (b"content-type", b"text/plain"),
            ],
            response_data=b"Hello, world.",
        ),
        bytes_generator.get_server_bytes(
            request_headers=[
                (b":method", b"GET"),
                (b":authority", b"www.example.com"),
                (b":scheme", b"https"),
                (b":path", "/"),
            ],
            request_data=b"",
            response_headers=[
                (b":status", b"200"),
                (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"),
                (b"server", b"Apache"),
                (b"content-length", b"13"),
                (b"content-type", b"text/plain"),
            ],
            response_data=b"Hello, world.",
        ),
    ]
    backend = MockBackend(http_buffer=http_buffer)

    async with httpcore.AsyncConnectionPool(http2=True, backend=backend) as http:
        # We're sending a request with a standard keep-alive connection, so
        # it will remain in the pool once we've sent the request.
        response = await http.handle_async_request(
            method=b"GET",
            url=(b"https", b"example.org", None, b"/"),
            headers=[(b"Host", b"example.org")],
            stream=httpcore.ByteStream(b""),
            extensions={},
        )
        status_code, headers, stream, extensions = response
        body = await stream.aread()
        assert status_code == 200
        assert body == b"Hello, world."
        assert await http.get_connection_info() == {
            "https://example.org": ["HTTP/2, IDLE, 0 streams"]
        }

        # The second HTTP request will go out over the same connection.
        response = await http.handle_async_request(
            method=b"GET",
            url=(b"https", b"example.org", None, b"/"),
            headers=[(b"Host", b"example.org")],
            stream=httpcore.ByteStream(b""),
            extensions={},
        )
        status_code, headers, stream, extensions = response
        body = await stream.aread()
        assert status_code == 200
        assert body == b"Hello, world."
        assert await http.get_connection_info() == {
            "https://example.org": ["HTTP/2, IDLE, 0 streams"]
        }


@pytest.mark.trio
async def test_post_request() -> None:
    bytes_generator = HTTP2BytesGenerator()
    bytes_to_send = bytes_generator.get_server_bytes(
        request_headers=[
            (b":method", b"POST"),
            (b":authority", b"www.example.com"),
            (b":scheme", b"https"),
            (b":path", "/"),
            (b"content-length", b"13"),
        ],
        request_data=b"Hello, world.",
        response_headers=[
            (b":status", b"200"),
            (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"),
            (b"server", b"Apache"),
            (b"content-length", b"13"),
            (b"content-type", b"text/plain"),
        ],
        response_data=b"Hello, world.",
    )
    backend = MockBackend(http_buffer=[bytes_to_send])

    async with httpcore.AsyncConnectionPool(http2=True, backend=backend) as http:
        # We're sending a request with a standard keep-alive connection, so
        # it will remain in the pool once we've sent the request.
        response = await http.handle_async_request(
            method=b"POST",
            url=(b"https", b"example.org", None, b"/"),
            headers=[(b"Host", b"example.org"), (b"Content-length", b"13")],
            stream=httpcore.ByteStream(b"Hello, world."),
            extensions={},
        )
        status_code, headers, stream, extensions = response
        body = await stream.aread()
        assert status_code == 200
        assert body == b"Hello, world."
        assert await http.get_connection_info() == {
            "https://example.org": ["HTTP/2, IDLE, 0 streams"]
        }


@pytest.mark.trio
async def test_request_with_missing_host_header() -> None:
    backend = MockBackend(http_buffer=[])

    server_config = h2.config.H2Configuration(client_side=False)
    server_conn = h2.connection.H2Connection(config=server_config)
    server_conn.initiate_connection()
    backend = MockBackend(http_buffer=[server_conn.data_to_send()])

    async with httpcore.AsyncConnectionPool(backend=backend) as http:
        with pytest.raises(httpcore.LocalProtocolError) as excinfo:
            await http.handle_async_request(
                method=b"GET",
                url=(b"http", b"example.org", None, b"/"),
                headers=[],
                stream=httpcore.ByteStream(b""),
                extensions={},
            )
        assert str(excinfo.value) == "Missing mandatory Host: header"
