pFad - Phone/Frame/Anonymizer/Declutterfier! Saves Data!


--- a PPN by Garber Painting Akron. With Image Size Reduction included!

URL: http://github.com/modelcontextprotocol/python-sdk/pull/2363.diff

MessageMetadata = None, ) -> None: if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - ctx = ServerRequestContext( - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - ) - await handler(ctx, notify.params) + request_data = None + if isinstance(message_metadata, ServerMessageMetadata): + request_data = message_metadata.request_context + + with _bind_request_auth_context(request_data): + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..0ffda24f6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -43,9 +43,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + NotificationWithMetadata, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -60,7 +61,9 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception + RequestResponder[types.ClientRequest, types.ServerResult] + | NotificationWithMetadata[types.ClientNotification] + | Exception ) @@ -683,7 +686,15 @@ async def send_message(self, message: SessionMessage) -> None: """ await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: + async def _handle_incoming( + self, + req: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message_metadata: MessageMetadata = None, + ) -> None: + if isinstance(req, types.ClientNotification): + await self._incoming_message_stream_writer.send(NotificationWithMetadata(req, message_metadata)) + return + await self._incoming_message_stream_writer.send(req) @property diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f..b024f6221 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import dataclass from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -53,6 +54,14 @@ async def __call__( ) -> None: ... # pragma: no branch +@dataclass +class NotificationWithMetadata(Generic[ReceiveNotificationT]): + """A validated notification paired with its transport metadata.""" + + notification: ReceiveNotificationT + message_metadata: MessageMetadata = None + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -396,7 +405,7 @@ async def _receive_loop(self) -> None: except Exception: logging.exception("Progress callback raised an exception") await self._received_notification(notification) - await self._handle_incoming(notification) + await self._handle_incoming(notification, message.metadata) except Exception: # For other validation errors, log and continue logging.warning( # pragma: no cover @@ -515,6 +524,8 @@ async def send_progress_notification( """Sends a progress notification for a request that is currently being processed.""" async def _handle_incoming( - self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception + self, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, + message_metadata: MessageMetadata = None, ) -> None: """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py new file mode 100644 index 000000000..28a91da13 --- /dev/null +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -0,0 +1,198 @@ +"""Regression tests for auth context in StreamableHTTP servers.""" + +from __future__ import annotations + +import multiprocessing +import queue +import socket +import time +from collections.abc import Generator +from multiprocessing.queues import Queue + +import anyio +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Mount + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) +from tests.test_helpers import wait_for_server + + +class _EchoTokenVerifier: + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access = get_access_token() + text = access.token if access else "" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + +async def _handle_list_tools( + ctx: ServerRequestContext, + params: PaginatedRequestParams | None, +) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})]) + + +class _MutableBearerAuth(httpx.Auth): + def __init__(self, token: str) -> None: + self.token = token + + def auth_flow(self, request: httpx.Request): + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + +def _create_stateful_auth_app(progress_tokens: Queue[str] | None = None) -> Starlette: + async def _handle_progress(ctx: ServerRequestContext, params: object) -> None: + if progress_tokens is None: + return + + access = get_access_token() + progress_tokens.put(access.token if access else "") + + server = Server( + "auth-test-server", + on_call_tool=_handle_whoami, + on_list_tools=_handle_list_tools, + on_progress=_handle_progress, + ) + session_manager = StreamableHTTPSessionManager(app=server, stateless=False) + return Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + middleware=[ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), + Middleware(AuthContextMiddleware), + ], + lifespan=lambda app: session_manager.run(), + ) + + +def run_stateful_auth_server( + port: int, + progress_tokens: Queue[str] | None = None, +) -> None: # pragma: no cover + config = uvicorn.Config( + app=_create_stateful_auth_app(progress_tokens), + host="127.0.0.1", + port=port, + log_level="error", + access_log=False, + ) + uvicorn.Server(config).run() + + +@pytest.fixture +def stateful_auth_server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def stateful_auth_server( + stateful_auth_server_port: int, +) -> Generator[tuple[str, Queue[str]], None, None]: + progress_tokens: Queue[str] = multiprocessing.Queue() + proc = multiprocessing.Process( + target=run_stateful_auth_server, + kwargs={ + "port": stateful_auth_server_port, + "progress_tokens": progress_tokens, + }, + daemon=True, + ) + proc.start() + wait_for_server(stateful_auth_server_port) + + try: + yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp", progress_tokens + finally: + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): # pragma: no cover + proc.kill() + proc.join(timeout=1) + progress_tokens.close() + progress_tokens.join_thread() + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_request_in_stateful_session( + stateful_auth_server: tuple[str, Queue[str]], +) -> None: + server_url, _ = stateful_auth_server + auth = _MutableBearerAuth("token-A") + async with httpx.AsyncClient( + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client: + async with streamable_http_client(server_url, http_client=http_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + first_response = await session.call_tool("whoami", {}) + assert len(first_response.content) == 1 + assert isinstance(first_response.content[0], TextContent) + assert first_response.content[0].text == "token-A" + + auth.token = "token-B" + + second_response = await session.call_tool("whoami", {}) + assert len(second_response.content) == 1 + assert isinstance(second_response.content[0], TextContent) + assert second_response.content[0].text == "token-B" + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_notification_in_stateful_session( + stateful_auth_server: tuple[str, Queue[str]], +) -> None: + server_url, progress_tokens = stateful_auth_server + auth = _MutableBearerAuth("token-A") + async with httpx.AsyncClient( + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client: + async with streamable_http_client(server_url, http_client=http_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + auth.token = "token-B" + await session.send_progress_notification(progress_token="progress-1", progress=1) + + with anyio.fail_after(5): + while True: + try: + assert progress_tokens.get_nowait() == "token-B" + break + except queue.Empty: + await anyio.sleep(0.01) pFad - Phonifier reborn

Pfad - The Proxy pFad © 2024 Your Company Name. All rights reserved.





Check this box to remove all script contents from the fetched content.



Check this box to remove all images from the fetched content.


Check this box to remove all CSS styles from the fetched content.


Check this box to keep images inefficiently compressed and original size.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy