diff --git a/examples/cmd/examples-go/client/main.go b/examples/cmd/examples-go/client/main.go index 94a67ba..37f80f2 100644 --- a/examples/cmd/examples-go/client/main.go +++ b/examples/cmd/examples-go/client/main.go @@ -6,24 +6,74 @@ import ( "net/http" "connectrpc.com/connect" - elizav1 "github.com/gaudiy/connect-python/examples/proto/connectrpc/eliza/v1" "github.com/gaudiy/connect-python/examples/proto/connectrpc/eliza/v1/v1connect" ) +// func main() { +// client := v1connect.NewElizaServiceClient( +// http.DefaultClient, +// "http://localhost:8080/", +// ) +// req := connect.NewRequest(&elizav1.SayRequest{ +// Sentence: "Hi", +// }) +// req.Header().Set("Some-Header", "hello from connect") +// res, err := client.Say(context.Background(), req) +// if err != nil { +// log.Fatalln(err) +// } +// log.Println(res.Msg) +// log.Println(res.Header()) +// } + +// func main() { +// client := v1connect.NewElizaServiceClient( +// http.DefaultClient, +// "http://localhost:8080/", +// ) + +// stream := client.IntroduceClient(context.Background()) +// for i := 0; i < 5; i++ { +// err := stream.Send(&elizav1.IntroduceRequest{ +// Name: "Alice", +// }) +// if err != nil { +// log.Fatalln(err) +// break +// } +// } + +// res, err := stream.CloseAndReceive() +// if err != nil { +// log.Fatalln(err) +// } + +// log.Println(res.Msg) +// } + func main() { client := v1connect.NewElizaServiceClient( http.DefaultClient, "http://localhost:8080/", ) - req := connect.NewRequest(&elizav1.SayRequest{ - Sentence: "Hi", + + request := connect.NewRequest(&elizav1.IntroduceRequest{ + Name: "Alice", }) - req.Header().Set("Some-Header", "hello from connect") - res, err := client.Say(context.Background(), req) + stream, err := client.IntroduceServer(context.Background(), request) if err != nil { log.Fatalln(err) } - log.Println(res.Msg) - log.Println(res.Header()) + + number := int64(1) + for ; stream.Receive(); number++ { + log.Printf("Received message %d: %s", number, stream.Msg().Sentence) + } + + if err := stream.Err(); err != nil { + log.Fatalln(err) + } + + stream.Close() } diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py index aa6b6a6..b26bf03 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py @@ -5,12 +5,11 @@ """Generated connect code.""" import abc -from collections.abc import AsyncIterator from enum import Enum from connect.client import Client -from connect.connect import StreamRequest, UnaryRequest, UnaryResponse -from connect.handler import UnaryHandler +from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse +from connect.handler import ServerStreamHander, UnaryHandler from connect.options import ClientOptions, ConnectOptions from connect.session import AsyncClientSession from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor @@ -67,13 +66,9 @@ class ElizaServiceHandler(metaclass=abc.ABCMeta): @abc.abstractmethod async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayResponse]: ... @abc.abstractmethod - async def IntroduceServer( - self, request: UnaryRequest[IntroduceRequest] - ) -> AsyncIterator[UnaryResponse[IntroduceResponse]]: ... + async def IntroduceServer(self, request: StreamRequest[IntroduceRequest]) -> StreamResponse[IntroduceResponse]: ... @abc.abstractmethod - async def IntroduceClient( - self, request: StreamRequest[IntroduceRequest] - ) -> AsyncIterator[UnaryResponse[IntroduceResponse]]: ... + async def IntroduceClient(self, request: StreamRequest[IntroduceRequest]) -> StreamResponse[IntroduceResponse]: ... def create_ElizaService_handlers( @@ -86,6 +81,13 @@ def create_ElizaService_handlers( input=SayRequest, output=SayResponse, options=options, - ) + ), + ServerStreamHander( + procedure=ElizaServiceProcedures.IntroduceServer.value, + stream=service.IntroduceServer, + input=IntroduceRequest, + output=IntroduceResponse, + options=options, + ), ] return handlers diff --git a/src/connect/connect.py b/src/connect/connect.py index 944e03e..fb428ee 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -8,6 +8,7 @@ from pydantic import BaseModel +from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel from connect.utils import aiterate, get_callable_attribute @@ -294,7 +295,7 @@ def messages(self) -> AsyncIterator[T]: return self._messages -class StreamingHandlerConn(abc.ABC): +class UnaryHandlerConn(abc.ABC): """Abstract base class for a streaming handler connection. This class defines the interface for handling streaming connections, including @@ -389,6 +390,96 @@ def response_trailers(self) -> Headers: raise NotImplementedError() +class StreamingHandlerConn(abc.ABC): + """Abstract base class for a streaming handler connection. + + This class defines the interface for handling streaming connections, including + methods for specifying the connection, handling peer communication, receiving + and sending messages, and managing request and response headers and trailers. + + """ + + @property + @abc.abstractmethod + def spec(self) -> Spec: + """Return the specification details. + + Returns: + Spec: The specification details. + + """ + raise NotImplementedError() + + @property + @abc.abstractmethod + def peer(self) -> Peer: + """Establish a connection to a peer in the network. + + Returns: + Any: The result of the connection attempt. The exact type and structure + of the return value will depend on the implementation details. + + """ + raise NotImplementedError() + + @abc.abstractmethod + def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message and processes it. + + Args: + message (Any): The message to be received and processed. + + Returns: + Any: The result of processing the message. + + """ + raise NotImplementedError() + + @property + @abc.abstractmethod + def request_headers(self) -> Headers: + """Generate and return the request headers. + + Returns: + Any: The request headers. + + """ + raise NotImplementedError() + + @abc.abstractmethod + def send(self, messages: AsyncIterator[Any]) -> AsyncIterator[bytes]: + raise NotImplementedError() + + @property + @abc.abstractmethod + def response_headers(self) -> Headers: + """Retrieve the response headers. + + Returns: + Any: The response headers. + + """ + raise NotImplementedError() + + @property + @abc.abstractmethod + def response_trailers(self) -> Headers: + """Handle response trailers. + + This method is intended to be overridden in subclasses to provide + specific functionality for processing response trailers. + + Returns: + Any: The return type is not specified as this is a placeholder method. + + """ + raise NotImplementedError() + + @abc.abstractmethod + def finally_send(self, error: ConnectError | None) -> AsyncIterator[bytes]: + raise NotImplementedError() + + class UnaryClientConn: """Abstract base class for a streaming client connection.""" @@ -524,11 +615,11 @@ async def receive(self, message: Any) -> Any: raise NotImplementedError() -async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> UnaryRequest[T]: +async def receive_unary_request[T](conn: UnaryHandlerConn, t: type[T]) -> UnaryRequest[T]: """Receives a unary request from the given connection and returns a UnaryRequest object. Args: - conn (StreamingHandlerConn): The connection from which to receive the unary request. + conn (UnaryHandlerConn): The connection from which to receive the unary request. t (type[T]): The type of the message to be received. Returns: @@ -551,6 +642,22 @@ async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> Un ) +async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> StreamRequest[T]: + return StreamRequest( + messages=receive_stream_message(conn, t), + spec=conn.spec, + peer=conn.peer, + headers=conn.request_headers, + method=HTTPMethod.POST, + ) + + +async def receive_stream_message[T](conn: StreamingHandlerConn, t: type[T]) -> AsyncIterator[T]: + async for message in conn.receive(t): + # TODO(tsubakiky): Add validation for message type + yield cast(T, message) + + async def recieve_unary_response[T](conn: UnaryClientConn, t: type[T]) -> UnaryResponse[T]: """Receive a unary response from a streaming client connection. diff --git a/src/connect/demo.py b/src/connect/demo.py new file mode 100644 index 0000000..e745f23 --- /dev/null +++ b/src/connect/demo.py @@ -0,0 +1,66 @@ +from collections.abc import AsyncIterator + +import hypercorn.asyncio +from starlette.applications import Starlette +from starlette.middleware import Middleware + +from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse +from connect.middleware import ConnectMiddleware +from examples.proto.connectrpc.eliza.v1.eliza_pb2 import IntroduceRequest, IntroduceResponse, SayRequest, SayResponse +from examples.proto.connectrpc.eliza.v1.v1connect.eliza_connect_pb2 import ( + ElizaServiceHandler, + create_ElizaService_handlers, +) + + +class ElizaService(ElizaServiceHandler): + """Ping service implementation.""" + + async def Say(self, request: UnaryRequest[SayRequest]) -> UnaryResponse[SayResponse]: + """Return a ping response.""" + data = request.message + return UnaryResponse(SayResponse(sentence=data.sentence)) + + async def IntroduceClient(self, request: StreamRequest[IntroduceRequest]) -> StreamResponse[IntroduceResponse]: + raise NotImplementedError() + # """Introduce the client.""" + + # async def handler() -> AsyncIterator[IntroduceResponse]: + # async for message in request.messages: + # yield IntroduceResponse(sentence=f"Hello, {message.name}!") + + # return StreamResponse(handler()) + + async def IntroduceServer(self, request: StreamRequest[IntroduceRequest]) -> StreamResponse[IntroduceResponse]: + """Introduce the server.""" + messages = "" + async for message in request.messages: + messages += message.name + + print(f"Received messages: {messages}") + + async def handler() -> AsyncIterator[IntroduceResponse]: + for _ in range(3): + yield IntroduceResponse(sentence=f"Hello, {messages}!") + + return StreamResponse(handler()) + + +middleware = [ + Middleware( + ConnectMiddleware, + create_ElizaService_handlers(service=ElizaService()), + ) +] + +app = Starlette(middleware=middleware) + + +if __name__ == "__main__": + import asyncio + + import hypercorn + + config = hypercorn.Config() + config.bind = ["localhost:8080"] + asyncio.run(hypercorn.asyncio.serve(app, config)) # type: ignore diff --git a/src/connect/handler.py b/src/connect/handler.py index 96854f5..f35c010 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -1,11 +1,12 @@ """Module provides handler configurations and implementations for unary procedures and stream types.""" -from collections.abc import Awaitable, Callable +import inspect +from collections.abc import AsyncIterator, Awaitable, Callable from http import HTTPMethod, HTTPStatus -from typing import Any +from typing import Any, TypeGuard, cast import anyio -from starlette.responses import PlainTextResponse +from starlette.responses import PlainTextResponse, StreamingResponse from connect.code import Code from connect.codec import Codec, CodecMap, CodecNameType, ProtoBinaryCodec, ProtoJSONCodec @@ -13,9 +14,13 @@ from connect.connect import ( Spec, StreamingHandlerConn, + StreamRequest, + StreamResponse, StreamType, + UnaryHandlerConn, UnaryRequest, UnaryResponse, + receive_stream_request, receive_unary_request, ) from connect.error import ConnectError @@ -30,19 +35,25 @@ ProtocolHandlerParams, exclude_protocol_headers, mapped_method_handlers, + negotiate_compression, sorted_accept_post_value, sorted_allow_method_value, ) from connect.protocol_connect import ( CONNECT_HEADER_TIMEOUT, + CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, + CONNECT_STREAMING_HEADER_COMPRESSION, CONNECT_UNARY_CONTENT_TYPE_JSON, CONNECT_UNARY_TRAILER_PREFIX, + EndStreamMarshaler, + EnvelopeMarshaler, ProtocolConnect, connect_code_to_http, error_to_json_bytes, ) from connect.request import Request from connect.response import Response +from connect.utils import achain class HandlerConfig: @@ -141,68 +152,31 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: return handlers -type UnaryFunc[T_Request, T_Response] = Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]] - - -class UnaryHandler[T_Request, T_Response]: - """A handler for unary RPC (Remote Procedure Call) operations. +UnaryImplementationFunc = Callable[[UnaryHandlerConn], Awaitable[bytes]] +StreamImplementationFunc = Callable[[StreamingHandlerConn], Awaitable[AsyncIterator[bytes]]] - Attributes: - protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A dictionary mapping HTTP methods to lists of protocol handlers. - procedure (str): The name of the procedure being handled. - unary (UnaryFunc[Req, Res]): The unary function to be executed. - input (type[Req]): The type of the request input. - output (type[Res]): The type of the response output. - options (ConnectOptions | None): Optional configuration options for the handler. - - """ +class Handler: procedure: str - implementation: Callable[[StreamingHandlerConn], Awaitable[bytes]] + implementation: UnaryImplementationFunc | StreamImplementationFunc protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]] allow_methods: str accept_post: str + protocol_handler: ProtocolHandler def __init__( self, procedure: str, - unary: UnaryFunc[T_Request, T_Response], - input: type[T_Request], - output: type[T_Response], - options: ConnectOptions | None = None, + implementation: UnaryImplementationFunc | StreamImplementationFunc, + protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]], + allow_methods: str, + accept_post: str, ): - """Initialize the unary handler.""" - options = options if options is not None else ConnectOptions() - - config = HandlerConfig(procedure=procedure, stream_type=StreamType.Unary, options=options) - protocol_handlers = create_protocol_handlers(config) - - async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: - response = await unary(request) - - return response - - untyped = apply_interceptors(_untyped, options.interceptors) - - async def implementation(conn: StreamingHandlerConn) -> bytes: - request = await receive_unary_request(conn, input) - response = await untyped(request) - - if not isinstance(response.message, output): - raise ConnectError( - f"expected response of type: {output.__name__}", - Code.INTERNAL, - ) - - conn.response_headers.update(exclude_protocol_headers(response.headers)) - conn.response_trailers.update(exclude_protocol_headers(response.trailers)) - return conn.send(response.message) - self.procedure = procedure self.implementation = implementation - self.protocol_handlers = mapped_method_handlers(protocol_handlers) - self.allow_methods = sorted_allow_method_value(protocol_handlers) - self.accept_post = sorted_accept_post_value(protocol_handlers) + self.protocol_handlers = protocol_handlers + self.allow_methods = allow_methods + self.accept_post = accept_post async def handle(self, request: Request) -> Response: """Handle an incoming HTTP request and return an HTTP response. @@ -229,6 +203,7 @@ async def handle(self, request: Request) -> Response: content_type = request.headers.get(HEADER_CONTENT_TYPE, "") protocol_handler: ProtocolHandler | None = None + for handler in protocol_handlers: if handler.can_handle_payload(request, content_type): protocol_handler = handler @@ -239,6 +214,8 @@ async def handle(self, request: Request) -> Response: status = HTTPStatus.UNSUPPORTED_MEDIA_TYPE return PlainTextResponse(content=status.phrase, headers=response_headers, status_code=status.value) + self.protocol_handler = protocol_handler + if HTTPMethod(request.method) == HTTPMethod.GET: content_length = request.headers.get(HEADER_CONTENT_LENGTH, None) has_body = False @@ -259,6 +236,8 @@ async def handle(self, request: Request) -> Response: return PlainTextResponse(content=status.phrase, headers=response_headers, status_code=status.value) status_code = HTTPStatus.OK.value + body: bytes | AsyncIterator[bytes] = b"" + error: ConnectError | None = None try: timeout = request.headers.get(CONNECT_HEADER_TIMEOUT, None) timeout_sec = None @@ -270,9 +249,124 @@ async def handle(self, request: Request) -> Response: timeout_sec = timeout_ms / 1000 + conn = await protocol_handler.conn(request, response_headers, response_trailers) + with anyio.fail_after(timeout_sec): - conn = await protocol_handler.conn(request, response_headers, response_trailers) - body = await self.implementation(conn) + if isinstance(conn, UnaryHandlerConn): + body = await cast(UnaryImplementationFunc, self.implementation)(conn) + else: + body = await cast(StreamImplementationFunc, self.implementation)(conn) + + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + + if isinstance(e, TimeoutError): + error = ConnectError("the operation timed out", Code.DEADLINE_EXCEEDED) + + finally: + if isinstance(body, bytes): + if error: + status_code = connect_code_to_http(error.code) + + response_headers[HEADER_CONTENT_TYPE] = CONNECT_UNARY_CONTENT_TYPE_JSON + if not error.wire_error: + response_headers.update(error.metadata) + + body = error_to_json_bytes(error) + + for key, value in response_trailers.items(): + response_headers[CONNECT_UNARY_TRAILER_PREFIX + key] = value + + response = Response(content=body, headers=response_headers, status_code=status_code) + else: + assert isinstance(conn, StreamingHandlerConn) + + body = achain(body, conn.finally_send(error)) + response = StreamingResponse(content=body, headers=response_headers, status_code=status_code) + + return response + + def is_stream(self, conn: UnaryHandlerConn | StreamingHandlerConn) -> TypeGuard[StreamingHandlerConn]: + return isinstance(conn, StreamingHandlerConn) + + def is_unary(self, conn: UnaryHandlerConn | StreamingHandlerConn) -> TypeGuard[UnaryHandlerConn]: + return isinstance(conn, UnaryHandlerConn) + + def is_stream_impl( + self, impl: UnaryImplementationFunc | StreamImplementationFunc + ) -> TypeGuard[StreamImplementationFunc]: + signature = inspect.signature(impl) + parameters = list(signature.parameters.values()) + return bool(callable(next) and len(parameters) == 1 and parameters[0].annotation == StreamingHandlerConn) + + def is_unary_impl( + self, impl: UnaryImplementationFunc | StreamImplementationFunc + ) -> TypeGuard[UnaryImplementationFunc]: + signature = inspect.signature(impl) + parameters = list(signature.parameters.values()) + return bool(callable(next) and len(parameters) == 1 and parameters[0].annotation == UnaryHandlerConn) + + async def stream_handle(self, request: Request, response_headers: Headers, response_trailers: Headers) -> Response: + status_code = HTTPStatus.OK.value + body: AsyncIterator[bytes] + error: ConnectError | None = None + + content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) + + _, response_compression = negotiate_compression( + self.protocol_handler.params.compressions, content_encoding, accept_encoding + ) + + end_stream_marshaler = EndStreamMarshaler( + marshaler=EnvelopeMarshaler( + compression=response_compression, + send_max_bytes=self.protocol_handler.params.send_max_bytes, + compress_min_bytes=self.protocol_handler.params.compress_min_bytes, + ) + ) + + try: + conn = await self.protocol_handler.conn(request, response_headers, response_trailers) + + implementation = self.implementation + if self.is_stream(conn) and self.is_stream_impl(implementation): + _conn = conn + body = await implementation(conn) + + except Exception as e: + error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) + + finally: + assert isinstance(conn, StreamingHandlerConn) + body = achain(body, end_stream_marshaler.marshal(error, response_headers)) + response = StreamingResponse(content=body, headers=conn.response_headers, status_code=status_code) + + return response + + async def unary_handle(self, request: Request, response_headers: Headers, response_trailers: Headers) -> Response: + status_code = HTTPStatus.OK.value + body: bytes + error: ConnectError | None = None + try: + timeout = request.headers.get(CONNECT_HEADER_TIMEOUT, None) + timeout_sec = None + if timeout is not None: + try: + timeout_ms = float(timeout) + except ValueError as e: + raise ConnectError(f"parse timeout: {str(e)}", Code.INVALID_ARGUMENT) from e + + timeout_sec = timeout_ms / 1000 + + conn = await self.protocol_handler.conn(request, response_headers, response_trailers) + + implementation = self.implementation + with anyio.fail_after(delay=timeout_sec): + if self.is_unary(conn) and self.is_unary_impl(implementation): + body = await implementation(conn) + else: + raise ValueError(f"Invalid function type for unary handler: {implementation}") except Exception as e: error = e if isinstance(e, ConnectError) else ConnectError("internal error", Code.INTERNAL) @@ -280,17 +374,114 @@ async def handle(self, request: Request) -> Response: if isinstance(e, TimeoutError): error = ConnectError("the operation timed out", Code.DEADLINE_EXCEEDED) - status_code = connect_code_to_http(error.code) + finally: + if error: + status_code = connect_code_to_http(error.code) - response_headers[HEADER_CONTENT_TYPE] = CONNECT_UNARY_CONTENT_TYPE_JSON - if not error.wire_error: - response_headers.update(error.metadata) + conn.response_headers[HEADER_CONTENT_TYPE] = CONNECT_UNARY_CONTENT_TYPE_JSON + if not error.wire_error: + conn.response_headers.update(error.metadata) - body = error_to_json_bytes(error) + body = error_to_json_bytes(error) - for key, value in response_trailers.items(): - response_headers[CONNECT_UNARY_TRAILER_PREFIX + key] = value + for key, value in conn.response_trailers.items(): + conn.response_headers[CONNECT_UNARY_TRAILER_PREFIX + key] = value - response = Response(content=body, headers=response_headers, status_code=status_code) + response = Response(content=body, headers=conn.response_headers, status_code=status_code) return response + + +type UnaryFunc[T_Request, T_Response] = Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]] +type StreamFunc[T_Request, T_Response] = Callable[[StreamRequest[T_Request]], Awaitable[StreamResponse[T_Response]]] + + +class UnaryHandler[T_Request, T_Response](Handler): + """A handler for unary RPC (Remote Procedure Call) operations. + + Attributes: + protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A dictionary mapping HTTP methods to lists of protocol handlers. + procedure (str): The name of the procedure being handled. + unary (UnaryFunc[Req, Res]): The unary function to be executed. + input (type[Req]): The type of the request input. + output (type[Res]): The type of the response output. + options (ConnectOptions | None): Optional configuration options for the handler. + + """ + + procedure: str + protocol_handlers: dict[HTTPMethod, list[ProtocolHandler]] + allow_methods: str + accept_post: str + + def __init__( + self, + procedure: str, + unary: UnaryFunc[T_Request, T_Response], + input: type[T_Request], + output: type[T_Response], + options: ConnectOptions | None = None, + ): + """Initialize the unary handler.""" + options = options if options is not None else ConnectOptions() + + config = HandlerConfig(procedure=procedure, stream_type=StreamType.Unary, options=options) + protocol_handlers = create_protocol_handlers(config) + + async def _untyped(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: + response = await unary(request) + + return response + + untyped = apply_interceptors(_untyped, options.interceptors) + + async def implementation(conn: UnaryHandlerConn) -> bytes: + request = await receive_unary_request(conn, input) + response = await untyped(request) + + if not isinstance(response.message, output): + raise ConnectError( + f"expected response of type: {output.__name__}", + Code.INTERNAL, + ) + + conn.response_headers.update(exclude_protocol_headers(response.headers)) + conn.response_trailers.update(exclude_protocol_headers(response.trailers)) + return conn.send(response.message) + + super().__init__( + procedure=procedure, + implementation=implementation, + protocol_handlers=mapped_method_handlers(protocol_handlers), + allow_methods=sorted_allow_method_value(protocol_handlers), + accept_post=sorted_accept_post_value(protocol_handlers), + ) + + +class ServerStreamHander[T_Request, T_Response](Handler): + def __init__( + self, + procedure: str, + stream: StreamFunc[T_Request, T_Response], + input: type[T_Request], + output: type[T_Response], # noqa: ARG002 + options: ConnectOptions | None = None, + ): + options = options if options is not None else ConnectOptions() + config = HandlerConfig(procedure=procedure, stream_type=StreamType.ServerStream, options=options) + protocol_handlers = create_protocol_handlers(config) + + async def implementation(conn: StreamingHandlerConn) -> AsyncIterator[bytes]: + request = await receive_stream_request(conn, input) + + response = await stream(request) + + return conn.send(response.messages) + + super().__init__( + procedure=procedure, + implementation=implementation, + protocol_handlers=mapped_method_handlers(protocol_handlers), + allow_methods=sorted_allow_method_value(protocol_handlers), + accept_post=sorted_accept_post_value(protocol_handlers), + ) diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 04c33c2..f05db12 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -9,7 +9,15 @@ from connect.code import Code from connect.codec import Codec, ReadOnlyCodecs from connect.compression import COMPRESSION_IDENTITY, Compression -from connect.connect import Peer, Spec, StreamingClientConn, StreamingHandlerConn, StreamType, UnaryClientConn +from connect.connect import ( + Peer, + Spec, + StreamingClientConn, + StreamingHandlerConn, + StreamType, + UnaryClientConn, + UnaryHandlerConn, +) from connect.error import ConnectError from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel @@ -111,6 +119,17 @@ def methods(self) -> list[HTTPMethod]: """ raise NotImplementedError() + @property + @abc.abstractmethod + def params(self) -> ProtocolHandlerParams: + """Retrieve the parameters for the protocol handler. + + Returns: + ProtocolHandlerParams: The parameters for the protocol handler. + + """ + raise NotImplementedError + @abc.abstractmethod def content_types(self) -> list[str]: """Handle content types. @@ -144,7 +163,7 @@ def can_handle_payload(self, request: Request, content_type: str) -> bool: @abc.abstractmethod async def conn( self, request: Request, response_headers: Headers, response_trailers: Headers - ) -> StreamingHandlerConn: + ) -> UnaryHandlerConn | StreamingHandlerConn: """Handle the connection for a given request and response headers. Args: @@ -153,7 +172,7 @@ async def conn( response_trailers (Headers): The mutable headers for the response trailers. Returns: - StreamingHandlerConn: The connection handler for streaming. + UnaryHandlerConn: The connection handler for streaming. """ raise NotImplementedError() diff --git a/src/connect/protocol_connect.py b/src/connect/protocol_connect.py index bae892e..bf6b445 100644 --- a/src/connect/protocol_connect.py +++ b/src/connect/protocol_connect.py @@ -23,7 +23,16 @@ from connect.code import Code from connect.codec import Codec, CodecNameType, StableCodec from connect.compression import COMPRESSION_IDENTITY, Compression, get_compresion_from_name -from connect.connect import Address, Peer, Spec, StreamingClientConn, StreamingHandlerConn, StreamType, UnaryClientConn +from connect.connect import ( + Address, + Peer, + Spec, + StreamingClientConn, + StreamingHandlerConn, + StreamType, + UnaryClientConn, + UnaryHandlerConn, +) from connect.envelope import Envelope, EnvelopeFlags from connect.error import DEFAULT_ANY_RESOLVER_PREFIX, ConnectError, ErrorDetail from connect.headers import Headers, include_request_headers @@ -104,6 +113,117 @@ def connect_content_type_from_codec_name(stream_type: StreamType, codec_name: st return CONNECT_STREAMING_CONTENT_TYPE_PREFIX + codec_name +class ConnectStreamingHandler(ProtocolHandler): + """A handler for managing protocol connections. + + Attributes: + params (ProtocolHandlerParams): Parameters for the protocol handler. + __methods (list[HTTPMethod]): List of HTTP methods supported by the handler. + accept (list[str]): List of accepted content types. + + """ + + params: ProtocolHandlerParams + _methods: list[HTTPMethod] + accept: list[str] + + def __init__(self, params: ProtocolHandlerParams, methods: list[HTTPMethod], accept: list[str]) -> None: + """Initialize the ProtocolConnect instance. + + Args: + params (ProtocolHandlerParams): The parameters for the protocol handler. + methods (list[HTTPMethod]): A list of HTTP methods. + accept (list[str]): A list of accepted content types. + + """ + self.params = params + self._methods = methods + self.accept = accept + + @property + def methods(self) -> list[HTTPMethod]: + """Return the list of HTTP methods. + + Returns: + list[HTTPMethod]: A list of HTTP methods. + + """ + return self._methods + + def content_types(self) -> list[str]: + """Handle content types. + + This method currently does nothing and serves as a placeholder for future + implementation related to content types. + + """ + return self.accept + + def can_handle_payload(self, request: Request, content_type: str) -> bool: + """Check if the handler can handle the payload.""" + return content_type in self.accept + + async def conn( + self, request: Request, response_headers: Headers, response_trailers: Headers + ) -> UnaryHandlerConn | StreamingHandlerConn: + content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) + + request_compression, response_compression = negotiate_compression( + self.params.compressions, content_encoding, accept_encoding + ) + + connect_check_protocol_version(request, False) + + request_stream = AsyncByteStream(aiterator=request.stream()) + content_type = request.headers.get(HEADER_CONTENT_TYPE, "") + codec_name = connect_codec_from_content_type(self.params.spec.stream_type, content_type) + + codec = self.params.codecs.get(codec_name) + if codec is None: + raise ConnectError( + f"invalid message encoding: {codec_name}", + Code.INVALID_ARGUMENT, + ) + + response_headers[HEADER_CONTENT_TYPE] = content_type + + accept_compression_header = CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION + if response_compression and response_compression.name != COMPRESSION_IDENTITY: + response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name + + response_headers[accept_compression_header] = f"{', '.join(c.name for c in self.params.compressions)}" + + peer = Peer( + address=Address(host=request.client.host, port=request.client.port) if request.client else request.client, + protocol=PROTOCOL_CONNECT, + query=request.query_params, + ) + + stream_conn = ConnectStreamingHandlerConn( + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectStreamingMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + ), + unmarshaler=ConnectStreamingUnmarshaler( + stream=request_stream, + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) + + return stream_conn + + class ConnectHandler(ProtocolHandler): """A handler for managing protocol connections. @@ -160,7 +280,7 @@ def can_handle_payload(self, request: Request, content_type: str) -> bool: async def conn( self, request: Request, response_headers: Headers, response_trailers: Headers - ) -> StreamingHandlerConn: + ) -> UnaryHandlerConn | StreamingHandlerConn: """Handle the connection for the given request and response headers. Args: @@ -169,7 +289,7 @@ async def conn( response_trailers (Headers): The headers for the response trailers. Returns: - StreamingHandlerConn: The connection handler for the request. + UnaryHandlerConn: The connection handler for the request. Raises: ValueError: If the request method is not supported or if the codec is not found. @@ -191,9 +311,8 @@ async def conn( accept_encoding = request.headers.get(CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION, None) else: - # Streaming support is not yet implemented - content_encoding = None - accept_encoding = None + content_encoding = request.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, None) + accept_encoding = request.headers.get(CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, None) request_compression, response_compression = negotiate_compression( self.params.compressions, content_encoding, accept_encoding @@ -210,7 +329,7 @@ async def conn( f"missing {CONNECT_UNARY_ENCODING_QUERY_PARAMETER} parameter", Code.INVALID_ARGUMENT, ) - elif message is None: + if message is None: raise ConnectError( f"missing {CONNECT_UNARY_MESSAGE_QUERY_PARAMETER} parameter", Code.INVALID_ARGUMENT, @@ -242,8 +361,10 @@ async def stream() -> AsyncGenerator[bytes]: response_headers[HEADER_CONTENT_TYPE] = content_type accept_compression_header = CONNECT_UNARY_HEADER_ACCEPT_COMPRESSION if self.params.spec.stream_type != StreamType.Unary: - # TODO(tsubakiky): Add streaming support - pass + accept_compression_header = CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION + if response_compression and response_compression.name != COMPRESSION_IDENTITY: + response_headers[CONNECT_STREAMING_HEADER_COMPRESSION] = response_compression.name + response_headers[accept_compression_header] = f"{', '.join(c.name for c in self.params.compressions)}" peer = Peer( @@ -253,7 +374,7 @@ async def stream() -> AsyncGenerator[bytes]: ) if self.params.spec.stream_type == StreamType.Unary: - conn = ConnectUnaryHandlerConn( + unary_conn = ConnectUnaryHandlerConn( request=request, peer=peer, spec=self.params.spec, @@ -274,11 +395,30 @@ async def stream() -> AsyncGenerator[bytes]: response_headers=response_headers, response_trailers=response_trailers, ) + return unary_conn else: - # TODO(tsubakiky): Add streaming support - pass + stream_conn = ConnectStreamingHandlerConn( + request=request, + peer=peer, + spec=self.params.spec, + marshaler=ConnectStreamingMarshaler( + codec=codec, + compress_min_bytes=self.params.compress_min_bytes, + send_max_bytes=self.params.send_max_bytes, + compression=response_compression, + ), + unmarshaler=ConnectStreamingUnmarshaler( + stream=request_stream, + codec=codec, + compression=request_compression, + read_max_bytes=self.params.read_max_bytes, + ), + request_headers=Headers(request.headers, encoding="latin-1"), + response_headers=response_headers, + response_trailers=response_trailers, + ) - return conn + return stream_conn class ProtocolConnect(Protocol): @@ -511,7 +651,7 @@ def marshal(self, message: Any) -> bytes: return data -class ConnectUnaryHandlerConn(StreamingHandlerConn): +class ConnectUnaryHandlerConn(UnaryHandlerConn): """ConnectUnaryHandlerConn is a handler connection class for unary RPCs in the Connect protocol. Attributes: @@ -1060,6 +1200,35 @@ async def marshal(self, messages: AsyncIterator[Any]) -> AsyncIterator[bytes]: env.data = compressed_data yield env.encode() + def write_envelope(self, env: Envelope) -> Envelope: + if env.is_set(EnvelopeFlags.compressed) or self.compression is None or len(env.data) < self.compress_min_bytes: + if self.send_max_bytes > 0 and len(env.data) > self.send_max_bytes: + raise ConnectError( + f"message size {len(env.data)} exceeds sendMaxBytes {self.send_max_bytes}", Code.RESOURCE_EXHAUSTED + ) + compressed_data = env.data + flags = env.flags + else: + compressed_data = self.compression.compress(env.data) + flags = EnvelopeFlags(env.flags | EnvelopeFlags.compressed) + + if self.send_max_bytes > 0 and len(env.data) > self.send_max_bytes: + raise ConnectError( + f"compressed message size {len(env.data)} exceeds send_mas_bytes {self.send_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) + + return Envelope( + data=compressed_data, + flags=flags, + ) + + def marshal_end_stream(self, data: bytes) -> bytes: + env = Envelope(data, EnvelopeFlags(EnvelopeFlags.end_stream)) + env = self.write_envelope(env) + + return env.encode() + class ConnectStreamingUnmarshaler: """A class to handle the unmarshaling of streaming data. @@ -1199,6 +1368,164 @@ def end_stream_error(self) -> ConnectError | None: return self._end_stream_error +class EnvelopeMarshaler: + compress_min_bytes: int + send_max_bytes: int + compression: Compression | None + + def __init__(self, compression: Compression | None, compress_min_bytes: int, send_max_bytes: int) -> None: + self.compress_min_bytes = compress_min_bytes + self.send_max_bytes = send_max_bytes + self.compression = compression + + def write(self, env: Envelope) -> Envelope: + if env.is_set(EnvelopeFlags.compressed) or self.compression is None or len(env.data) < self.compress_min_bytes: + if self.send_max_bytes > 0 and len(env.data) > self.send_max_bytes: + raise ConnectError( + f"message size {len(env.data)} exceeds sendMaxBytes {self.send_max_bytes}", Code.RESOURCE_EXHAUSTED + ) + compressed_data = env.data + flags = env.flags + else: + compressed_data = self.compression.compress(env.data) + flags = EnvelopeFlags(env.flags | EnvelopeFlags.compressed) + + if self.send_max_bytes > 0 and len(env.data) > self.send_max_bytes: + raise ConnectError( + f"compressed message size {len(env.data)} exceeds send_mas_bytes {self.send_max_bytes}", + Code.RESOURCE_EXHAUSTED, + ) + + return Envelope( + data=compressed_data, + flags=flags, + ) + + def marshal_end_stream(self, data: bytes) -> bytes: + env = Envelope(data, EnvelopeFlags(EnvelopeFlags.end_stream)) + env = self.write(env) + + return env.encode() + + +class EndStreamMarshaler: + marshaler: EnvelopeMarshaler + + def __init__(self, marshaler: EnvelopeMarshaler) -> None: + self.marshaler = marshaler + + async def marshal(self, error: ConnectError | None, trailers: Headers) -> AsyncIterator[bytes]: + json_obj = end_stream_to_json(error, trailers) + json_str = json.dumps(json_obj) + + yield self.marshaler.marshal_end_stream(json_str.encode()) + + +class ConnectStreamingHandlerConn(StreamingHandlerConn): + request: Request + _peer: Peer + _spec: Spec + marshaler: ConnectStreamingMarshaler + unmarshaler: ConnectStreamingUnmarshaler + _request_headers: Headers + _response_headers: Headers + _response_trailers: Headers + + def __init__( + self, + request: Request, + peer: Peer, + spec: Spec, + marshaler: ConnectStreamingMarshaler, + unmarshaler: ConnectStreamingUnmarshaler, + request_headers: Headers, + response_headers: Headers, + response_trailers: Headers | None = None, + ) -> None: + self.request = request + self._peer = peer + self._spec = spec + self.marshaler = marshaler + self.unmarshaler = unmarshaler + self._request_headers = request_headers + self._response_headers = response_headers + self._response_trailers = response_trailers or Headers() + + @property + def spec(self) -> Spec: + """Return the specification object. + + Returns: + Spec: The specification object. + + """ + return self._spec + + @property + def peer(self) -> Peer: + """Return the peer associated with this instance. + + :return: The peer associated with this instance. + :rtype: Peer + """ + return self._peer + + async def receive(self, message: Any) -> AsyncIterator[Any]: + """Receives a message, unmarshals it, and returns the resulting object. + + Args: + message (Any): The message to be unmarshaled. + + Returns: + Any: The unmarshaled object. + + """ + async for obj, _ in self.unmarshaler.unmarshal(message): + yield obj + + @property + def request_headers(self) -> Headers: + """Retrieve the headers from the request. + + Returns: + Mapping[str, str]: A dictionary-like object containing the request headers. + + """ + return self._request_headers + + def send(self, messages: AsyncIterator[Any]) -> AsyncIterator[bytes]: + return self.marshaler.marshal(messages) + + @property + def response_headers(self) -> Headers: + """Retrieve the response headers. + + Returns: + Any: The response headers. + + """ + return self._response_headers + + @property + def response_trailers(self) -> Headers: + """Handle response trailers. + + This method is intended to be overridden in subclasses to provide + specific functionality for processing response trailers. + + Returns: + Any: The processed response trailer data. + + """ + return self._response_trailers + + async def finally_send(self, error: ConnectError | None) -> AsyncIterator[bytes]: + json_obj = end_stream_to_json(error, self.response_trailers) + json_str = json.dumps(json_obj) + + yield self.marshaler.marshal_end_stream(json_str.encode()) + + EventHook = Callable[..., Any] @@ -1412,8 +1739,6 @@ async def send(self, messages: AsyncIterator[Any]) -> None: await self._validate_response(response) - return - async def _validate_response(self, response: httpcore.Response) -> None: response_headers = Headers(response.headers) @@ -1709,8 +2034,6 @@ def json_ummarshal(data: bytes, _message: Any) -> Any: wire_error.metadata.update(self._response_trailers) raise wire_error - return - @property def event_hooks(self) -> dict[str, list[EventHook]]: """Return the event hooks. @@ -2002,6 +2325,20 @@ def end_stream_from_bytes(data: bytes) -> tuple[ConnectError | None, Headers]: return None, metadata +def end_stream_to_json(error: ConnectError | None, trailers: Headers) -> dict[str, Any]: + json_dict = {} + + metadata = Headers(trailers.copy()) + if error: + json_dict["error"] = error_to_json(error) + metadata.update(error.metadata.copy()) + + if len(metadata) > 0: + json_dict["metadata"] = {k: v.split(", ") for k, v in metadata.items()} + + return json_dict + + def error_to_json(error: ConnectError) -> dict[str, Any]: """Convert a ConnectError object to a JSON-serializable dictionary. diff --git a/src/connect/utils.py b/src/connect/utils.py index 7ef6b49..13869b6 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -249,3 +249,9 @@ def map_httpcore_exceptions() -> Iterator[None]: raise ConnectError(str(exc), to_code) from exc raise exc + + +async def achain[T](*itrs: typing.AsyncIterable[T]) -> typing.AsyncIterator[T]: + for itr in itrs: + async for item in itr: + yield item