URL: http://github.com/modelcontextprotocol/python-sdk/pull/2363.diff
MessageMetadata = None, ) -> None: if handler := self._notification_handlers.get(notify.method): logger.debug("Dispatching notification of type %s", type(notify).__name__) try: - client_capabilities = session.client_params.capabilities if session.client_params else None - task_support = self._experimental_handlers.task_support if self._experimental_handlers else None - ctx = ServerRequestContext( - session=session, - lifespan_context=lifespan_context, - experimental=Experimental( - task_metadata=None, - _client_capabilities=client_capabilities, - _session=session, - _task_support=task_support, - ), - ) - await handler(ctx, notify.params) + request_data = None + if isinstance(message_metadata, ServerMessageMetadata): + request_data = message_metadata.request_context + + with _bind_request_auth_context(request_data): + client_capabilities = session.client_params.capabilities if session.client_params else None + task_support = self._experimental_handlers.task_support if self._experimental_handlers else None + ctx = ServerRequestContext( + session=session, + lifespan_context=lifespan_context, + experimental=Experimental( + task_metadata=None, + _client_capabilities=client_capabilities, + _session=session, + _task_support=task_support, + ), + request=request_data, + ) + await handler(ctx, notify.params) except Exception: # pragma: no cover logger.exception("Uncaught exception in notification handler") diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ce467e6c9..0ffda24f6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -43,9 +43,10 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.shared.exceptions import StatelessModeNotSupported from mcp.shared.experimental.tasks.capabilities import check_tasks_capability from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + NotificationWithMetadata, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -60,7 +61,9 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception + RequestResponder[types.ClientRequest, types.ServerResult] + | NotificationWithMetadata[types.ClientNotification] + | Exception ) @@ -683,7 +686,15 @@ async def send_message(self, message: SessionMessage) -> None: """ await self._write_stream.send(message) - async def _handle_incoming(self, req: ServerRequestResponder) -> None: + async def _handle_incoming( + self, + req: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + message_metadata: MessageMetadata = None, + ) -> None: + if isinstance(req, types.ClientNotification): + await self._incoming_message_stream_writer.send(NotificationWithMetadata(req, message_metadata)) + return + await self._incoming_message_stream_writer.send(req) @property diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 6fc59923f..b024f6221 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import dataclass from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -53,6 +54,14 @@ async def __call__( ) -> None: ... # pragma: no branch +@dataclass +class NotificationWithMetadata(Generic[ReceiveNotificationT]): + """A validated notification paired with its transport metadata.""" + + notification: ReceiveNotificationT + message_metadata: MessageMetadata = None + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -396,7 +405,7 @@ async def _receive_loop(self) -> None: except Exception: logging.exception("Progress callback raised an exception") await self._received_notification(notification) - await self._handle_incoming(notification) + await self._handle_incoming(notification, message.metadata) except Exception: # For other validation errors, log and continue logging.warning( # pragma: no cover @@ -515,6 +524,8 @@ async def send_progress_notification( """Sends a progress notification for a request that is currently being processed.""" async def _handle_incoming( - self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception + self, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, + message_metadata: MessageMetadata = None, ) -> None: """A generic handler for incoming messages. Overridden by subclasses.""" diff --git a/tests/server/auth/middleware/test_auth_context_streamable_http.py b/tests/server/auth/middleware/test_auth_context_streamable_http.py new file mode 100644 index 000000000..28a91da13 --- /dev/null +++ b/tests/server/auth/middleware/test_auth_context_streamable_http.py @@ -0,0 +1,198 @@ +"""Regression tests for auth context in StreamableHTTP servers.""" + +from __future__ import annotations + +import multiprocessing +import queue +import socket +import time +from collections.abc import Generator +from multiprocessing.queues import Queue + +import anyio +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Mount + +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) +from tests.test_helpers import wait_for_server + + +class _EchoTokenVerifier: + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access = get_access_token() + text = access.token if access else "Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: