-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_stdio.py
More file actions
94 lines (76 loc) · 3.63 KB
/
test_stdio.py
File metadata and controls
94 lines (76 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import io
import sys
from io import TextIOWrapper
import anyio
import pytest
from mcp.server.stdio import stdio_server
from mcp.shared.message import SessionMessage
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter
@pytest.mark.anyio
async def test_stdio_server():
stdin = io.StringIO()
stdout = io.StringIO()
messages = [
JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=2, result={}),
]
for message in messages:
stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n")
stdin.seek(0)
async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as (
read_stream,
write_stream,
):
received_messages: list[JSONRPCMessage] = []
async with read_stream:
async for message in read_stream:
if isinstance(message, Exception): # pragma: no cover
raise message
received_messages.append(message.message)
if len(received_messages) == 2:
break
# Verify received messages
assert len(received_messages) == 2
assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={})
# Test sending responses from the server
responses = [
JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"),
JSONRPCResponse(jsonrpc="2.0", id=4, result={}),
]
async with write_stream:
for response in responses:
session_message = SessionMessage(response)
await write_stream.send(session_message)
stdout.seek(0)
output_lines = stdout.readlines()
assert len(output_lines) == 2
received_responses = [jsonrpc_message_adapter.validate_json(line.strip()) for line in output_lines]
assert len(received_responses) == 2
assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={})
@pytest.mark.anyio
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
"""Non-UTF-8 bytes on stdin must not crash the server.
Invalid bytes are replaced with U+FFFD, which then fails JSON parsing and
is delivered as an in-stream exception. Subsequent valid messages must
still be processed.
"""
# \xff\xfe are invalid UTF-8 start bytes.
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
# stdio_server()'s default path wraps it with errors='replace'.
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
with anyio.fail_after(5):
async with stdio_server() as (read_stream, write_stream):
await write_stream.aclose()
async with read_stream: # pragma: no branch
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
first = await read_stream.receive()
assert isinstance(first, Exception)
# Second line: valid message still comes through
second = await read_stream.receive()
assert isinstance(second, SessionMessage)
assert second.message == valid