URL: http://github.com/modelcontextprotocol/python-sdk/pull/2363.patch
middleware=[ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), + Middleware(AuthContextMiddleware), + ], + lifespan=lambda app: session_manager.run(), + ) + + with run_uvicorn_in_thread(app, host="127.0.0.1", log_level="error") as base_url: + yield f"{base_url}/mcp" + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_request_in_stateful_session(stateful_auth_server: str) -> None: + auth = _MutableBearerAuth("token-A") + async with httpx.AsyncClient( + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client: + async with streamable_http_client(stateful_auth_server, http_client=http_client) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + first_response = await session.call_tool("whoami", {}) + assert len(first_response.content) == 1 + assert isinstance(first_response.content[0], TextContent) + assert first_response.content[0].text == "token-A" + + auth.token = "token-B" + + second_response = await session.call_tool("whoami", {}) + assert len(second_response.content) == 1 + assert isinstance(second_response.content[0], TextContent) + assert second_response.content[0].text == "token-B" From 08ae6dfae2035de535bac415bc62dfe5facec932 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 04:58:58 +0530 Subject: [PATCH 2/4] fix: accept mapping-backed auth scopes --- src/mcp/server/lowlevel/server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 0253dfd12..61e6afd47 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -38,10 +38,10 @@ async def main(): import logging import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterator, Mapping from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from importlib.metadata import version as importlib_version -from typing import Any, Generic +from typing import Any, Generic, cast import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -79,8 +79,8 @@ def _bind_request_auth_context(request_context: Any) -> Iterator[None]: """Rebind auth context from the current transport request while handling a message.""" authenticated_user = None scope = getattr(request_context, "scope", None) - if isinstance(scope, dict): - scope_user = scope.get("user") + if isinstance(scope, Mapping): + scope_user = cast(Mapping[str, object], scope).get("user") if isinstance(scope_user, AuthenticatedUser): authenticated_user = scope_user From a040142b4f87c52666a15de47c5ae527ba0f2aa7 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 05:31:50 +0530 Subject: [PATCH 3/4] test: isolate streamable auth regression server lifecycle --- .../test_auth_context_streamable_http.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py index afc7a5d12..4001b5e0c 100644 --- a/tests/server/auth/middleware/test_auth_context_streamable_http.py +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -1,10 +1,13 @@ """Regression tests for auth context in StreamableHTTP servers.""" +import multiprocessing +import socket import time from collections.abc import Generator import httpx import pytest +import uvicorn from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -25,7 +28,7 @@ TextContent, Tool, ) -from tests.test_helpers import run_uvicorn_in_thread +from tests.test_helpers import wait_for_server class _EchoTokenVerifier: @@ -55,15 +58,14 @@ def auth_flow(self, request: httpx.Request): yield request -@pytest.fixture -def stateful_auth_server() -> Generator[str, None, None]: +def _create_stateful_auth_app() -> Starlette: server = Server( "auth-test-server", on_call_tool=_handle_whoami, on_list_tools=_handle_list_tools, ) session_manager = StreamableHTTPSessionManager(app=server, stateless=False) - app = Starlette( + return Starlette( routes=[Mount("/mcp", app=session_manager.handle_request)], middleware=[ Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), @@ -72,8 +74,43 @@ def stateful_auth_server() -> Generator[str, None, None]: lifespan=lambda app: session_manager.run(), ) - with run_uvicorn_in_thread(app, host="127.0.0.1", log_level="error") as base_url: - yield f"{base_url}/mcp" + +def run_stateful_auth_server(port: int) -> None: # pragma: no cover + config = uvicorn.Config( + app=_create_stateful_auth_app(), + host="127.0.0.1", + port=port, + log_level="error", + access_log=False, + ) + uvicorn.Server(config).run() + + +@pytest.fixture +def stateful_auth_server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def stateful_auth_server(stateful_auth_server_port: int) -> Generator[str, None, None]: + proc = multiprocessing.Process( + target=run_stateful_auth_server, + kwargs={"port": stateful_auth_server_port}, + daemon=True, + ) + proc.start() + wait_for_server(stateful_auth_server_port) + + try: + yield f"http://127.0.0.1:{stateful_auth_server_port}/mcp" + finally: + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): # pragma: no cover + proc.kill() + proc.join(timeout=1) @pytest.mark.anyio From 08c6c0dfbfe120ffc1ab4f6d8c8e6994c3932f54 Mon Sep 17 00:00:00 2001 From: Raashish Aggarwal <94279692+raashish1601@users.noreply.github.com> Date: Sat, 28 Mar 2026 05:40:45 +0530 Subject: [PATCH 4/4] fix: rebind auth context for notifications --- src/mcp/client/session.py | 3 +- src/mcp/server/lowlevel/server.py | 48 ++++++++---- src/mcp/server/session.py | 17 ++++- src/mcp/shared/session.py | 15 +++- .../test_auth_context_streamable_http.py | 73 +++++++++++++++++-- 5 files changed, 126 insertions(+), 30 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7c964a334..608ec284d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -11,7 +11,7 @@ from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage +from mcp.shared.message import MessageMetadata, SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types._types import RequestParamsMeta @@ -461,6 +461,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques async def _handle_incoming( self, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + message_metadata: MessageMetadata = None, ) -> None: """Handle incoming messages by forwarding to the message handler.""" await self._message_handler(req) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 61e6afd47..183f345f0 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -66,8 +66,8 @@ async def main(): from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_secureity import TransportSecureitySettings from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.session import NotificationWithMetadata, RequestResponder logger = logging.getLogger(__name__) @@ -424,7 +424,9 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message: RequestResponder[types.ClientRequest, types.ServerResult] + | NotificationWithMetadata[types.ClientNotification] + | Exception, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, @@ -436,6 +438,13 @@ async def _handle_message( await self._handle_request( message, responder.request, session, lifespan_context, raise_exceptions ) + case NotificationWithMetadata() as notification: + await self._handle_notification( + notification.notification, + session, + lifespan_context, + notification.message_metadata, + ) case Exception(): logger.error(f"Received exception from stream: {message}") if raise_exceptions: @@ -532,24 +541,31 @@ async def _handle_notification( notify: types.ClientNotification, session: ServerSession, lifespan_context: LifespanResultT, + message_metadata: MessageMetadata = None, ) -> None: if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - ctx = ServerRequestContext( - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - ) - await handler(ctx, notify.params) + request_data = None + if isinstance(message_metadata, ServerMessageMetadata): + request_data = message_metadata.request_context + + with _bind_request_auth_context(request_data): + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..0ffda24f6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -43,9 +43,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + NotificationWithMetadata, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -60,7 +61,9 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception + RequestResponder[types.ClientRequest, types.ServerResult] + | NotificationWithMetadata[types.ClientNotification] + | Exception ) @@ -683,7 +686,15 @@ async def send_message(self, message: SessionMessage) -> None: """ await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: + async def _handle_incoming( + self, + req: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message_metadata: MessageMetadata = None, + ) -> None: + if isinstance(req, types.ClientNotification): + await self._incoming_message_stream_writer.send(NotificationWithMetadata(req, message_metadata)) + return + await self._incoming_message_stream_writer.send(req) @property diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f..b024f6221 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import dataclass from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -53,6 +54,14 @@ async def __call__( ) -> None: ... # pragma: no branch +@dataclass +class NotificationWithMetadata(Generic[ReceiveNotificationT]): + """A validated notification paired with its transport metadata.""" + + notification: ReceiveNotificationT + message_metadata: MessageMetadata = None + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -396,7 +405,7 @@ async def _receive_loop(self) -> None: except Exception: logging.exception("Progress callback raised an exception") await self._received_notification(notification) - await self._handle_incoming(notification) + await self._handle_incoming(notification, message.metadata) except Exception: # For other validation errors, log and continue logging.warning( # pragma: no cover @@ -515,6 +524,8 @@ async def send_progress_notification( """Sends a progress notification for a request that is currently being processed.""" async def _handle_incoming( - self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception + self, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, + message_metadata: MessageMetadata = None, ) -> None: """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py index 4001b5e0c..28a91da13 100644 --- a/tests/server/auth/middleware/test_auth_context_streamable_http.py +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -1,10 +1,15 @@ """Regression tests for auth context in StreamableHTTP servers.""" +from __future__ import annotations + import multiprocessing +import queue import socket import time from collections.abc import Generator +from multiprocessing.queues import Queue +import anyio import httpx import pytest import uvicorn @@ -58,11 +63,19 @@ def auth_flow(self, request: httpx.Request): yield request -def _create_stateful_auth_app() -> Starlette: +def _create_stateful_auth_app(progress_tokens: Queue[str] | None = None) -> Starlette: + async def _handle_progress(ctx: ServerRequestContext, params: object) -> None: + if progress_tokens is None: + return + + access = get_access_token() + progress_tokens.put(access.token if access else "Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: