Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
2819 lines
111 KiB
Python
2819 lines
111 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Tests for the serving layer.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import socket
|
|
import time
|
|
import unittest
|
|
from unittest.mock import MagicMock
|
|
|
|
import httpx
|
|
|
|
from transformers.cli.serve import Serve
|
|
from transformers.cli.serving.chat_completion import ChatCompletionHandler
|
|
from transformers.cli.serving.completion import CompletionHandler
|
|
from transformers.cli.serving.model_manager import ModelManager, TimedModel
|
|
from transformers.cli.serving.response import ResponseHandler, compute_usage
|
|
from transformers.cli.serving.server import build_server
|
|
from transformers.cli.serving.transcription import TranscriptionHandler
|
|
from transformers.cli.serving.utils import (
|
|
_TOOL_CALL_FALLBACKS,
|
|
BaseHandler,
|
|
GenerationState,
|
|
Modality,
|
|
get_tool_call_config,
|
|
parse_tool_calls,
|
|
)
|
|
from transformers.testing_utils import (
|
|
require_librosa,
|
|
require_multipart,
|
|
require_serve,
|
|
require_torch_accelerator,
|
|
require_torchcodec,
|
|
require_vision,
|
|
slow,
|
|
)
|
|
from transformers.utils.chat_parsing_utils import recursive_parse
|
|
from transformers.utils.import_utils import is_serve_available
|
|
|
|
|
|
if is_serve_available():
|
|
from fastapi import HTTPException
|
|
from openai import OpenAI
|
|
from openai.types.responses import Response, ResponseCreatedEvent
|
|
|
|
|
|
def _find_free_port() -> int:
|
|
"""Return a free TCP port on localhost."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("localhost", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def _start_serve(**kwargs) -> tuple["Serve", int]:
|
|
"""Start a non-blocking Serve instance on a free port and wait until healthy.
|
|
|
|
Returns ``(serve, port)``.
|
|
"""
|
|
port = _find_free_port()
|
|
serve = Serve(port=port, non_blocking=True, **kwargs)
|
|
for _ in range(30):
|
|
try:
|
|
if httpx.get(f"http://localhost:{port}/health", timeout=2).status_code == 200:
|
|
return serve, port
|
|
except Exception: # noqa: S110
|
|
pass
|
|
time.sleep(1)
|
|
raise RuntimeError(f"Server on port {port} did not become healthy in time")
|
|
|
|
|
|
@require_serve
|
|
def test_host_port_blocking(cli):
|
|
"""CLI args --host and --port are passed to uvicorn.Config, and server.run() is called."""
|
|
from unittest.mock import Mock, patch
|
|
|
|
with (
|
|
patch("uvicorn.Config") as ConfigMock,
|
|
patch("uvicorn.Server") as ServerMock,
|
|
):
|
|
server_instance = Mock()
|
|
ServerMock.return_value = server_instance
|
|
|
|
out = cli("serve", "--host", "0.0.0.0", "--port", "9000")
|
|
_, kwargs = ConfigMock.call_args
|
|
|
|
assert out.exit_code == 0
|
|
assert kwargs["host"] == "0.0.0.0"
|
|
assert kwargs["port"] == 9000
|
|
ServerMock.assert_called_once_with(ConfigMock.return_value)
|
|
server_instance.run.assert_called_once()
|
|
|
|
|
|
class TestProcessorInputsFromMessages(unittest.TestCase):
|
|
def test_llm_string_content(self):
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
result = get_processor_inputs_from_messages(messages, Modality.LLM)
|
|
self.assertEqual(result, [{"role": "user", "content": "Hello"}])
|
|
|
|
def test_llm_list_content_text_only(self):
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [{"role": "user", "content": [{"type": "text", "text": "A"}, {"type": "text", "text": "B"}]}]
|
|
result = get_processor_inputs_from_messages(messages, Modality.LLM)
|
|
self.assertEqual(result, [{"role": "user", "content": "A B"}])
|
|
|
|
def test_vlm_string_content_wrapped(self):
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
result = get_processor_inputs_from_messages(messages, Modality.VLM)
|
|
self.assertEqual(result, [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}])
|
|
|
|
def test_vlm_text_and_image_url(self):
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "What is this?"},
|
|
{"type": "image_url", "image_url": {"url": "https://example.com/img.png"}},
|
|
],
|
|
}
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.VLM)
|
|
self.assertEqual(len(result[0]["content"]), 2)
|
|
self.assertEqual(result[0]["content"][0]["type"], "text")
|
|
self.assertEqual(result[0]["content"][1], {"type": "image", "url": "https://example.com/img.png"})
|
|
|
|
def test_llm_multi_turn_conversation(self):
|
|
"""Multi-turn conversation with string content should pass through as-is."""
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [
|
|
{"role": "user", "content": "How are you?"},
|
|
{"role": "assistant", "content": "I'm great!"},
|
|
{"role": "user", "content": "Help me write tests?"},
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.LLM)
|
|
self.assertEqual(len(result), 3)
|
|
self.assertEqual(result[0]["content"], "How are you?")
|
|
self.assertEqual(result[1]["role"], "assistant")
|
|
self.assertEqual(result[2]["content"], "Help me write tests?")
|
|
|
|
def test_llm_list_content_with_type(self):
|
|
"""LLM messages with typed content list should extract text and join."""
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [
|
|
{"role": "user", "content": [{"type": "text", "text": "Hello"}, {"type": "text", "text": "world"}]}
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.LLM)
|
|
self.assertEqual(result[0]["content"], "Hello world")
|
|
|
|
@require_vision
|
|
def test_vlm_base64_image_passed_through(self):
|
|
"""Base64 image URLs should be passed through as-is for the processor to handle."""
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
# Minimal valid 1x1 PNG as base64
|
|
base64_url = (
|
|
"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4"
|
|
"2mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
|
)
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "What is this?"},
|
|
{"type": "image_url", "image_url": {"url": base64_url}},
|
|
],
|
|
}
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.VLM)
|
|
image_item = result[0]["content"][1]
|
|
self.assertEqual(image_item["type"], "image")
|
|
self.assertEqual(image_item["url"], base64_url)
|
|
|
|
def test_vlm_multi_turn(self):
|
|
"""VLM multi-turn: string content should be wrapped in text type."""
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
messages = [
|
|
{"role": "user", "content": "Describe the image"},
|
|
{"role": "assistant", "content": "It shows a cat"},
|
|
{"role": "user", "content": "What color?"},
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.VLM)
|
|
self.assertEqual(len(result), 3)
|
|
for msg in result:
|
|
self.assertIsInstance(msg["content"], list)
|
|
self.assertEqual(msg["content"][0]["type"], "text")
|
|
|
|
def test_llm_tool_use_fields_forwarded(self):
|
|
"""Tool-use fields (tool_calls, tool_call_id) should be forwarded to processor inputs."""
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
tool_calls = [
|
|
{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": '{"city": "Paris"}'}}
|
|
]
|
|
messages = [
|
|
{"role": "user", "content": "What's the weather in Paris?"},
|
|
{"role": "assistant", "tool_calls": tool_calls},
|
|
{"role": "tool", "content": "22°C, sunny", "tool_call_id": "call_1"},
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.LLM)
|
|
self.assertEqual(len(result), 3)
|
|
# The preprocessor parses function.arguments from JSON string to dict.
|
|
tool_calls[0]["function"]["arguments"] = json.loads(tool_calls[0]["function"]["arguments"])
|
|
self.assertEqual(result[1]["tool_calls"], tool_calls)
|
|
self.assertNotIn("tool_calls", result[0])
|
|
self.assertEqual(result[2]["tool_call_id"], "call_1")
|
|
self.assertNotIn("tool_call_id", result[0])
|
|
|
|
def test_vlm_tool_use_fields_forwarded(self):
|
|
"""Tool-use fields should be forwarded for VLM modality as well."""
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
tool_calls = [{"id": "call_1", "type": "function", "function": {"name": "describe", "arguments": "{}"}}]
|
|
messages = [
|
|
{"role": "user", "content": "Describe this"},
|
|
{"role": "assistant", "tool_calls": tool_calls},
|
|
{"role": "tool", "content": "A landscape photo", "tool_call_id": "call_1"},
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.VLM)
|
|
self.assertEqual(len(result), 3)
|
|
tool_calls[0]["function"]["arguments"] = json.loads(tool_calls[0]["function"]["arguments"])
|
|
self.assertEqual(result[1]["tool_calls"], tool_calls)
|
|
self.assertEqual(result[2]["tool_call_id"], "call_1")
|
|
|
|
def test_multimodal_base64_input_audio_converted_to_data_uri(self):
|
|
"""input_audio with base64 data should be converted to a data URI for the processor to handle."""
|
|
import base64
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
audio_url = "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"
|
|
audio_bytes = httpx.get(audio_url, follow_redirects=True).content
|
|
audio_b64 = base64.b64encode(audio_bytes).decode()
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "What do you hear?"},
|
|
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "mp3"}},
|
|
],
|
|
}
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.MULTIMODAL)
|
|
self.assertEqual(len(result[0]["content"]), 2)
|
|
self.assertEqual(result[0]["content"][0]["type"], "text")
|
|
audio_item = result[0]["content"][1]
|
|
self.assertEqual(audio_item["type"], "audio")
|
|
self.assertTrue(audio_item["url"].startswith("data:audio/mp3;base64,"))
|
|
self.assertIn(audio_b64, audio_item["url"])
|
|
|
|
def test_vlm_ignores_audio_content(self):
|
|
"""VLM models should ignore audio content parts."""
|
|
import base64
|
|
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
audio_b64 = base64.b64encode(b"fake audio").decode()
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "What do you hear?"},
|
|
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "mp3"}},
|
|
],
|
|
}
|
|
]
|
|
result = get_processor_inputs_from_messages(messages, Modality.VLM)
|
|
self.assertEqual(len(result[0]["content"]), 1)
|
|
self.assertEqual(result[0]["content"][0]["type"], "text")
|
|
|
|
def test_video_url_converted_to_hf_video_format(self):
|
|
"""video_url content should be converted to HF video format for both VLM and MULTIMODAL."""
|
|
get_processor_inputs_from_messages = BaseHandler.get_processor_inputs_from_messages
|
|
|
|
video_src = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "video_url", "video_url": {"url": video_src}},
|
|
{"type": "text", "text": "What is happening?"},
|
|
],
|
|
}
|
|
]
|
|
for modality in (Modality.VLM, Modality.MULTIMODAL):
|
|
with self.subTest(modality=modality):
|
|
result = get_processor_inputs_from_messages(messages, modality)
|
|
self.assertEqual(len(result[0]["content"]), 2)
|
|
video_item = result[0]["content"][0]
|
|
self.assertEqual(video_item, {"type": "video", "url": video_src})
|
|
|
|
|
|
class TestGenerativeModelList(unittest.TestCase):
|
|
def test_lists_only_generative_models(self):
|
|
"""Should list LLMs and VLMs but not non-generative models like BERT."""
|
|
import tempfile
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
with tempfile.TemporaryDirectory() as cache_dir:
|
|
# Download config.json for a few models
|
|
hf_hub_download("Qwen/Qwen2.5-0.5B-Instruct", "config.json", cache_dir=cache_dir)
|
|
hf_hub_download("google-bert/bert-base-cased", "config.json", cache_dir=cache_dir)
|
|
|
|
result = ModelManager.get_gen_models(cache_dir)
|
|
model_ids = {r["id"] for r in result}
|
|
|
|
self.assertIn("Qwen/Qwen2.5-0.5B-Instruct", model_ids)
|
|
self.assertNotIn("google-bert/bert-base-cased", model_ids)
|
|
|
|
|
|
@require_serve
|
|
class TestBuildGenerationConfig(unittest.TestCase):
|
|
def _make_handler(self):
|
|
return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState())
|
|
|
|
def test_max_tokens(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"max_tokens": 7}, GenerationConfig())
|
|
self.assertEqual(result.max_new_tokens, 7)
|
|
|
|
def test_temperature_zero_disables_sampling(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"temperature": 0.0}, GenerationConfig(do_sample=True))
|
|
self.assertFalse(result.do_sample)
|
|
|
|
def test_frequency_penalty(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"frequency_penalty": 0.5}, GenerationConfig())
|
|
self.assertAlmostEqual(result.repetition_penalty, 1.5)
|
|
|
|
def test_logit_bias_tuple_keys(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"logit_bias": {"42": 1.0}}, GenerationConfig())
|
|
self.assertEqual(result.sequence_bias, {(42,): 1.0})
|
|
|
|
def test_stop_strings(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"stop": ["<END>"]}, GenerationConfig())
|
|
self.assertEqual(result.stop_strings, ["<END>"])
|
|
|
|
def test_generation_config_json_overrides(self):
|
|
from transformers import GenerationConfig
|
|
|
|
custom = GenerationConfig(max_new_tokens=5, do_sample=False)
|
|
result = self._make_handler()._build_generation_config(
|
|
{"generation_config": custom.to_json_string()}, GenerationConfig(max_new_tokens=100)
|
|
)
|
|
self.assertEqual(result.max_new_tokens, 5)
|
|
self.assertFalse(result.do_sample)
|
|
|
|
def test_generation_config_json_no_defaults_applied(self):
|
|
"""When generation_config JSON is passed, serving defaults should NOT be applied."""
|
|
from transformers import GenerationConfig
|
|
|
|
custom = GenerationConfig(max_new_tokens=10)
|
|
result = self._make_handler()._build_generation_config(
|
|
{"generation_config": custom.to_json_string()}, GenerationConfig()
|
|
)
|
|
# Should keep 10, not bump to 1024
|
|
self.assertEqual(result.max_new_tokens, 10)
|
|
|
|
def test_default_bumps_short_max_new_tokens(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20))
|
|
self.assertEqual(result.max_new_tokens, 1024)
|
|
|
|
def test_user_max_tokens_overrides_default(self):
|
|
"""User's max_tokens should win over the serving default."""
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"max_tokens": 50}, GenerationConfig(max_new_tokens=20))
|
|
self.assertEqual(result.max_new_tokens, 50)
|
|
|
|
|
|
@require_serve
|
|
class TestValidation(unittest.TestCase):
|
|
def _make_handler(self):
|
|
return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState())
|
|
|
|
def test_valid_request_passes(self):
|
|
handler = self._make_handler()
|
|
# Should not raise
|
|
handler._validate_request({"model": "x", "messages": [{"role": "user", "content": "hi"}], "stream": True})
|
|
|
|
def test_unexpected_keys_rejected(self):
|
|
handler = self._make_handler()
|
|
with self.assertRaises(HTTPException) as ctx:
|
|
handler._validate_request({"model": "x", "messages": [], "bogus_field": True})
|
|
self.assertEqual(ctx.exception.status_code, 422)
|
|
self.assertIn("bogus_field", ctx.exception.detail)
|
|
|
|
def test_unsupported_fields_warns(self):
|
|
handler = self._make_handler()
|
|
with self.assertLogs("transformers", level="WARNING") as cm:
|
|
handler._validate_request({"model": "x", "messages": [], "audio": {}})
|
|
self.assertTrue(any("audio" in msg for msg in cm.output))
|
|
|
|
|
|
class TestResolveModel(unittest.TestCase):
|
|
def _make_handler(self, force_model=None):
|
|
mm = MagicMock()
|
|
mm.force_model = force_model
|
|
mm.process_model_name.side_effect = ModelManager.process_model_name
|
|
mm.load_model_and_processor.return_value = (MagicMock(), MagicMock())
|
|
return ChatCompletionHandler(model_manager=mm, generation_state=GenerationState())
|
|
|
|
def test_force_model_overrides_when_model_omitted(self):
|
|
handler = self._make_handler(force_model="org/pinned")
|
|
body = {}
|
|
model_id, _, _ = handler._resolve_model(body)
|
|
self.assertEqual(model_id, "org/pinned@main")
|
|
self.assertEqual(body["model"], "org/pinned")
|
|
|
|
def test_force_model_allows_matching_request(self):
|
|
handler = self._make_handler(force_model="org/pinned")
|
|
body = {"model": "org/pinned"}
|
|
model_id, _, _ = handler._resolve_model(body)
|
|
self.assertEqual(model_id, "org/pinned@main")
|
|
|
|
def test_force_model_rejects_mismatched_request(self):
|
|
handler = self._make_handler(force_model="org/pinned")
|
|
with self.assertRaises(HTTPException) as ctx:
|
|
handler._resolve_model({"model": "other/model"})
|
|
self.assertEqual(ctx.exception.status_code, 400)
|
|
self.assertIn("org/pinned", ctx.exception.detail)
|
|
self.assertIn("other/model", ctx.exception.detail)
|
|
|
|
|
|
class TestModelManager(unittest.TestCase):
|
|
def test_process_model_name_adds_main(self):
|
|
self.assertEqual(ModelManager.process_model_name("org/model"), "org/model@main")
|
|
|
|
def test_process_model_name_preserves_revision(self):
|
|
self.assertEqual(ModelManager.process_model_name("org/model@dev"), "org/model@dev")
|
|
|
|
def test_quantization_config_4bit(self):
|
|
mm = ModelManager(quantization="bnb-4bit")
|
|
cfg = mm.get_quantization_config()
|
|
self.assertTrue(cfg.load_in_4bit)
|
|
|
|
def test_quantization_config_8bit(self):
|
|
mm = ModelManager(quantization="bnb-8bit")
|
|
cfg = mm.get_quantization_config()
|
|
self.assertTrue(cfg.load_in_8bit)
|
|
|
|
def test_quantization_config_none(self):
|
|
mm = ModelManager()
|
|
self.assertIsNone(mm.get_quantization_config())
|
|
|
|
|
|
class TestTimedModel(unittest.TestCase):
|
|
def test_delete_model(self):
|
|
mock_model = MagicMock()
|
|
deleted = []
|
|
timed = TimedModel(
|
|
mock_model, timeout_seconds=9999, processor=MagicMock(), on_unload=lambda: deleted.append(True)
|
|
)
|
|
self.assertIsNotNone(timed.model)
|
|
timed.delete_model()
|
|
self.assertIsNone(timed.model)
|
|
self.assertEqual(len(deleted), 1)
|
|
|
|
def test_timeout_zero_no_delete(self):
|
|
mock_model = MagicMock()
|
|
timed = TimedModel(mock_model, timeout_seconds=0, processor=MagicMock())
|
|
timed._timeout_reached()
|
|
self.assertIsNotNone(timed.model)
|
|
timed._timer.cancel()
|
|
|
|
|
|
@require_serve
|
|
class TestChunkSSE(unittest.TestCase):
|
|
def _make_handler(self):
|
|
return ChatCompletionHandler(model_manager=MagicMock(), generation_state=GenerationState())
|
|
|
|
def test_build_chunk_sse_content(self):
|
|
handler = self._make_handler()
|
|
sse = handler._build_chunk_sse(request_id="req1", content="hi", model="m")
|
|
self.assertTrue(sse.startswith("data: "))
|
|
self.assertTrue(sse.endswith("\n\n"))
|
|
parsed = json.loads(sse[len("data: ") :].strip())
|
|
self.assertEqual(parsed["choices"][0]["delta"]["content"], "hi")
|
|
|
|
def test_build_chunk_sse_role(self):
|
|
handler = self._make_handler()
|
|
sse = handler._build_chunk_sse(request_id="req1", role="assistant", model="m")
|
|
parsed = json.loads(sse[len("data: ") :].strip())
|
|
self.assertEqual(parsed["choices"][0]["delta"]["role"], "assistant")
|
|
self.assertNotIn("content", parsed["choices"][0]["delta"])
|
|
|
|
def test_build_chunk_sse_finish_reason(self):
|
|
handler = self._make_handler()
|
|
sse = handler._build_chunk_sse(request_id="req1", finish_reason="stop", model="m")
|
|
parsed = json.loads(sse[len("data: ") :].strip())
|
|
self.assertEqual(parsed["choices"][0]["finish_reason"], "stop")
|
|
|
|
def test_chunk_to_sse_string_passthrough(self):
|
|
result = BaseHandler.chunk_to_sse("data: already formatted\n\n")
|
|
self.assertEqual(result, "data: already formatted\n\n")
|
|
|
|
def test_chunk_to_sse_wraps_plain_string(self):
|
|
result = BaseHandler.chunk_to_sse("hello")
|
|
self.assertEqual(result, "data: hello\n\n")
|
|
|
|
|
|
@require_serve
|
|
@require_serve
|
|
class TestAppRoutes(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model_manager = MagicMock(spec=ModelManager)
|
|
cls.model_manager.get_gen_models.return_value = [
|
|
{"id": "test/model", "owned_by": "test", "object": "model", "created": 0}
|
|
]
|
|
cls.chat_handler = MagicMock(spec=ChatCompletionHandler)
|
|
cls.completion_handler = MagicMock(spec=CompletionHandler)
|
|
cls.response_handler = MagicMock(spec=ResponseHandler)
|
|
cls.transcription_handler = MagicMock(spec=TranscriptionHandler)
|
|
cls.app = build_server(
|
|
cls.model_manager,
|
|
cls.chat_handler,
|
|
cls.completion_handler,
|
|
cls.response_handler,
|
|
cls.transcription_handler,
|
|
generation_state=GenerationState(),
|
|
)
|
|
cls.transport = httpx.ASGITransport(app=cls.app)
|
|
|
|
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
|
|
async with httpx.AsyncClient(transport=self.transport, base_url="http://test") as c:
|
|
return await c.request(method, path, **kwargs)
|
|
|
|
def test_health(self):
|
|
resp = asyncio.run(self._request("GET", "/health"))
|
|
self.assertEqual(resp.status_code, 200)
|
|
self.assertEqual(resp.json(), {"status": "ok"})
|
|
|
|
def test_models_list(self):
|
|
resp = asyncio.run(self._request("GET", "/v1/models"))
|
|
self.assertEqual(resp.status_code, 200)
|
|
data = resp.json()
|
|
self.assertEqual(data["object"], "list")
|
|
self.assertEqual(len(data["data"]), 1)
|
|
|
|
def test_request_id_generated(self):
|
|
resp = asyncio.run(self._request("GET", "/health"))
|
|
self.assertIn("x-request-id", resp.headers)
|
|
self.assertEqual(len(resp.headers["x-request-id"]), 36) # UUID length
|
|
|
|
def test_request_id_passthrough(self):
|
|
resp = asyncio.run(self._request("GET", "/health", headers={"x-request-id": "my-id"}))
|
|
self.assertEqual(resp.headers["x-request-id"], "my-id")
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
class TestChatCompletion(unittest.TestCase):
|
|
"""Integration tests for /v1/chat/completions with a real model."""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def test_non_streaming(self):
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}]
|
|
)
|
|
self.assertIsNotNone(resp.choices[0].message.content)
|
|
self.assertIn(resp.choices[0].finish_reason, ("stop", "length"))
|
|
|
|
def test_streaming(self):
|
|
text = ""
|
|
for chunk in self.client.chat.completions.create(
|
|
model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}], stream=True
|
|
):
|
|
if chunk.choices[0].delta.content:
|
|
text += chunk.choices[0].delta.content
|
|
self.assertTrue(len(text) > 0)
|
|
|
|
def test_early_return_due_to_length(self):
|
|
"""When max_tokens is hit, finish_reason should be 'length'."""
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
|
stream=True,
|
|
max_tokens=3,
|
|
)
|
|
)
|
|
last = chunks[-1]
|
|
self.assertEqual(last.choices[0].finish_reason, "length")
|
|
|
|
def test_continues_until_stop(self):
|
|
"""When model stops naturally, finish_reason should be 'stop'."""
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": 'Please only answer with "Hi."'}],
|
|
stream=True,
|
|
max_tokens=30,
|
|
)
|
|
)
|
|
last = chunks[-1]
|
|
self.assertEqual(last.choices[0].finish_reason, "stop")
|
|
|
|
def test_stop_strings(self):
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL, messages=[{"role": "user", "content": "Count to 10"}], stop=["5"]
|
|
)
|
|
self.assertNotIn("6", resp.choices[0].message.content)
|
|
|
|
def test_multi_turn(self):
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{"role": "user", "content": "My name is Alice"},
|
|
{"role": "assistant", "content": "Nice to meet you!"},
|
|
{"role": "user", "content": "What is my name?"},
|
|
],
|
|
)
|
|
self.assertIn("Alice", resp.choices[0].message.content)
|
|
|
|
def test_multiple_models_on_demand(self):
|
|
"""Load two different models via separate requests — both should work."""
|
|
model_a = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
model_b = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
|
prompt = [{"role": "user", "content": "Say hello"}]
|
|
|
|
resp_a = self.client.chat.completions.create(model=model_a, messages=prompt)
|
|
self.assertIn(model_a, resp_a.model)
|
|
self.assertIsNotNone(resp_a.choices[0].message.content)
|
|
|
|
resp_b = self.client.chat.completions.create(model=model_b, messages=prompt)
|
|
self.assertIn(model_b, resp_b.model)
|
|
self.assertIsNotNone(resp_b.choices[0].message.content)
|
|
|
|
def test_non_streaming_usage(self):
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL, messages=[{"role": "user", "content": "Say hello"}]
|
|
)
|
|
self.assertIsNotNone(resp.usage)
|
|
self.assertGreater(resp.usage.prompt_tokens, 0)
|
|
self.assertGreater(resp.usage.completion_tokens, 0)
|
|
self.assertEqual(resp.usage.total_tokens, resp.usage.prompt_tokens + resp.usage.completion_tokens)
|
|
|
|
def test_streaming_usage(self):
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Say hello"}],
|
|
stream=True,
|
|
)
|
|
)
|
|
# Last chunk should have usage
|
|
last = chunks[-1]
|
|
self.assertIsNotNone(last.usage)
|
|
self.assertGreater(last.usage.prompt_tokens, 0)
|
|
self.assertGreater(last.usage.completion_tokens, 0)
|
|
self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens)
|
|
|
|
def test_concurrent_non_streaming(self):
|
|
"""Two concurrent non-streaming requests should both complete without interference."""
|
|
import concurrent.futures
|
|
|
|
prompts = [
|
|
[{"role": "user", "content": "Say hello"}],
|
|
[{"role": "user", "content": "Say goodbye"}],
|
|
]
|
|
results = [None, None]
|
|
|
|
def request_in_thread(index):
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
results[index] = client.chat.completions.create(model=self.MODEL, messages=prompts[index])
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(request_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result() # re-raise exceptions
|
|
|
|
for i in range(2):
|
|
self.assertIsNotNone(results[i])
|
|
self.assertIsNotNone(results[i].choices[0].message.content)
|
|
self.assertTrue(len(results[i].choices[0].message.content) > 0)
|
|
|
|
def test_concurrent_streaming(self):
|
|
"""Two concurrent streaming requests should both produce complete, non-empty output."""
|
|
import concurrent.futures
|
|
|
|
prompts = [
|
|
[{"role": "user", "content": "Say hello"}],
|
|
[{"role": "user", "content": "Say goodbye"}],
|
|
]
|
|
results = [None, None]
|
|
|
|
def stream_in_thread(index):
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
text = ""
|
|
for chunk in client.chat.completions.create(model=self.MODEL, messages=prompts[index], stream=True):
|
|
if chunk.choices[0].delta.content:
|
|
text += chunk.choices[0].delta.content
|
|
results[index] = text
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(stream_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result()
|
|
|
|
for i in range(2):
|
|
self.assertIsNotNone(results[i])
|
|
self.assertTrue(len(results[i]) > 0, f"Request {i} produced empty output")
|
|
|
|
def test_request_cancellation(self):
|
|
"""Closing a stream early doesn't crash and the server stays healthy."""
|
|
|
|
with httpx.stream(
|
|
"POST",
|
|
f"{self.base_url}/v1/chat/completions",
|
|
json={
|
|
"model": self.MODEL,
|
|
"stream": True,
|
|
"messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
|
|
"max_tokens": 500,
|
|
},
|
|
timeout=30,
|
|
) as resp:
|
|
self.assertEqual(resp.status_code, 200)
|
|
chunks_read = 0
|
|
for _ in resp.iter_lines():
|
|
chunks_read += 1
|
|
if chunks_read >= 3:
|
|
break
|
|
|
|
# Server should still be healthy and serve subsequent requests
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Say hi"}],
|
|
max_tokens=10,
|
|
)
|
|
self.assertIsNotNone(resp.choices[0].message.content)
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
class TestCompletion(unittest.TestCase):
|
|
"""Integration tests for /v1/completions with a real model.
|
|
|
|
Covers sequential and continuous-batching generation, both streaming and
|
|
non-streaming, plus finish_reason, usage, stop strings, suffix, and
|
|
cancellation behaviour.
|
|
"""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
# ----- non-streaming -----
|
|
|
|
def test_non_streaming(self):
|
|
from openai.types import Completion as OpenAICompletion
|
|
|
|
result = self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="The capital of France is",
|
|
max_tokens=5,
|
|
)
|
|
|
|
self.assertIsInstance(result, OpenAICompletion)
|
|
self.assertEqual(result.object, "text_completion")
|
|
self.assertIsInstance(result.choices[0].text, str)
|
|
self.assertTrue(len(result.choices[0].text) > 0)
|
|
self.assertIn(result.choices[0].finish_reason, ("stop", "length"))
|
|
|
|
def test_non_streaming_usage(self):
|
|
result = self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="The capital of France is",
|
|
max_tokens=5,
|
|
)
|
|
self.assertIsNotNone(result.usage)
|
|
self.assertGreater(result.usage.prompt_tokens, 0)
|
|
self.assertGreater(result.usage.completion_tokens, 0)
|
|
self.assertEqual(result.usage.total_tokens, result.usage.prompt_tokens + result.usage.completion_tokens)
|
|
|
|
def test_finish_reason_length(self):
|
|
result = self.client.completions.create(model=self.MODEL, prompt="Hello", max_tokens=1)
|
|
self.assertEqual(result.choices[0].finish_reason, "length")
|
|
|
|
def test_finish_reason_stop(self):
|
|
result = self.client.completions.create(model=self.MODEL, prompt="Hello", max_tokens=1024)
|
|
self.assertEqual(result.choices[0].finish_reason, "stop")
|
|
|
|
def test_stop_strings(self):
|
|
result = self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="one two three four",
|
|
max_tokens=20,
|
|
stop=["six"],
|
|
)
|
|
self.assertNotIn("seven", result.choices[0].text)
|
|
|
|
def test_suffix(self):
|
|
"""suffix should be appended to generated text."""
|
|
suffix = " [END]"
|
|
result = self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="The capital of France is",
|
|
max_tokens=5,
|
|
suffix=suffix,
|
|
)
|
|
self.assertTrue(result.choices[0].text.endswith(suffix))
|
|
|
|
# ----- streaming -----
|
|
|
|
def test_streaming(self):
|
|
chunks = list(
|
|
self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="The capital of France is",
|
|
max_tokens=5,
|
|
stream=True,
|
|
)
|
|
)
|
|
texts = [c.choices[0].text for c in chunks]
|
|
self.assertTrue(any(t != "" for t in texts))
|
|
self.assertIn(chunks[-1].choices[0].finish_reason, ("stop", "length"))
|
|
|
|
def test_streaming_usage(self):
|
|
chunks = list(
|
|
self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="The capital of France is",
|
|
max_tokens=5,
|
|
stream=True,
|
|
)
|
|
)
|
|
last = chunks[-1]
|
|
self.assertIsNotNone(last.usage)
|
|
self.assertGreater(last.usage.prompt_tokens, 0)
|
|
self.assertGreater(last.usage.completion_tokens, 0)
|
|
self.assertEqual(last.usage.total_tokens, last.usage.prompt_tokens + last.usage.completion_tokens)
|
|
|
|
def test_streaming_finish_reason_length(self):
|
|
chunks = list(self.client.completions.create(model=self.MODEL, prompt="Hello", max_tokens=1, stream=True))
|
|
self.assertEqual(chunks[-1].choices[0].finish_reason, "length")
|
|
|
|
def test_streaming_suffix(self):
|
|
"""suffix should be emitted as a final text chunk before finish_reason."""
|
|
suffix = " [END]"
|
|
chunks = list(
|
|
self.client.completions.create(
|
|
model=self.MODEL,
|
|
prompt="The capital of France is",
|
|
max_tokens=5,
|
|
suffix=suffix,
|
|
stream=True,
|
|
)
|
|
)
|
|
all_text = "".join(c.choices[0].text for c in chunks)
|
|
self.assertTrue(all_text.endswith(suffix))
|
|
|
|
def test_request_cancellation(self):
|
|
"""Closing a stream early doesn't crash and the server stays healthy."""
|
|
with httpx.stream(
|
|
"POST",
|
|
f"{self.base_url}/v1/completions",
|
|
json={"model": self.MODEL, "prompt": "Count slowly:", "max_tokens": 500, "stream": True},
|
|
timeout=30,
|
|
) as resp:
|
|
self.assertEqual(resp.status_code, 200)
|
|
chunks_read = 0
|
|
for _ in resp.iter_lines():
|
|
chunks_read += 1
|
|
if chunks_read >= 3:
|
|
break
|
|
|
|
result = self.client.completions.create(model=self.MODEL, prompt="Hello", max_tokens=5)
|
|
self.assertIsNotNone(result.choices[0].text)
|
|
|
|
# ----- continuous batching -----
|
|
|
|
@require_torch_accelerator
|
|
def test_cb_streaming(self):
|
|
"""Streaming completion with CB produces text."""
|
|
serve, port = _start_serve(
|
|
force_model=self.MODEL,
|
|
device="cuda:0",
|
|
continuous_batching=True,
|
|
attn_implementation="sdpa",
|
|
)
|
|
try:
|
|
client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="unused")
|
|
chunks = list(
|
|
client.completions.create(
|
|
model=self.MODEL, prompt="The capital of France is", max_tokens=10, stream=True
|
|
)
|
|
)
|
|
texts = [c.choices[0].text for c in chunks]
|
|
self.assertTrue(any(t != "" for t in texts))
|
|
self.assertIn(chunks[-1].choices[0].finish_reason, ("stop", "length"))
|
|
finally:
|
|
serve.kill_server()
|
|
|
|
@require_torch_accelerator
|
|
def test_cb_non_streaming(self):
|
|
"""Non-streaming completion with CB returns a full response."""
|
|
serve, port = _start_serve(
|
|
force_model=self.MODEL,
|
|
device="cuda:0",
|
|
continuous_batching=True,
|
|
attn_implementation="sdpa",
|
|
)
|
|
try:
|
|
client = OpenAI(base_url=f"http://localhost:{port}/v1", api_key="unused")
|
|
result = client.completions.create(model=self.MODEL, prompt="The capital of France is", max_tokens=10)
|
|
self.assertIsInstance(result.choices[0].text, str)
|
|
self.assertTrue(len(result.choices[0].text) > 0)
|
|
self.assertIn(result.choices[0].finish_reason, ("stop", "length"))
|
|
finally:
|
|
serve.kill_server()
|
|
|
|
|
|
@require_serve
|
|
class TestResponseInputConversion(unittest.TestCase):
|
|
def _make_handler(self):
|
|
return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState())
|
|
|
|
def test_string_input(self):
|
|
handler = self._make_handler()
|
|
msgs = handler._normalize_input({"input": "Hello"})
|
|
self.assertEqual(msgs, [{"role": "user", "content": "Hello"}])
|
|
|
|
def test_string_input_with_instructions(self):
|
|
handler = self._make_handler()
|
|
msgs = handler._normalize_input({"input": "Hello", "instructions": "Be brief"})
|
|
self.assertEqual(len(msgs), 2)
|
|
self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"})
|
|
self.assertEqual(msgs[1], {"role": "user", "content": "Hello"})
|
|
|
|
def test_list_input(self):
|
|
handler = self._make_handler()
|
|
msgs = handler._normalize_input(
|
|
{"input": [{"role": "user", "content": "A"}, {"role": "assistant", "content": "B"}]}
|
|
)
|
|
self.assertEqual(len(msgs), 2)
|
|
self.assertEqual(msgs[0]["content"], "A")
|
|
|
|
def test_list_input_with_instructions_prepends_system(self):
|
|
handler = self._make_handler()
|
|
msgs = handler._normalize_input({"input": [{"role": "user", "content": "Hi"}], "instructions": "Be helpful"})
|
|
self.assertEqual(len(msgs), 2)
|
|
self.assertEqual(msgs[0]["role"], "system")
|
|
self.assertEqual(msgs[0]["content"], "Be helpful")
|
|
|
|
def test_list_input_with_instructions_replaces_existing_system(self):
|
|
handler = self._make_handler()
|
|
msgs = handler._normalize_input(
|
|
{"input": [{"role": "system", "content": "Old"}, {"role": "user", "content": "Hi"}], "instructions": "New"}
|
|
)
|
|
self.assertEqual(len(msgs), 2)
|
|
self.assertEqual(msgs[0]["content"], "New")
|
|
|
|
def test_flat_content_list(self):
|
|
"""Flat content list (Responses API native) is wrapped as a single user message."""
|
|
handler = self._make_handler()
|
|
flat_input = [
|
|
{"type": "input_text", "text": "Hello"},
|
|
{"type": "input_image", "image_url": "https://example.com/img.jpg"},
|
|
]
|
|
msgs = handler._normalize_input({"input": flat_input})
|
|
self.assertEqual(len(msgs), 1)
|
|
self.assertEqual(msgs[0]["role"], "user")
|
|
self.assertEqual(msgs[0]["content"], flat_input)
|
|
|
|
def test_flat_content_list_with_instructions(self):
|
|
"""Flat content list with instructions prepends a system message."""
|
|
handler = self._make_handler()
|
|
flat_input = [{"type": "input_text", "text": "Hello"}]
|
|
msgs = handler._normalize_input({"input": flat_input, "instructions": "Be brief"})
|
|
self.assertEqual(len(msgs), 2)
|
|
self.assertEqual(msgs[0], {"role": "system", "content": "Be brief"})
|
|
self.assertEqual(msgs[1]["role"], "user")
|
|
self.assertEqual(msgs[1]["content"], flat_input)
|
|
|
|
|
|
@require_serve
|
|
class TestResponseValidation(unittest.TestCase):
|
|
def _make_handler(self):
|
|
return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState())
|
|
|
|
def test_unsupported_fields_warns(self):
|
|
handler = self._make_handler()
|
|
with self.assertLogs("transformers", level="WARNING") as cm:
|
|
handler._validate_request({"model": "x", "input": "hi", "previous_response_id": "abc"})
|
|
self.assertTrue(any("previous_response_id" in msg for msg in cm.output))
|
|
|
|
def test_valid_request_passes(self):
|
|
handler = self._make_handler()
|
|
# Should not raise
|
|
handler._validate_request({"model": "x", "input": "hi"})
|
|
|
|
|
|
@require_serve
|
|
class TestResponseGenerationConfig(unittest.TestCase):
|
|
def _make_handler(self):
|
|
return ResponseHandler(model_manager=MagicMock(), generation_state=GenerationState())
|
|
|
|
def test_max_output_tokens(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({"max_output_tokens": 42}, GenerationConfig())
|
|
self.assertEqual(result.max_new_tokens, 42)
|
|
|
|
def test_default_bumps_short_max_new_tokens(self):
|
|
from transformers import GenerationConfig
|
|
|
|
result = self._make_handler()._build_generation_config({}, GenerationConfig(max_new_tokens=20))
|
|
self.assertEqual(result.max_new_tokens, 1024)
|
|
|
|
|
|
@require_serve
|
|
class TestResponseUsage(unittest.TestCase):
|
|
def testcompute_usage(self):
|
|
usage = compute_usage(input_tokens=100, output_tokens=50)
|
|
self.assertEqual(usage.input_tokens, 100)
|
|
self.assertEqual(usage.output_tokens, 50)
|
|
self.assertEqual(usage.total_tokens, 150)
|
|
self.assertEqual(usage.input_tokens_details.cached_tokens, 0)
|
|
self.assertEqual(usage.output_tokens_details.reasoning_tokens, 0)
|
|
|
|
def test_usage_in_completed_response(self):
|
|
"""Usage should serialize correctly inside a Response."""
|
|
|
|
usage = compute_usage(10, 5)
|
|
response = Response(
|
|
id="resp_test",
|
|
created_at=0,
|
|
status="completed",
|
|
model="test",
|
|
output=[],
|
|
object="response",
|
|
tools=[],
|
|
parallel_tool_calls=False,
|
|
tool_choice="auto",
|
|
usage=usage,
|
|
)
|
|
dumped = response.model_dump(exclude_none=True)
|
|
self.assertEqual(dumped["usage"]["input_tokens"], 10)
|
|
self.assertEqual(dumped["usage"]["output_tokens"], 5)
|
|
self.assertEqual(dumped["usage"]["total_tokens"], 15)
|
|
|
|
|
|
@require_serve
|
|
class TestResponseSSEFormat(unittest.TestCase):
|
|
def test_sse_format(self):
|
|
event = ResponseCreatedEvent(
|
|
type="response.created",
|
|
sequence_number=0,
|
|
response=Response(
|
|
id="resp_test",
|
|
created_at=0,
|
|
status="queued",
|
|
model="test",
|
|
text={"format": {"type": "text"}},
|
|
object="response",
|
|
tools=[],
|
|
output=[],
|
|
parallel_tool_calls=False,
|
|
tool_choice="auto",
|
|
),
|
|
)
|
|
result = BaseHandler.chunk_to_sse(event)
|
|
self.assertTrue(result.startswith("data: "))
|
|
self.assertTrue(result.endswith("\n\n"))
|
|
parsed = json.loads(result[len("data: ") :].strip())
|
|
self.assertEqual(parsed["type"], "response.created")
|
|
self.assertEqual(parsed["response"]["status"], "queued")
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
class TestResponsesIntegration(unittest.TestCase):
|
|
"""Integration tests for /v1/responses with a real model."""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def test_streaming(self):
|
|
events = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hello",
|
|
stream=True,
|
|
max_output_tokens=1,
|
|
)
|
|
)
|
|
# At least 8 events: created, in_progress, output_item_added, content_part_added,
|
|
# delta(s), text_done, content_part_done, output_item_done, completed
|
|
self.assertGreaterEqual(len(events), 8)
|
|
|
|
# Start markers (fixed order)
|
|
self.assertEqual(events[0].type, "response.created")
|
|
self.assertEqual(events[1].type, "response.in_progress")
|
|
self.assertEqual(events[2].type, "response.output_item.added")
|
|
self.assertEqual(events[3].type, "response.content_part.added")
|
|
|
|
# At least one delta
|
|
self.assertTrue(any(e.type == "response.output_text.delta" for e in events[4:-4]))
|
|
|
|
# Closing markers (fixed order from the end)
|
|
self.assertEqual(events[-4].type, "response.output_text.done")
|
|
self.assertEqual(events[-3].type, "response.content_part.done")
|
|
self.assertEqual(events[-2].type, "response.output_item.done")
|
|
self.assertEqual(events[-1].type, "response.completed")
|
|
|
|
def test_non_streaming(self):
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hello",
|
|
stream=False,
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
self.assertTrue(len(resp.output) > 0)
|
|
self.assertTrue(len(resp.output[0].content[0].text) > 0)
|
|
|
|
def test_flat_content_list(self):
|
|
"""Flat content list input (Responses API native format)."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[{"type": "input_text", "text": "Say hello"}],
|
|
stream=False,
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
self.assertTrue(len(resp.output[0].content[0].text) > 0)
|
|
|
|
def test_non_streaming_usage(self):
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hello",
|
|
stream=False,
|
|
)
|
|
self.assertIsNotNone(resp.usage)
|
|
self.assertGreater(resp.usage.input_tokens, 0)
|
|
self.assertGreater(resp.usage.output_tokens, 0)
|
|
self.assertEqual(resp.usage.total_tokens, resp.usage.input_tokens + resp.usage.output_tokens)
|
|
|
|
def test_streaming_usage(self):
|
|
events = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hello",
|
|
stream=True,
|
|
max_output_tokens=5,
|
|
)
|
|
)
|
|
completed = events[-1]
|
|
self.assertEqual(completed.type, "response.completed")
|
|
usage = completed.response.usage
|
|
self.assertIsNotNone(usage)
|
|
self.assertGreater(usage.input_tokens, 0)
|
|
self.assertGreater(usage.output_tokens, 0)
|
|
self.assertEqual(usage.total_tokens, usage.input_tokens + usage.output_tokens)
|
|
|
|
def test_multi_turn(self):
|
|
"""Multi-turn conversation via list input."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[
|
|
{"role": "user", "content": "My name is Alice"},
|
|
{"role": "assistant", "content": "Nice to meet you!"},
|
|
{"role": "user", "content": "What is my name?"},
|
|
],
|
|
stream=False,
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
self.assertIn("Alice", resp.output[0].content[0].text)
|
|
|
|
def test_concurrent_non_streaming(self):
|
|
"""Two concurrent non-streaming responses requests should both complete."""
|
|
import concurrent.futures
|
|
|
|
inputs = ["Say hello", "Say goodbye"]
|
|
results = [None, None]
|
|
|
|
def request_in_thread(index):
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(request_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result()
|
|
|
|
for i in range(2):
|
|
self.assertIsNotNone(results[i])
|
|
self.assertEqual(results[i].status, "completed")
|
|
self.assertTrue(len(results[i].output[0].content[0].text) > 0)
|
|
|
|
def test_concurrent_streaming(self):
|
|
"""Two concurrent streaming responses requests should both produce complete event streams."""
|
|
import concurrent.futures
|
|
|
|
inputs = ["Say hello", "Say goodbye"]
|
|
results = [None, None]
|
|
|
|
def stream_in_thread(index):
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
results[index] = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True))
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(stream_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result()
|
|
|
|
for i in range(2):
|
|
types = [e.type for e in results[i]]
|
|
self.assertIn("response.created", types, f"Request {i} missing created event")
|
|
self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events")
|
|
self.assertIn("response.completed", types, f"Request {i} missing completed event")
|
|
|
|
|
|
def _parse_sse_events(response):
|
|
"""Parse SSE lines from a streaming httpx response into a list of dicts."""
|
|
events = []
|
|
for line in response.iter_lines():
|
|
if not line or not line.startswith("data: "):
|
|
continue
|
|
events.append(json.loads(line[len("data: ") :]))
|
|
return events
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
class TestLoadModel(unittest.TestCase):
|
|
"""Integration tests for POST /load_model SSE endpoint."""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def setUp(self):
|
|
# Clear model cache so each test starts fresh
|
|
self.serve.reset_loaded_models()
|
|
|
|
def _load_model(self, model: str):
|
|
with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": model}, timeout=120) as resp:
|
|
events = _parse_sse_events(resp)
|
|
return resp, events
|
|
|
|
def test_load_model_fresh(self):
|
|
"""POST /load_model returns SSE events ending with ready."""
|
|
response, events = self._load_model(self.MODEL)
|
|
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e]
|
|
self.assertIn("processor", stages)
|
|
self.assertIn("weights", stages)
|
|
|
|
last = events[-1]
|
|
self.assertEqual(last["status"], "ready")
|
|
self.assertFalse(last["cached"])
|
|
|
|
for event in events:
|
|
self.assertIn("status", event)
|
|
self.assertIn("model", event)
|
|
|
|
def test_load_model_cached(self):
|
|
"""Loading an already-loaded model returns a single ready event with cached: true."""
|
|
self._load_model(self.MODEL)
|
|
|
|
_, events = self._load_model(self.MODEL)
|
|
|
|
ready_events = [e for e in events if e["status"] == "ready"]
|
|
self.assertEqual(len(ready_events), 1)
|
|
self.assertTrue(ready_events[0]["cached"])
|
|
|
|
loading_events = [e for e in events if e["status"] == "loading"]
|
|
self.assertEqual(len(loading_events), 0)
|
|
|
|
def test_load_model_error(self):
|
|
"""Loading a nonexistent model produces an error event."""
|
|
_, events = self._load_model("nonexistent/model-that-does-not-exist")
|
|
|
|
error_events = [e for e in events if e["status"] == "error"]
|
|
self.assertGreaterEqual(len(error_events), 1)
|
|
self.assertIn("message", error_events[0])
|
|
|
|
def test_load_model_missing_field(self):
|
|
"""POST /load_model with no model field returns 422."""
|
|
|
|
response = httpx.post(f"{self.base_url}/load_model", json={}, timeout=30)
|
|
self.assertEqual(response.status_code, 422)
|
|
|
|
def test_load_model_event_schema(self):
|
|
"""Every event conforms to the expected schema."""
|
|
_, events = self._load_model(self.MODEL)
|
|
|
|
for event in events:
|
|
self.assertIsInstance(event["status"], str)
|
|
self.assertIsInstance(event["model"], str)
|
|
|
|
if event["status"] == "loading":
|
|
self.assertIn("stage", event)
|
|
if event["stage"] in ("download", "weights") and "progress" in event:
|
|
progress = event["progress"]
|
|
self.assertIn("current", progress)
|
|
self.assertIn("total", progress)
|
|
self.assertIsInstance(progress["current"], int)
|
|
|
|
if event["status"] == "ready":
|
|
self.assertIn("cached", event)
|
|
self.assertIsInstance(event["cached"], bool)
|
|
|
|
def test_load_model_stage_ordering(self):
|
|
"""Stages appear in the expected order."""
|
|
_, events = self._load_model(self.MODEL)
|
|
|
|
stages = [e["stage"] for e in events if e["status"] == "loading" and "stage" in e]
|
|
seen = set()
|
|
unique_stages = []
|
|
for s in stages:
|
|
if s not in seen:
|
|
seen.add(s)
|
|
unique_stages.append(s)
|
|
|
|
expected_order = ["processor", "config", "download", "weights"]
|
|
expected_present = [s for s in expected_order if s in unique_stages]
|
|
self.assertEqual(unique_stages, expected_present, "Stages appeared out of order")
|
|
|
|
def test_concurrent_load_same_model(self):
|
|
"""Two concurrent /load_model requests both get events and a ready event."""
|
|
import concurrent.futures
|
|
|
|
results = [None, None]
|
|
|
|
def load_in_thread(index):
|
|
with httpx.stream("POST", f"{self.base_url}/load_model", json={"model": self.MODEL}, timeout=120) as resp:
|
|
events = _parse_sse_events(resp)
|
|
results[index] = (resp.status_code, events)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(load_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result()
|
|
|
|
for i in range(2):
|
|
status_code, events = results[i]
|
|
self.assertEqual(status_code, 200, f"Caller {i} got non-200 status")
|
|
self.assertTrue(len(events) > 0, f"Caller {i} received no events")
|
|
ready_events = [e for e in events if e["status"] == "ready"]
|
|
self.assertEqual(len(ready_events), 1, f"Caller {i} should get exactly one ready event")
|
|
|
|
def test_concurrent_load_second_caller_gets_cached(self):
|
|
"""If the first /load_model finishes before the second, the second gets cached: true."""
|
|
_, events1 = self._load_model(self.MODEL)
|
|
ready1 = [e for e in events1 if e["status"] == "ready"]
|
|
self.assertEqual(len(ready1), 1)
|
|
self.assertFalse(ready1[0]["cached"])
|
|
|
|
_, events2 = self._load_model(self.MODEL)
|
|
ready2 = [e for e in events2 if e["status"] == "ready"]
|
|
self.assertEqual(len(ready2), 1)
|
|
self.assertTrue(ready2[0]["cached"])
|
|
|
|
loading2 = [e for e in events2 if e["status"] == "loading"]
|
|
self.assertEqual(len(loading2), 0)
|
|
|
|
def test_load_model_weights_progress_complete(self):
|
|
"""Weights progress should go from 1 to total, with total matching across events."""
|
|
_, events = self._load_model(self.MODEL)
|
|
|
|
weights_events = [e for e in events if e.get("stage") == "weights" and "progress" in e]
|
|
self.assertGreater(len(weights_events), 0, "No weights progress events emitted")
|
|
|
|
# All events should have the same total
|
|
totals = {e["progress"]["total"] for e in weights_events}
|
|
self.assertEqual(len(totals), 1, f"Inconsistent totals: {totals}")
|
|
total = totals.pop()
|
|
self.assertIsNotNone(total)
|
|
self.assertGreater(total, 0)
|
|
|
|
# First should be 1, last should be total
|
|
self.assertEqual(weights_events[0]["progress"]["current"], 1)
|
|
self.assertEqual(weights_events[-1]["progress"]["current"], total)
|
|
|
|
# Progress should be monotonically increasing
|
|
currents = [e["progress"]["current"] for e in weights_events]
|
|
self.assertEqual(currents, sorted(currents))
|
|
|
|
def test_load_model_exactly_one_ready(self):
|
|
"""A fresh load should produce exactly one ready event as the last event."""
|
|
_, events = self._load_model(self.MODEL)
|
|
|
|
ready_events = [e for e in events if e["status"] == "ready"]
|
|
self.assertEqual(len(ready_events), 1)
|
|
self.assertEqual(events[-1]["status"], "ready")
|
|
|
|
def test_load_model_usable_after_load(self):
|
|
"""After /load_model completes, the model should be usable for inference."""
|
|
self._load_model(self.MODEL)
|
|
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
resp = client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Say hi"}],
|
|
max_tokens=5,
|
|
)
|
|
self.assertIsNotNone(resp.choices[0].message.content)
|
|
self.assertTrue(len(resp.choices[0].message.content) > 0)
|
|
|
|
def test_load_model_model_field_matches(self):
|
|
"""The model field in every event should match the canonical model ID."""
|
|
_, events = self._load_model(self.MODEL)
|
|
|
|
for event in events:
|
|
self.assertTrue(
|
|
event["model"].startswith(self.MODEL),
|
|
f"Event model '{event['model']}' doesn't match '{self.MODEL}'",
|
|
)
|
|
|
|
def test_concurrent_non_streaming(self):
|
|
"""Two concurrent non-streaming responses requests should both complete."""
|
|
import concurrent.futures
|
|
|
|
inputs = ["Say hello", "Say goodbye"]
|
|
results = [None, None]
|
|
|
|
def request_in_thread(index):
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
results[index] = client.responses.create(model=self.MODEL, input=inputs[index], stream=False)
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(request_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result()
|
|
|
|
for i in range(2):
|
|
self.assertIsNotNone(results[i])
|
|
self.assertEqual(results[i].status, "completed")
|
|
self.assertTrue(len(results[i].output[0].content[0].text) > 0)
|
|
|
|
def test_concurrent_streaming(self):
|
|
"""Two concurrent streaming responses requests should both produce complete event streams."""
|
|
import concurrent.futures
|
|
|
|
inputs = ["Say hello", "Say goodbye"]
|
|
results = [None, None]
|
|
|
|
def stream_in_thread(index):
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
events = list(client.responses.create(model=self.MODEL, input=inputs[index], stream=True))
|
|
results[index] = events
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [pool.submit(stream_in_thread, i) for i in range(2)]
|
|
concurrent.futures.wait(futures)
|
|
for f in futures:
|
|
f.result()
|
|
|
|
for i in range(2):
|
|
types = [e.type for e in results[i]]
|
|
self.assertIn("response.created", types, f"Request {i} missing created event")
|
|
self.assertIn("response.output_text.delta", types, f"Request {i} missing delta events")
|
|
self.assertIn("response.completed", types, f"Request {i} missing completed event")
|
|
|
|
|
|
# Real image URL for VLM tests (person + dog on a beach)
|
|
_DOG_IMAGE_URL = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg"
|
|
|
|
|
|
@slow
|
|
@require_vision
|
|
@require_serve
|
|
class TestVLM(unittest.TestCase):
|
|
"""Integration tests for VLM (vision-language model) support. Requires torchvision."""
|
|
|
|
MODEL = "HuggingFaceTB/SmolVLM-256M-Instruct"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def test_chat_completion_with_image(self):
|
|
"""Chat completions should accept image_url content and produce a meaningful response."""
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "What do you see in this image?"},
|
|
{"type": "image_url", "image_url": {"url": _DOG_IMAGE_URL}},
|
|
],
|
|
}
|
|
],
|
|
max_tokens=50,
|
|
)
|
|
text = resp.choices[0].message.content
|
|
self.assertIsNotNone(text)
|
|
self.assertTrue(
|
|
any(word in text.lower() for word in ["dog", "beach", "person"]),
|
|
f"Expected dog/beach/person in response, got: {text}",
|
|
)
|
|
|
|
def test_responses_with_image(self):
|
|
"""Responses API should accept input_image content and produce a meaningful response."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[
|
|
{"type": "input_text", "text": "What do you see in this image?"},
|
|
{"type": "input_image", "image_url": _DOG_IMAGE_URL},
|
|
],
|
|
stream=False,
|
|
max_output_tokens=50,
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
text = resp.output[0].content[0].text
|
|
self.assertTrue(
|
|
any(word in text.lower() for word in ["dog", "beach", "person"]),
|
|
f"Expected dog/beach/person in response, got: {text}",
|
|
)
|
|
|
|
|
|
_AUDIO_URL = (
|
|
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/obama_first_45_secs.mp3"
|
|
)
|
|
_VIDEO_URL = "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/concert.mp4"
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
class TestMultimodalLM(unittest.TestCase):
|
|
"""Integration tests for multimodal (audio, video) chat completions with Gemma 4."""
|
|
|
|
MODEL = "google/gemma-4-E2B-it"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def _get_audio_messages(self):
|
|
import base64
|
|
|
|
audio_bytes = httpx.get(_AUDIO_URL, follow_redirects=True).content
|
|
audio_b64 = base64.b64encode(audio_bytes).decode()
|
|
return [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "Transcribe this audio."},
|
|
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "mp3"}},
|
|
],
|
|
}
|
|
]
|
|
|
|
def _get_audio_flat_input(self):
|
|
"""Flat content list format for Responses API."""
|
|
import base64
|
|
|
|
audio_bytes = httpx.get(_AUDIO_URL, follow_redirects=True).content
|
|
audio_b64 = base64.b64encode(audio_bytes).decode()
|
|
return [
|
|
{"type": "input_text", "text": "Transcribe this audio."},
|
|
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "mp3"}},
|
|
]
|
|
|
|
def _assert_audio_transcription(self, text):
|
|
self.assertIsNotNone(text)
|
|
self.assertIn("chicago", text.lower(), f"Expected 'chicago' in transcription, got: {text}")
|
|
|
|
def _assert_video_description(self, text):
|
|
self.assertIsNotNone(text)
|
|
self.assertTrue(
|
|
any(word in text.lower() for word in ["concert", "music", "stage", "perform"]),
|
|
f"Expected concert/music/stage/perform in response, got: {text}",
|
|
)
|
|
|
|
def test_chat_completion_with_audio(self):
|
|
"""Chat completions should accept input_audio (base64) content and transcribe audio."""
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=self._get_audio_messages(),
|
|
max_tokens=200,
|
|
)
|
|
self._assert_audio_transcription(resp.choices[0].message.content)
|
|
|
|
def test_chat_completion_with_audio_streaming(self):
|
|
"""Streaming chat completions should accept input_audio (base64) content and transcribe audio."""
|
|
stream = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=self._get_audio_messages(),
|
|
max_tokens=200,
|
|
stream=True,
|
|
)
|
|
chunks = []
|
|
for chunk in stream:
|
|
if chunk.choices[0].delta.content:
|
|
chunks.append(chunk.choices[0].delta.content)
|
|
self._assert_audio_transcription("".join(chunks))
|
|
|
|
@require_torchcodec
|
|
def test_chat_completion_with_video(self):
|
|
"""Chat completions should accept video_url content and describe video."""
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "video_url", "video_url": {"url": _VIDEO_URL}},
|
|
{"type": "text", "text": "What is happening in the video?"},
|
|
],
|
|
}
|
|
],
|
|
max_tokens=200,
|
|
)
|
|
self._assert_video_description(resp.choices[0].message.content)
|
|
|
|
@require_torchcodec
|
|
def test_chat_completion_with_video_streaming(self):
|
|
"""Streaming chat completions should accept video_url content and describe video."""
|
|
stream = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "video_url", "video_url": {"url": _VIDEO_URL}},
|
|
{"type": "text", "text": "What is happening in the video?"},
|
|
],
|
|
}
|
|
],
|
|
max_tokens=200,
|
|
stream=True,
|
|
)
|
|
chunks = []
|
|
for chunk in stream:
|
|
if chunk.choices[0].delta.content:
|
|
chunks.append(chunk.choices[0].delta.content)
|
|
self._assert_video_description("".join(chunks))
|
|
|
|
def test_responses_with_audio(self):
|
|
"""Responses API should accept input_audio (base64) content and transcribe audio."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=self._get_audio_flat_input(),
|
|
stream=False,
|
|
max_output_tokens=200,
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
self._assert_audio_transcription(resp.output[0].content[0].text)
|
|
|
|
def test_responses_with_audio_streaming(self):
|
|
"""Streaming responses API should accept input_audio (base64) content and transcribe audio."""
|
|
stream = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=self._get_audio_flat_input(),
|
|
stream=True,
|
|
max_output_tokens=200,
|
|
)
|
|
text = ""
|
|
for event in stream:
|
|
if event.type == "response.output_text.delta":
|
|
text += event.delta
|
|
self._assert_audio_transcription(text)
|
|
|
|
@require_torchcodec
|
|
def test_responses_with_video(self):
|
|
"""Responses API should accept video_url content and describe video."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[
|
|
{"type": "input_text", "text": "What is happening in the video?"},
|
|
{"type": "video_url", "video_url": {"url": _VIDEO_URL}},
|
|
],
|
|
stream=False,
|
|
max_output_tokens=200,
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
self._assert_video_description(resp.output[0].content[0].text)
|
|
|
|
@require_torchcodec
|
|
def test_responses_with_video_streaming(self):
|
|
"""Streaming responses API should accept video_url content and describe video."""
|
|
stream = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[
|
|
{"type": "input_text", "text": "What is happening in the video?"},
|
|
{"type": "video_url", "video_url": {"url": _VIDEO_URL}},
|
|
],
|
|
stream=True,
|
|
max_output_tokens=200,
|
|
)
|
|
text = ""
|
|
for event in stream:
|
|
if event.type == "response.output_text.delta":
|
|
text += event.delta
|
|
self._assert_video_description(text)
|
|
|
|
|
|
class TestToolCallUnit(unittest.TestCase):
|
|
"""Unit tests for tool call parsing utilities (no server needed)."""
|
|
|
|
def test_get_tool_call_config_fallback(self):
|
|
"""Fallback config is returned for known model families (Qwen)."""
|
|
model = MagicMock()
|
|
model.config.model_type = "qwen2"
|
|
processor = MagicMock(spec=["convert_tokens_to_ids"])
|
|
processor.convert_tokens_to_ids.return_value = 151657
|
|
config = get_tool_call_config(processor, model)
|
|
self.assertIsNotNone(config)
|
|
self.assertEqual(config["stc_id"], 151657)
|
|
self.assertEqual(config["etc_id"], 151657)
|
|
|
|
def test_get_tool_call_config_unsupported(self):
|
|
"""None is returned for models without tool call support."""
|
|
model = MagicMock()
|
|
model.config.model_type = "llama"
|
|
processor = MagicMock(spec=[])
|
|
self.assertIsNone(get_tool_call_config(processor, model))
|
|
|
|
def test_parse_tool_calls_from_text(self):
|
|
text = '<tool_call>\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n</tool_call>'
|
|
processor = MagicMock()
|
|
processor.parse_response = lambda t, s: recursive_parse(t, s)
|
|
schema = next(v["schema"] for k, v in _TOOL_CALL_FALLBACKS.items() if "qwen2" in k)
|
|
calls = parse_tool_calls(processor, text, schema)
|
|
self.assertEqual(len(calls), 1)
|
|
self.assertEqual(calls[0]["name"], "get_weather")
|
|
|
|
def test_parse_multiple_tool_calls_from_text(self):
|
|
text = (
|
|
'<tool_call>\n{"name": "get_weather", "arguments": {"city": "Paris"}}\n</tool_call>\n'
|
|
'<tool_call>\n{"name": "get_weather", "arguments": {"city": "London"}}\n</tool_call>'
|
|
)
|
|
processor = MagicMock()
|
|
processor.parse_response = lambda t, s: recursive_parse(t, s)
|
|
schema = next(v["schema"] for k, v in _TOOL_CALL_FALLBACKS.items() if "qwen2" in k)
|
|
calls = parse_tool_calls(processor, text, schema)
|
|
self.assertEqual(len(calls), 2)
|
|
|
|
|
|
class TestCBWorkerDeadServerIntegration(unittest.TestCase):
|
|
"""End-to-end FastAPI behavior when the CB worker has died.
|
|
|
|
Asserts the wiring from a ``CBWorkerDeadError`` raised in a request handler
|
|
(or a dead CB worker observed by ``/health``) to a 503 response carrying the cause —
|
|
the contract orchestrators rely on to recycle the pod.
|
|
"""
|
|
|
|
def _build_app(self, generation_state, chat_handler=None):
|
|
from transformers.cli.serving.server import build_server
|
|
|
|
return build_server(
|
|
model_manager=MagicMock(),
|
|
chat_handler=chat_handler or MagicMock(),
|
|
completion_handler=MagicMock(),
|
|
response_handler=MagicMock(),
|
|
transcription_handler=MagicMock(),
|
|
generation_state=generation_state,
|
|
)
|
|
|
|
def test_health_returns_503_when_cb_dead(self):
|
|
from fastapi.testclient import TestClient
|
|
|
|
from transformers.cli.serving.utils import CBGenerateManager
|
|
|
|
state = GenerationState(continuous_batching=True)
|
|
# Manager whose underlying CB has a fatal_error set -> is_alive() returns False.
|
|
mgr = CBGenerateManager()
|
|
mgr._cb = MagicMock()
|
|
mgr._cb.fatal_error = RuntimeError("CUDA illegal memory access")
|
|
state._cb_manager = mgr
|
|
|
|
resp = TestClient(self._build_app(state)).get("/health")
|
|
|
|
self.assertEqual(resp.status_code, 503)
|
|
self.assertEqual(resp.json(), {"status": "unhealthy", "reason": "cb_worker_dead"})
|
|
|
|
def test_chat_endpoint_returns_503_with_cause(self):
|
|
"""A CBWorkerDeadError raised from the chat route maps to 503 with the cause in the body."""
|
|
from fastapi.testclient import TestClient
|
|
|
|
from transformers.cli.serving.utils import CBWorkerDeadError
|
|
|
|
chat_handler = MagicMock()
|
|
|
|
async def handle_request(_body, _request_id):
|
|
raise CBWorkerDeadError("CB worker is dead and cannot accept request: CUDA illegal memory access")
|
|
|
|
chat_handler.handle_request = handle_request
|
|
|
|
client = TestClient(self._build_app(GenerationState(continuous_batching=True), chat_handler=chat_handler))
|
|
resp = client.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "x", "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
|
|
self.assertEqual(resp.status_code, 503)
|
|
# Body carries the original cause so the client knows why the server is broken.
|
|
self.assertIn("CUDA illegal memory access", resp.json()["error"])
|
|
|
|
|
|
class _TestToolCallBase:
|
|
"""Base class for tool call integration tests. Subclasses set MODEL and inherit all tests."""
|
|
|
|
MODEL: str
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def _get_tool_def(self):
|
|
return {
|
|
"function": {
|
|
"name": "get_weather",
|
|
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}},
|
|
"description": "Get the weather for a city.",
|
|
},
|
|
"type": "function",
|
|
}
|
|
|
|
def test_chat_non_streaming(self):
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "What is the weather in Paris?"}],
|
|
stream=False,
|
|
max_tokens=50,
|
|
temperature=0.0,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
self.assertEqual(resp.choices[0].finish_reason, "tool_calls")
|
|
self.assertIsNotNone(resp.choices[0].message.tool_calls)
|
|
tc = resp.choices[0].message.tool_calls[0]
|
|
self.assertEqual(tc.function.name, "get_weather")
|
|
parsed_args = json.loads(tc.function.arguments)
|
|
self.assertIsInstance(parsed_args, dict)
|
|
|
|
def test_chat_streaming(self):
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "What is the weather in Paris?"}],
|
|
stream=True,
|
|
max_tokens=50,
|
|
temperature=0.0,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
)
|
|
tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls]
|
|
self.assertGreater(len(tool_chunks), 0, "Model did not produce a tool call")
|
|
first_tool = tool_chunks[0].choices[0].delta.tool_calls[0]
|
|
self.assertEqual(first_tool.function.name, "get_weather")
|
|
self.assertEqual(chunks[-1].choices[0].finish_reason, "tool_calls")
|
|
parsed_args = json.loads(first_tool.function.arguments)
|
|
self.assertIsInstance(parsed_args, dict)
|
|
|
|
def test_chat_multiple_tool_calls_non_streaming(self):
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "What is the weather in Paris and London?"}],
|
|
stream=False,
|
|
max_tokens=100,
|
|
temperature=0.0,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
self.assertEqual(resp.choices[0].finish_reason, "tool_calls")
|
|
self.assertEqual(len(resp.choices[0].message.tool_calls), 2)
|
|
|
|
def test_chat_multiple_tool_calls_streaming(self):
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "What is the weather in Paris and London?"}],
|
|
stream=True,
|
|
max_tokens=100,
|
|
temperature=0.0,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
)
|
|
tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls]
|
|
self.assertEqual(len(tool_chunks), 2, f"Expected 2 tool calls, got {len(tool_chunks)}")
|
|
self.assertEqual(chunks[-1].choices[0].finish_reason, "tool_calls")
|
|
|
|
def test_chat_multi_turn_non_streaming(self):
|
|
tool_def = self._get_tool_def()
|
|
resp1 = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "What is the weather in Paris?"}],
|
|
stream=False,
|
|
max_tokens=50,
|
|
temperature=0.0,
|
|
tools=[tool_def],
|
|
)
|
|
self.assertEqual(resp1.choices[0].finish_reason, "tool_calls")
|
|
tc = resp1.choices[0].message.tool_calls[0]
|
|
|
|
resp2 = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{"role": "user", "content": "What is the weather in Paris?"},
|
|
resp1.choices[0].message,
|
|
{"role": "tool", "tool_call_id": tc.id, "content": '{"temperature": 22, "condition": "sunny"}'},
|
|
],
|
|
stream=False,
|
|
max_tokens=100,
|
|
temperature=0.0,
|
|
tools=[tool_def],
|
|
)
|
|
self.assertIn(resp2.choices[0].finish_reason, ("stop", "length"))
|
|
content = resp2.choices[0].message.content
|
|
self.assertIsNotNone(content)
|
|
self.assertTrue(
|
|
"22" in content.lower() or "sunny" in content.lower(),
|
|
f"Expected model to reference tool result, got: {content}",
|
|
)
|
|
|
|
def test_chat_multi_turn_streaming(self):
|
|
tool_def = self._get_tool_def()
|
|
|
|
# Turn 1: streaming — accumulate tool call from deltas
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "What is the weather in Paris?"}],
|
|
stream=True,
|
|
max_tokens=50,
|
|
temperature=0.0,
|
|
tools=[tool_def],
|
|
)
|
|
)
|
|
self.assertEqual(chunks[-1].choices[0].finish_reason, "tool_calls")
|
|
tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls]
|
|
self.assertGreater(len(tool_chunks), 0)
|
|
tc = tool_chunks[0].choices[0].delta.tool_calls[0]
|
|
|
|
# Reconstruct assistant message from deltas
|
|
content = "".join(c.choices[0].delta.content for c in chunks if c.choices[0].delta.content)
|
|
assistant_msg = {
|
|
"role": "assistant",
|
|
"content": content,
|
|
"tool_calls": [{"id": tc.id, "type": "function", "function": tc.function.model_dump()}],
|
|
}
|
|
|
|
# Turn 2: streaming — send back tool result
|
|
chunks2 = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{"role": "user", "content": "What is the weather in Paris?"},
|
|
assistant_msg,
|
|
{"role": "tool", "tool_call_id": tc.id, "content": '{"temperature": 22, "condition": "sunny"}'},
|
|
],
|
|
stream=True,
|
|
max_tokens=100,
|
|
temperature=0.0,
|
|
tools=[tool_def],
|
|
)
|
|
)
|
|
content = "".join(c.choices[0].delta.content for c in chunks2 if c.choices[0].delta.content)
|
|
self.assertTrue(
|
|
"22" in content.lower() or "sunny" in content.lower(),
|
|
f"Expected model to reference tool result, got: {content}",
|
|
)
|
|
|
|
def test_responses_non_streaming(self):
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="What is the weather in Paris?",
|
|
stream=False,
|
|
max_output_tokens=50,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
fc_items = [o for o in resp.output if o.type == "function_call"]
|
|
self.assertGreater(len(fc_items), 0)
|
|
self.assertEqual(fc_items[0].name, "get_weather")
|
|
parsed = json.loads(fc_items[0].arguments)
|
|
self.assertIsInstance(parsed, dict)
|
|
|
|
def test_responses_streaming(self):
|
|
events = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="What is the weather in Paris?",
|
|
stream=True,
|
|
max_output_tokens=50,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
)
|
|
types = [e.type for e in events]
|
|
self.assertIn("response.created", types)
|
|
self.assertIn("response.completed", types)
|
|
self.assertIn("response.function_call_arguments.done", types)
|
|
|
|
args_done = [e for e in events if e.type == "response.function_call_arguments.done"]
|
|
self.assertGreater(len(args_done), 0)
|
|
self.assertEqual(args_done[0].name, "get_weather")
|
|
parsed = json.loads(args_done[0].arguments)
|
|
self.assertIsInstance(parsed, dict)
|
|
|
|
def test_responses_multiple_tool_calls_non_streaming(self):
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="What is the weather in Paris and London?",
|
|
stream=False,
|
|
max_output_tokens=100,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
self.assertEqual(resp.status, "completed")
|
|
fc_items = [o for o in resp.output if o.type == "function_call"]
|
|
self.assertEqual(len(fc_items), 2, f"Expected 2 tool calls, got {len(fc_items)}")
|
|
|
|
def test_responses_multiple_tool_calls_streaming(self):
|
|
events = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="What is the weather in Paris and London?",
|
|
stream=True,
|
|
max_output_tokens=100,
|
|
tools=[self._get_tool_def()],
|
|
)
|
|
)
|
|
args_done = [e for e in events if e.type == "response.function_call_arguments.done"]
|
|
self.assertEqual(len(args_done), 2, f"Expected 2 tool calls, got {len(args_done)}")
|
|
self.assertEqual(events[-1].type, "response.completed")
|
|
|
|
def test_responses_multi_turn_non_streaming(self):
|
|
tool_def = self._get_tool_def()
|
|
resp1 = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="What is the weather in Paris?",
|
|
stream=False,
|
|
max_output_tokens=50,
|
|
tools=[tool_def],
|
|
)
|
|
self.assertEqual(resp1.status, "completed")
|
|
fc_items = [o for o in resp1.output if o.type == "function_call"]
|
|
self.assertGreater(len(fc_items), 0)
|
|
|
|
input_list = [{"role": "user", "content": "What is the weather in Paris?"}]
|
|
input_list += resp1.output
|
|
input_list.append(
|
|
{
|
|
"type": "function_call_output",
|
|
"call_id": fc_items[0].call_id,
|
|
"output": '{"temperature": 22, "condition": "sunny"}',
|
|
}
|
|
)
|
|
resp2 = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=input_list,
|
|
stream=False,
|
|
max_output_tokens=100,
|
|
tools=[tool_def],
|
|
)
|
|
self.assertEqual(resp2.status, "completed")
|
|
msg_items = [o for o in resp2.output if o.type == "message"]
|
|
self.assertGreater(len(msg_items), 0)
|
|
content = msg_items[0].content[0].text
|
|
self.assertTrue(
|
|
"22" in content.lower() or "sunny" in content.lower(),
|
|
f"Expected model to reference tool result, got: {content}",
|
|
)
|
|
|
|
def test_responses_multi_turn_streaming(self):
|
|
tool_def = self._get_tool_def()
|
|
|
|
# Turn 1: streaming — get completed response with tool calls
|
|
events = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="What is the weather in Paris?",
|
|
stream=True,
|
|
max_output_tokens=50,
|
|
tools=[tool_def],
|
|
)
|
|
)
|
|
completed = [e for e in events if e.type == "response.completed"]
|
|
self.assertEqual(len(completed), 1)
|
|
resp1_output = completed[0].response.output
|
|
fc_items = [o for o in resp1_output if o.type == "function_call"]
|
|
self.assertGreater(len(fc_items), 0)
|
|
|
|
# Turn 2: streaming — send back tool result
|
|
input_list = [{"role": "user", "content": "What is the weather in Paris?"}]
|
|
input_list += resp1_output
|
|
input_list.append(
|
|
{
|
|
"type": "function_call_output",
|
|
"call_id": fc_items[0].call_id,
|
|
"output": '{"temperature": 22, "condition": "sunny"}',
|
|
}
|
|
)
|
|
events2 = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=input_list,
|
|
stream=True,
|
|
max_output_tokens=100,
|
|
tools=[tool_def],
|
|
)
|
|
)
|
|
content = "".join(e.delta for e in events2 if e.type == "response.output_text.delta")
|
|
self.assertTrue(
|
|
"22" in content.lower() or "sunny" in content.lower(),
|
|
f"Expected model to reference tool result, got: {content}",
|
|
)
|
|
|
|
# ----- parser equivalence -----
|
|
|
|
def test_chat_streaming_matches_non_streaming(self):
|
|
"""Streaming and non-streaming chat completions yield the same tool call at T=0."""
|
|
msgs = [{"role": "user", "content": "What is the weather in Paris?"}]
|
|
kwargs = {"model": self.MODEL, "max_tokens": 50, "temperature": 0.0, "tools": [self._get_tool_def()]}
|
|
|
|
ns = self.client.chat.completions.create(messages=msgs, stream=False, **kwargs)
|
|
ns_tc = ns.choices[0].message.tool_calls[0]
|
|
|
|
chunks = list(self.client.chat.completions.create(messages=msgs, stream=True, **kwargs))
|
|
tool_chunks = [c for c in chunks if c.choices[0].delta.tool_calls]
|
|
# Arguments may be split across deltas, accumulate then compare parsed dicts.
|
|
s_name = tool_chunks[0].choices[0].delta.tool_calls[0].function.name
|
|
s_args = "".join(c.choices[0].delta.tool_calls[0].function.arguments or "" for c in tool_chunks)
|
|
|
|
self.assertEqual(s_name, ns_tc.function.name)
|
|
self.assertEqual(json.loads(s_args), json.loads(ns_tc.function.arguments))
|
|
|
|
def test_responses_streaming_matches_non_streaming(self):
|
|
"""Streaming and non-streaming Responses API yield the same tool call at T=0."""
|
|
kwargs = {
|
|
"model": self.MODEL,
|
|
"input": "What is the weather in Paris?",
|
|
"max_output_tokens": 50,
|
|
"tools": [self._get_tool_def()],
|
|
}
|
|
|
|
ns = self.client.responses.create(stream=False, **kwargs)
|
|
ns_tc = next(i for i in ns.output if i.type == "function_call")
|
|
|
|
events = list(self.client.responses.create(stream=True, **kwargs))
|
|
completed = next(e for e in events if e.type == "response.completed")
|
|
s_tc = next(i for i in completed.response.output if i.type == "function_call")
|
|
|
|
self.assertEqual(s_tc.name, ns_tc.name)
|
|
self.assertEqual(json.loads(s_tc.arguments), json.loads(ns_tc.arguments))
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
@require_torch_accelerator
|
|
class TestToolCallQwen(_TestToolCallBase, unittest.TestCase):
|
|
"""Tool call tests with Qwen (fallback config, no response_schema)."""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
@require_torch_accelerator
|
|
class TestToolCallGemma(_TestToolCallBase, unittest.TestCase):
|
|
"""Tool call tests with Gemma 4 (response_schema + stc/etc special tokens)."""
|
|
|
|
MODEL = "google/gemma-4-E2B-it"
|
|
|
|
|
|
class _TestReasoningBase:
|
|
"""Base class for reasoning integration tests. Subclasses set MODEL.
|
|
|
|
A single server is shared across all tests in a subclass via setUpClass.
|
|
"""
|
|
|
|
MODEL: str
|
|
USER_PROMPT = "What is 17 * 23? Think briefly, then answer in one sentence."
|
|
EXPECTED_ANSWER = "391"
|
|
MAX_TOKENS = 512
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
@staticmethod
|
|
def _reasoning_field(obj):
|
|
"""Return ``reasoning_content`` from a chat message or delta (handles model_extra)."""
|
|
return getattr(obj, "reasoning_content", None) or (obj.model_extra or {}).get("reasoning_content")
|
|
|
|
# ----- chat completions -----
|
|
|
|
def test_chat_non_streaming(self):
|
|
"""Chat completions: non-streaming surfaces ``reasoning_content`` and strips delimiters."""
|
|
msg = (
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": self.USER_PROMPT}],
|
|
stream=False,
|
|
max_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
.choices[0]
|
|
.message
|
|
)
|
|
reasoning = self._reasoning_field(msg)
|
|
self.assertIn(self.EXPECTED_ANSWER, reasoning or "", f"answer missing from reasoning: {reasoning!r}")
|
|
self.assertIn(self.EXPECTED_ANSWER, msg.content or "", f"answer missing from content: {msg.content!r}")
|
|
self.assertNotIn("<think>", msg.content or "")
|
|
self.assertNotIn("<|channel>", msg.content or "")
|
|
self.assertNotIn(reasoning.strip()[:30], msg.content or "")
|
|
|
|
def test_chat_streaming(self):
|
|
"""Chat completions: streaming emits ``reasoning_content`` deltas; content stays clean."""
|
|
chunks = list(
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": self.USER_PROMPT}],
|
|
stream=True,
|
|
max_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
)
|
|
reasoning_text = "".join(self._reasoning_field(c.choices[0].delta) or "" for c in chunks)
|
|
self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}")
|
|
content = "".join(c.choices[0].delta.content or "" for c in chunks)
|
|
self.assertIn(self.EXPECTED_ANSWER, content, f"answer missing from content: {content!r}")
|
|
self.assertNotIn("<think>", content)
|
|
self.assertNotIn("<|channel>", content)
|
|
|
|
def test_chat_multi_turn_round_trips_reasoning(self):
|
|
"""Chat completions: reasoning_content from a prior turn round-trips through input."""
|
|
first = (
|
|
self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": self.USER_PROMPT}],
|
|
stream=False,
|
|
max_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
.choices[0]
|
|
.message
|
|
)
|
|
reasoning = self._reasoning_field(first)
|
|
self.assertTrue(reasoning)
|
|
second = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{"role": "user", "content": self.USER_PROMPT},
|
|
{"role": "assistant", "content": first.content or "", "reasoning_content": reasoning},
|
|
{"role": "user", "content": "Now multiply that result by 2."},
|
|
],
|
|
stream=False,
|
|
max_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
self.assertIsNotNone(second.choices[0].message.content)
|
|
|
|
# ----- responses -----
|
|
|
|
def test_response_non_streaming(self):
|
|
"""Responses API: non-streaming includes a reasoning item before the message item."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=self.USER_PROMPT,
|
|
stream=False,
|
|
max_output_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
types = [item.type for item in resp.output]
|
|
self.assertIn("reasoning", types, f"expected reasoning item, got types: {types}")
|
|
self.assertIn("message", types)
|
|
self.assertLess(types.index("reasoning"), types.index("message"))
|
|
reasoning_text = next(item for item in resp.output if item.type == "reasoning").content[0].text
|
|
self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}")
|
|
message_text = next(item for item in resp.output if item.type == "message").content[0].text
|
|
self.assertIn(self.EXPECTED_ANSWER, message_text, f"answer missing from message: {message_text!r}")
|
|
self.assertNotIn("<think>", message_text)
|
|
self.assertNotIn("<|channel>", message_text)
|
|
|
|
def test_response_streaming(self):
|
|
"""Responses API: streaming emits reasoning_text events and a separate reasoning item."""
|
|
events = list(
|
|
self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=self.USER_PROMPT,
|
|
stream=True,
|
|
max_output_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
)
|
|
added = [e for e in events if e.type == "response.output_item.added"]
|
|
self.assertGreaterEqual(len(added), 2)
|
|
self.assertEqual(added[0].item.type, "reasoning")
|
|
self.assertEqual(added[1].item.type, "message")
|
|
# Coherence: concat of reasoning_text.delta events == reasoning_text.done.text, and contains the answer.
|
|
reasoning_text = "".join(e.delta for e in events if e.type == "response.reasoning_text.delta")
|
|
done = next(e for e in events if e.type == "response.reasoning_text.done")
|
|
self.assertEqual(reasoning_text, done.text)
|
|
self.assertIn(self.EXPECTED_ANSWER, reasoning_text, f"answer missing from reasoning: {reasoning_text!r}")
|
|
content = "".join(e.delta for e in events if e.type == "response.output_text.delta")
|
|
self.assertIn(self.EXPECTED_ANSWER, content, f"answer missing from content: {content!r}")
|
|
self.assertNotIn("<think>", content)
|
|
self.assertNotIn("<|channel>", content)
|
|
|
|
def test_response_multi_turn_round_trips_reasoning(self):
|
|
"""Responses API: ``reasoning`` items echoed back as input are accepted."""
|
|
first = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=self.USER_PROMPT,
|
|
stream=False,
|
|
max_output_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
reasoning_item = next((i for i in first.output if i.type == "reasoning"), None)
|
|
message_item = next((i for i in first.output if i.type == "message"), None)
|
|
self.assertIsNotNone(reasoning_item)
|
|
self.assertIsNotNone(message_item)
|
|
second = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[
|
|
{"role": "user", "content": self.USER_PROMPT},
|
|
reasoning_item.model_dump(exclude_none=True),
|
|
{"role": "assistant", "content": message_item.content[0].text},
|
|
{"role": "user", "content": "Now multiply that result by 2."},
|
|
],
|
|
stream=False,
|
|
max_output_tokens=self.MAX_TOKENS,
|
|
temperature=0.0,
|
|
)
|
|
self.assertEqual(second.status, "completed")
|
|
|
|
# ----- parser equivalence -----
|
|
|
|
def test_chat_streaming_matches_non_streaming(self):
|
|
"""Streaming and non-streaming chat completions yield the same content + reasoning at T=0."""
|
|
msgs = [{"role": "user", "content": self.USER_PROMPT}]
|
|
kwargs = {"model": self.MODEL, "max_tokens": self.MAX_TOKENS, "temperature": 0.0}
|
|
|
|
ns_msg = self.client.chat.completions.create(messages=msgs, stream=False, **kwargs).choices[0].message
|
|
chunks = list(self.client.chat.completions.create(messages=msgs, stream=True, **kwargs))
|
|
stream_content = "".join(c.choices[0].delta.content or "" for c in chunks)
|
|
stream_reasoning = "".join(self._reasoning_field(c.choices[0].delta) or "" for c in chunks)
|
|
|
|
self.assertEqual(stream_content, ns_msg.content or "")
|
|
self.assertEqual(stream_reasoning, self._reasoning_field(ns_msg) or "")
|
|
|
|
def test_response_streaming_matches_non_streaming(self):
|
|
"""Streaming and non-streaming Responses API yield the same content + reasoning at T=0."""
|
|
kwargs = {
|
|
"model": self.MODEL,
|
|
"input": self.USER_PROMPT,
|
|
"max_output_tokens": self.MAX_TOKENS,
|
|
"temperature": 0.0,
|
|
}
|
|
|
|
ns = self.client.responses.create(stream=False, **kwargs)
|
|
ns_message = next(i for i in ns.output if i.type == "message")
|
|
ns_reasoning_item = next((i for i in ns.output if i.type == "reasoning"), None)
|
|
ns_content = ns_message.content[0].text
|
|
ns_reasoning = ns_reasoning_item.content[0].text if ns_reasoning_item else ""
|
|
|
|
events = list(self.client.responses.create(stream=True, **kwargs))
|
|
stream_content = "".join(e.delta for e in events if e.type == "response.output_text.delta")
|
|
stream_reasoning = "".join(e.delta for e in events if e.type == "response.reasoning_text.delta")
|
|
|
|
self.assertEqual(stream_content, ns_content)
|
|
self.assertEqual(stream_reasoning, ns_reasoning)
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
@require_torch_accelerator
|
|
class TestReasoningQwen(_TestReasoningBase, unittest.TestCase):
|
|
"""Reasoning tests with Qwen3 (inline <think>...</think> tags)."""
|
|
|
|
MODEL = "Qwen/Qwen3-1.7B"
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
@require_torch_accelerator
|
|
class TestReasoningGemma(_TestReasoningBase, unittest.TestCase):
|
|
"""Reasoning tests with Gemma 4 (response_schema-based thinking channel)."""
|
|
|
|
MODEL = "google/gemma-4-E2B-it"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Gemma 4's chat template gates thinking on `enable_thinking`; default is "false".
|
|
cls.serve, port = _start_serve(reasoning="on")
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
|
|
@slow
|
|
@require_librosa
|
|
@require_multipart
|
|
@require_serve
|
|
class TestTranscription(unittest.TestCase):
|
|
"""Integration tests for POST /v1/audio/transcriptions with whisper-tiny."""
|
|
|
|
MODEL = "openai/whisper-tiny"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve()
|
|
cls.base_url = f"http://localhost:{port}"
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
@classmethod
|
|
def _get_audio_bytes(cls):
|
|
"""Download the MLK 'I have a dream' speech sample from HF Hub."""
|
|
if not hasattr(cls, "_audio_bytes"):
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
path = hf_hub_download("Narsil/asr_dummy", "mlk.flac", repo_type="dataset")
|
|
with open(path, "rb") as f:
|
|
cls._audio_bytes = f.read()
|
|
return cls._audio_bytes
|
|
|
|
def test_transcription_returns_text(self):
|
|
"""POST /v1/audio/transcriptions with real speech returns meaningful transcription."""
|
|
|
|
audio_bytes = self._get_audio_bytes()
|
|
resp = httpx.post(
|
|
f"{self.base_url}/v1/audio/transcriptions",
|
|
files={"file": ("mlk.flac", audio_bytes, "audio/flac")},
|
|
data={"model": self.MODEL},
|
|
timeout=120,
|
|
)
|
|
self.assertEqual(resp.status_code, 200)
|
|
data = resp.json()
|
|
self.assertIn("text", data)
|
|
self.assertIsInstance(data["text"], str)
|
|
# Whisper-tiny should recognize at least "dream" from the MLK speech
|
|
self.assertIn("dream", data["text"].lower())
|
|
|
|
def test_transcription_openai_client(self):
|
|
"""Transcription should work via the OpenAI Python client."""
|
|
audio_bytes = self._get_audio_bytes()
|
|
client = OpenAI(base_url=f"{self.base_url}/v1", api_key="unused")
|
|
result = client.audio.transcriptions.create(
|
|
model=self.MODEL,
|
|
file=("mlk.flac", audio_bytes),
|
|
)
|
|
self.assertIsInstance(result.text, str)
|
|
self.assertTrue(len(result.text) > 10)
|
|
|
|
def test_transcription_streaming(self):
|
|
"""Streaming transcription should yield text chunks via SSE."""
|
|
|
|
audio_bytes = self._get_audio_bytes()
|
|
with httpx.stream(
|
|
"POST",
|
|
f"{self.base_url}/v1/audio/transcriptions",
|
|
files={"file": ("mlk.flac", audio_bytes, "audio/flac")},
|
|
data={"model": self.MODEL, "stream": "true"},
|
|
timeout=120,
|
|
) as resp:
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
chunks = []
|
|
for line in resp.iter_lines():
|
|
if line and line.startswith("data: "):
|
|
chunks.append(line[len("data: ") :])
|
|
|
|
self.assertGreater(len(chunks), 0, "No streaming chunks received")
|
|
full_text = "".join(chunks)
|
|
self.assertIn("dream", full_text.lower())
|
|
|
|
def test_transcription_missing_file(self):
|
|
"""POST without a file should fail."""
|
|
|
|
resp = httpx.post(
|
|
f"{self.base_url}/v1/audio/transcriptions",
|
|
data={"model": self.MODEL},
|
|
timeout=30,
|
|
)
|
|
self.assertNotEqual(resp.status_code, 200)
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
@require_torch_accelerator
|
|
class TestContinuousBatchingChatCompletion(unittest.TestCase):
|
|
"""Integration tests for /v1/chat/completions with continuous batching."""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve(
|
|
force_model=cls.MODEL,
|
|
device="cuda:0",
|
|
continuous_batching=True,
|
|
attn_implementation="sdpa",
|
|
default_seed=42,
|
|
)
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def test_streaming(self):
|
|
"""Streaming chat completion with CB produces text."""
|
|
text = ""
|
|
for chunk in self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{"role": "system", "content": "You are a sports assistant designed to craft sports programs."},
|
|
{"role": "user", "content": "Tell me what you can do."},
|
|
],
|
|
stream=True,
|
|
max_tokens=30,
|
|
):
|
|
if chunk.choices[0].delta.content:
|
|
text += chunk.choices[0].delta.content
|
|
self.assertTrue(len(text) > 0)
|
|
|
|
def test_non_streaming(self):
|
|
"""Non-streaming chat completion with CB returns a full response."""
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Say hello"}],
|
|
max_tokens=20,
|
|
)
|
|
self.assertIsNotNone(resp.choices[0].message.content)
|
|
self.assertTrue(len(resp.choices[0].message.content) > 0)
|
|
|
|
def test_non_streaming_response_json_format(self):
|
|
"""Non-streaming CB responses return proper JSON objects, not double-encoded strings."""
|
|
response = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Say hello"}],
|
|
stream=False,
|
|
max_tokens=5,
|
|
)
|
|
self.assertIsNotNone(response)
|
|
self.assertIsNotNone(response.id)
|
|
self.assertIsNotNone(response.choices)
|
|
self.assertEqual(len(response.choices), 1)
|
|
|
|
choice = response.choices[0]
|
|
self.assertIsNotNone(choice.message)
|
|
self.assertIsNotNone(choice.message.content)
|
|
self.assertEqual(choice.message.role, "assistant")
|
|
self.assertIsInstance(choice.message.content, str)
|
|
|
|
def test_multi_turn(self):
|
|
"""Multi-turn conversation works with CB."""
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[
|
|
{"role": "user", "content": "My name is Alice"},
|
|
{"role": "assistant", "content": "Nice to meet you!"},
|
|
{"role": "user", "content": "What is my name?"},
|
|
],
|
|
max_tokens=20,
|
|
)
|
|
self.assertIn("Alice", resp.choices[0].message.content)
|
|
|
|
def test_request_cancellation(self):
|
|
"""Opening a stream and closing it early triggers CB cancellation."""
|
|
|
|
request_id = "test-cb-cancel"
|
|
|
|
# Open a streaming request and close after a few chunks
|
|
with httpx.stream(
|
|
"POST",
|
|
f"{self.base_url}/v1/chat/completions",
|
|
headers={"X-Request-ID": request_id},
|
|
json={
|
|
"model": self.MODEL,
|
|
"stream": True,
|
|
"messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
|
|
},
|
|
timeout=30,
|
|
) as resp:
|
|
self.assertEqual(resp.status_code, 200)
|
|
chunks_read = 0
|
|
for _ in resp.iter_lines():
|
|
chunks_read += 1
|
|
if chunks_read >= 3:
|
|
break
|
|
|
|
# Poll for cancellation in the CB scheduler
|
|
scheduler = self.serve._generation_state._cb_manager.scheduler
|
|
deadline = time.time() + 8.0
|
|
while time.time() < deadline:
|
|
if scheduler.request_is_cancelled(request_id):
|
|
break
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(
|
|
scheduler.request_is_cancelled(request_id),
|
|
f"Request {request_id} not cancelled in scheduler after stream close.",
|
|
)
|
|
|
|
# Server should still be healthy and serve subsequent requests
|
|
resp = self.client.chat.completions.create(
|
|
model=self.MODEL,
|
|
messages=[{"role": "user", "content": "Say hi"}],
|
|
max_tokens=10,
|
|
)
|
|
self.assertIsNotNone(resp.choices[0].message.content)
|
|
|
|
|
|
@slow
|
|
@require_serve
|
|
@require_torch_accelerator
|
|
class TestContinuousBatchingResponses(unittest.TestCase):
|
|
"""Integration tests for /v1/responses with continuous batching."""
|
|
|
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.serve, port = _start_serve(
|
|
force_model=cls.MODEL,
|
|
device="cuda:0",
|
|
continuous_batching=True,
|
|
attn_implementation="sdpa",
|
|
default_seed=42,
|
|
)
|
|
cls.base_url = f"http://localhost:{port}"
|
|
cls.client = OpenAI(base_url=f"{cls.base_url}/v1", api_key="unused")
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.serve.kill_server()
|
|
|
|
def test_streaming(self):
|
|
"""Streaming response with CB produces text."""
|
|
text = ""
|
|
stream = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hello in one sentence.",
|
|
stream=True,
|
|
max_output_tokens=30,
|
|
)
|
|
for event in stream:
|
|
if event.type == "response.output_text.delta":
|
|
text += event.delta
|
|
self.assertTrue(len(text) > 0)
|
|
|
|
def test_non_streaming(self):
|
|
"""Non-streaming response with CB returns text."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hello in one sentence.",
|
|
stream=False,
|
|
max_output_tokens=30,
|
|
)
|
|
content = resp.output[0].content[0].text
|
|
self.assertTrue(len(content) > 0)
|
|
|
|
def test_multi_turn(self):
|
|
"""Multi-turn conversation works with CB via Responses API."""
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input=[
|
|
{"role": "user", "content": "My name is Alice"},
|
|
{"role": "assistant", "content": "Nice to meet you!"},
|
|
{"role": "user", "content": "What is my name?"},
|
|
],
|
|
stream=False,
|
|
max_output_tokens=20,
|
|
)
|
|
content = resp.output[0].content[0].text
|
|
self.assertIn("Alice", content)
|
|
|
|
def test_request_cancellation(self):
|
|
"""Opening a stream and closing it early triggers CB cancellation."""
|
|
|
|
request_id = "test-cb-resp-cancel"
|
|
|
|
with httpx.stream(
|
|
"POST",
|
|
f"{self.base_url}/v1/responses",
|
|
headers={"X-Request-ID": request_id},
|
|
json={
|
|
"model": self.MODEL,
|
|
"stream": True,
|
|
"input": "Count slowly so I can cancel you.",
|
|
"max_output_tokens": 500,
|
|
},
|
|
timeout=30,
|
|
) as resp:
|
|
self.assertEqual(resp.status_code, 200)
|
|
# Read enough data to ensure CB generation has started, then close.
|
|
received = b""
|
|
for chunk in resp.iter_bytes(chunk_size=512):
|
|
received += chunk
|
|
if b"output_text.delta" in received:
|
|
break
|
|
|
|
# Poll for cancellation in the CB scheduler
|
|
scheduler = self.serve._generation_state._cb_manager.scheduler
|
|
deadline = time.time() + 8.0
|
|
while time.time() < deadline:
|
|
if scheduler.request_is_cancelled(request_id):
|
|
break
|
|
time.sleep(0.1)
|
|
|
|
self.assertTrue(
|
|
scheduler.request_is_cancelled(request_id),
|
|
f"Request {request_id} not cancelled in scheduler after stream close.",
|
|
)
|
|
|
|
# Server should still serve subsequent requests
|
|
resp = self.client.responses.create(
|
|
model=self.MODEL,
|
|
input="Say hi",
|
|
stream=False,
|
|
max_output_tokens=10,
|
|
)
|
|
self.assertTrue(len(resp.output[0].content[0].text) > 0)
|