-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtest_sampling_callback.py
More file actions
131 lines (113 loc) · 5.02 KB
/
test_sampling_callback.py
File metadata and controls
131 lines (113 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import pytest
from mcp import Client
from mcp.client.session import ClientSession
from mcp.server.mcpserver import Context, MCPServer
from mcp.shared._context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
CreateMessageResultWithTools,
SamplingMessage,
TextContent,
ToolUseContent,
)
@pytest.mark.anyio
async def test_sampling_callback():
server = MCPServer("test")
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="This is a response from the sampling callback"),
model="test-model",
stop_reason="endTurn",
)
async def sampling_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
@server.tool("test_sampling")
async def test_sampling_tool(message: str, ctx: Context) -> bool:
value = await ctx.session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
assert value == callback_return
return True
# Test with sampling callback
async with Client(server, sampling_callback=sampling_callback) as client:
# Make a request to trigger sampling callback
result = await client.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
# Test without sampling callback
async with Client(server) as client:
# Make a request to trigger sampling callback
result = await client.call_tool("test_sampling", {"message": "Test message for sampling"})
assert result.is_error is True
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
@pytest.mark.anyio
async def test_create_message_backwards_compat_single_content():
"""Test backwards compatibility: create_message without tools returns single content."""
server = MCPServer("test")
# Callback returns single content (text)
callback_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Hello from LLM"),
model="test-model",
stop_reason="endTurn",
)
async def sampling_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return
@server.tool("test_backwards_compat")
async def test_tool(message: str, ctx: Context) -> bool:
# Call create_message WITHOUT tools
result = await ctx.session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
# Backwards compat: result should be CreateMessageResult
assert isinstance(result, CreateMessageResult)
# Content should be single (not a list) - this is the key backwards compat check
assert isinstance(result.content, TextContent)
assert result.content.text == "Hello from LLM"
# CreateMessageResult should NOT have content_as_list (that's on WithTools)
assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None))
return True
async with Client(server, sampling_callback=sampling_callback) as client:
result = await client.call_tool("test_backwards_compat", {"message": "Test"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"
@pytest.mark.anyio
async def test_create_message_result_with_tools_type():
"""Test that CreateMessageResultWithTools supports content_as_list."""
# Test the type itself, not the overload (overload requires client capability setup)
result = CreateMessageResultWithTools(
role="assistant",
content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}),
model="test-model",
stop_reason="toolUse",
)
# CreateMessageResultWithTools should have content_as_list
content_list = result.content_as_list
assert len(content_list) == 1
assert content_list[0].type == "tool_use"
# It should also work with array content
result_array = CreateMessageResultWithTools(
role="assistant",
content=[
TextContent(type="text", text="Let me check the weather"),
ToolUseContent(type="tool_use", id="call_456", name="get_weather", input={"city": "NYC"}),
],
model="test-model",
stop_reason="toolUse",
)
content_list_array = result_array.content_as_list
assert len(content_list_array) == 2
assert content_list_array[0].type == "text"
assert content_list_array[1].type == "tool_use"