first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
608
tests/models/sam3_video/test_modeling_sam3_video.py
Normal file
608
tests/models/sam3_video/test_modeling_sam3_video.py
Normal file
@@ -0,0 +1,608 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch SAM3 Video model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_bf16_available_on_device,
|
||||
is_torch_fp16_available_on_device,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available
|
||||
from transformers.video_utils import load_video
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Sam3VideoModel, Sam3VideoProcessor
|
||||
|
||||
|
||||
def prepare_video():
|
||||
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
|
||||
raw_video, _ = load_video(video_url)
|
||||
return raw_video
|
||||
|
||||
|
||||
@slow
|
||||
class Sam3VideoModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
checkpoint_path = "facebook/sam3"
|
||||
self.video_model = Sam3VideoModel.from_pretrained(checkpoint_path).to(torch.float32)
|
||||
self.processor = Sam3VideoProcessor.from_pretrained(checkpoint_path)
|
||||
self.video_model.to(torch_device)
|
||||
self.video_model.eval()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_inference_video_propagate_with_text_prompt(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(
|
||||
video=raw_video,
|
||||
inference_device=torch_device,
|
||||
processing_device="cpu",
|
||||
video_storage_device="cpu",
|
||||
)
|
||||
|
||||
# Add text prompt
|
||||
text = "person"
|
||||
inference_session = self.processor.add_text_prompt(
|
||||
inference_session=inference_session,
|
||||
text=text,
|
||||
)
|
||||
|
||||
# Propagate through video frames
|
||||
outputs_per_frame = {}
|
||||
model_outputs_per_frame = {}
|
||||
for model_outputs in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
max_frame_num_to_track=3,
|
||||
):
|
||||
processed_outputs = self.processor.postprocess_outputs(inference_session, model_outputs)
|
||||
outputs_per_frame[model_outputs.frame_idx] = processed_outputs
|
||||
model_outputs_per_frame[model_outputs.frame_idx] = model_outputs
|
||||
|
||||
# Check we processed the expected number of frames
|
||||
self.assertGreaterEqual(len(outputs_per_frame), 1)
|
||||
self.assertLessEqual(len(outputs_per_frame), 4) # frame 0 + up to 3 more
|
||||
|
||||
# Check output structure for each frame
|
||||
for processed_outputs in outputs_per_frame.values():
|
||||
self.assertIn("object_ids", processed_outputs)
|
||||
self.assertIn("scores", processed_outputs)
|
||||
self.assertIn("boxes", processed_outputs)
|
||||
self.assertIn("masks", processed_outputs)
|
||||
|
||||
num_objects = len(processed_outputs["object_ids"])
|
||||
if num_objects > 0:
|
||||
self.assertEqual(processed_outputs["scores"].shape, (num_objects,))
|
||||
self.assertEqual(processed_outputs["boxes"].shape, (num_objects, 4))
|
||||
self.assertEqual(
|
||||
processed_outputs["masks"].shape, (num_objects, raw_video.shape[-3], raw_video.shape[-2])
|
||||
)
|
||||
# Check boxes are in XYXY format (absolute coordinates)
|
||||
boxes = processed_outputs["boxes"]
|
||||
self.assertTrue(torch.all(boxes[:, 2] >= boxes[:, 0])) # x2 >= x1
|
||||
self.assertTrue(torch.all(boxes[:, 3] >= boxes[:, 1])) # y2 >= y1
|
||||
|
||||
# Check numeric values for first frame
|
||||
if len(outputs_per_frame) > 0:
|
||||
first_frame_idx = min(outputs_per_frame.keys())
|
||||
first_outputs = outputs_per_frame[first_frame_idx]
|
||||
num_objects = len(first_outputs["object_ids"])
|
||||
if num_objects > 0:
|
||||
# Move outputs to CPU for comparison (postprocess_outputs may return CPU tensors)
|
||||
object_ids = (
|
||||
first_outputs["object_ids"].cpu()
|
||||
if isinstance(first_outputs["object_ids"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["object_ids"])
|
||||
)
|
||||
scores = (
|
||||
first_outputs["scores"].cpu()
|
||||
if isinstance(first_outputs["scores"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["scores"])
|
||||
)
|
||||
boxes = (
|
||||
first_outputs["boxes"].cpu()
|
||||
if isinstance(first_outputs["boxes"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["boxes"])
|
||||
)
|
||||
masks = (
|
||||
first_outputs["masks"].cpu()
|
||||
if isinstance(first_outputs["masks"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["masks"])
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
object_ids,
|
||||
torch.tensor([0, 1], dtype=torch.int64),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
scores,
|
||||
torch.tensor([0.968647837638855, 0.9736108779907227], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
boxes[0],
|
||||
torch.tensor([146.0, 135.0, 291.0, 404.0], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
masks[0, :3, :3].float(),
|
||||
torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# Check raw model_outputs mask values for first frame
|
||||
if len(model_outputs_per_frame) > 0:
|
||||
first_frame_idx = min(model_outputs_per_frame.keys())
|
||||
first_model_outputs = model_outputs_per_frame[first_frame_idx]
|
||||
num_objects = len(first_model_outputs.object_ids)
|
||||
if num_objects > 0:
|
||||
# Check raw mask from model_outputs (low-resolution, before post-processing)
|
||||
first_obj_id = first_model_outputs.object_ids[0]
|
||||
raw_mask = first_model_outputs.obj_id_to_mask[first_obj_id].cpu()
|
||||
torch.testing.assert_close(
|
||||
raw_mask[:1, :3, :3].float(),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[-2.952317476272583, -5.94632625579834, -7.991223335266113],
|
||||
[-6.916913986206055, -10.058566093444824, -11.114638328552246],
|
||||
[-8.195585250854492, -9.787644386291504, -10.39273452758789],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
atol=5e-3, # Higher tolerance for raw logits
|
||||
rtol=5e-3,
|
||||
)
|
||||
|
||||
# Check numeric values for last frame (to verify propagation consistency)
|
||||
if len(outputs_per_frame) > 1:
|
||||
last_frame_idx = max(outputs_per_frame.keys())
|
||||
last_outputs = outputs_per_frame[last_frame_idx]
|
||||
num_objects = len(last_outputs["object_ids"])
|
||||
if num_objects > 0:
|
||||
# Move outputs to CPU for comparison
|
||||
object_ids = (
|
||||
last_outputs["object_ids"].cpu()
|
||||
if isinstance(last_outputs["object_ids"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["object_ids"])
|
||||
)
|
||||
scores = (
|
||||
last_outputs["scores"].cpu()
|
||||
if isinstance(last_outputs["scores"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["scores"])
|
||||
)
|
||||
boxes = (
|
||||
last_outputs["boxes"].cpu()
|
||||
if isinstance(last_outputs["boxes"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["boxes"])
|
||||
)
|
||||
masks = (
|
||||
last_outputs["masks"].cpu()
|
||||
if isinstance(last_outputs["masks"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["masks"])
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
object_ids,
|
||||
torch.tensor([0, 1], dtype=torch.int64),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
scores,
|
||||
torch.tensor([0.968647837638855, 0.9736108779907227], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
boxes[0],
|
||||
torch.tensor([157.0, 116.0, 295.0, 382.0], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
masks[0, :3, :3].float(),
|
||||
torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# Check raw model_outputs mask values for last frame
|
||||
if len(model_outputs_per_frame) > 1:
|
||||
last_frame_idx = max(model_outputs_per_frame.keys())
|
||||
last_model_outputs = model_outputs_per_frame[last_frame_idx]
|
||||
num_objects = len(last_model_outputs.object_ids)
|
||||
if num_objects > 0:
|
||||
# Check raw mask from model_outputs (low-resolution, before post-processing)
|
||||
first_obj_id = last_model_outputs.object_ids[0]
|
||||
raw_mask = last_model_outputs.obj_id_to_mask[first_obj_id].cpu()
|
||||
torch.testing.assert_close(
|
||||
raw_mask[:1, :3, :3].float(),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[-23.023313522338867, -27.02887535095215, -22.29985237121582],
|
||||
[-24.373233795166016, -31.428438186645508, -24.268810272216797],
|
||||
[-24.550016403198242, -32.607383728027344, -26.500947952270508],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
atol=5e-3, # Higher tolerance for raw logits
|
||||
rtol=5e-3,
|
||||
)
|
||||
|
||||
def test_inference_video_streaming_with_text_prompt(self):
|
||||
raw_video = prepare_video()
|
||||
|
||||
# Initialize session for streaming (no video provided)
|
||||
inference_session = self.processor.init_video_session(
|
||||
inference_device=torch_device,
|
||||
processing_device="cpu",
|
||||
video_storage_device="cpu",
|
||||
)
|
||||
|
||||
# Add text prompt
|
||||
text = "person"
|
||||
inference_session = self.processor.add_text_prompt(
|
||||
inference_session=inference_session,
|
||||
text=text,
|
||||
)
|
||||
|
||||
# Process frames one by one (streaming mode)
|
||||
outputs_per_frame = {}
|
||||
model_outputs_per_frame = {}
|
||||
max_frame_num_to_track = 3
|
||||
for frame_idx, frame in enumerate(raw_video):
|
||||
if frame_idx >= max_frame_num_to_track:
|
||||
break
|
||||
|
||||
# Process frame using processor
|
||||
inputs = self.processor(images=frame, device=torch_device, return_tensors="pt")
|
||||
|
||||
# Process frame using streaming inference
|
||||
model_outputs = self.video_model(
|
||||
inference_session=inference_session,
|
||||
frame=inputs.pixel_values[0], # Provide processed frame - this enables streaming mode
|
||||
reverse=False,
|
||||
)
|
||||
|
||||
# Post-process outputs with original_sizes for proper resolution handling
|
||||
processed_outputs = self.processor.postprocess_outputs(
|
||||
inference_session,
|
||||
model_outputs,
|
||||
original_sizes=inputs.original_sizes, # Required for streaming inference
|
||||
)
|
||||
outputs_per_frame[frame_idx] = processed_outputs
|
||||
model_outputs_per_frame[frame_idx] = model_outputs
|
||||
|
||||
# Check we processed the expected number of frames
|
||||
self.assertEqual(len(outputs_per_frame), max_frame_num_to_track)
|
||||
|
||||
# Check output structure for each frame
|
||||
for frame_idx, processed_outputs in outputs_per_frame.items():
|
||||
self.assertIn("object_ids", processed_outputs)
|
||||
self.assertIn("scores", processed_outputs)
|
||||
self.assertIn("boxes", processed_outputs)
|
||||
self.assertIn("masks", processed_outputs)
|
||||
|
||||
num_objects = len(processed_outputs["object_ids"])
|
||||
if num_objects > 0:
|
||||
self.assertEqual(processed_outputs["scores"].shape, (num_objects,))
|
||||
self.assertEqual(processed_outputs["boxes"].shape, (num_objects, 4))
|
||||
# For streaming, masks should be at original frame resolution
|
||||
H_orig, W_orig = raw_video[frame_idx].shape[0], raw_video[frame_idx].shape[1]
|
||||
self.assertEqual(processed_outputs["masks"].shape, (num_objects, H_orig, W_orig))
|
||||
# Check boxes are in XYXY format (absolute coordinates)
|
||||
boxes = processed_outputs["boxes"]
|
||||
self.assertTrue(torch.all(boxes[:, 2] >= boxes[:, 0])) # x2 >= x1
|
||||
self.assertTrue(torch.all(boxes[:, 3] >= boxes[:, 1])) # y2 >= y1
|
||||
|
||||
# Check numeric values for first frame
|
||||
if len(outputs_per_frame) > 0:
|
||||
first_frame_idx = min(outputs_per_frame.keys())
|
||||
first_outputs = outputs_per_frame[first_frame_idx]
|
||||
num_objects = len(first_outputs["object_ids"])
|
||||
if num_objects > 0:
|
||||
# Move outputs to CPU for comparison (postprocess_outputs may return CPU tensors)
|
||||
object_ids = (
|
||||
first_outputs["object_ids"].cpu()
|
||||
if isinstance(first_outputs["object_ids"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["object_ids"])
|
||||
)
|
||||
scores = (
|
||||
first_outputs["scores"].cpu()
|
||||
if isinstance(first_outputs["scores"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["scores"])
|
||||
)
|
||||
boxes = (
|
||||
first_outputs["boxes"].cpu()
|
||||
if isinstance(first_outputs["boxes"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["boxes"])
|
||||
)
|
||||
masks = (
|
||||
first_outputs["masks"].cpu()
|
||||
if isinstance(first_outputs["masks"], torch.Tensor)
|
||||
else torch.tensor(first_outputs["masks"])
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
object_ids,
|
||||
torch.tensor([0, 1], dtype=torch.int64),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
scores,
|
||||
torch.tensor([0.9683944582939148, 0.9740181565284729], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
boxes[0],
|
||||
torch.tensor([146.0, 135.0, 291.0, 404.0], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
masks[0, :3, :3].float(),
|
||||
torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# Check raw model_outputs mask values for first frame
|
||||
if len(model_outputs_per_frame) > 0:
|
||||
first_frame_idx = min(model_outputs_per_frame.keys())
|
||||
first_model_outputs = model_outputs_per_frame[first_frame_idx]
|
||||
num_objects = len(first_model_outputs.object_ids)
|
||||
if num_objects > 0:
|
||||
# Check raw mask from model_outputs (low-resolution, before post-processing)
|
||||
first_obj_id = first_model_outputs.object_ids[0]
|
||||
raw_mask = first_model_outputs.obj_id_to_mask[first_obj_id].cpu()
|
||||
torch.testing.assert_close(
|
||||
raw_mask[:1, :3, :3].float(),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[-2.987567901611328, -5.944897651672363, -7.973854064941406],
|
||||
[-7.017378330230713, -10.088018417358398, -11.089308738708496],
|
||||
[-8.274458885192871, -9.851463317871094, -10.428947448730469],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
atol=5e-3, # Higher tolerance for raw logits
|
||||
rtol=5e-3,
|
||||
)
|
||||
|
||||
# Check numeric values for last frame (to verify propagation consistency)
|
||||
if len(outputs_per_frame) > 1:
|
||||
last_frame_idx = max(outputs_per_frame.keys())
|
||||
last_outputs = outputs_per_frame[last_frame_idx]
|
||||
num_objects = len(last_outputs["object_ids"])
|
||||
if num_objects > 0:
|
||||
# Move outputs to CPU for comparison
|
||||
object_ids = (
|
||||
last_outputs["object_ids"].cpu()
|
||||
if isinstance(last_outputs["object_ids"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["object_ids"])
|
||||
)
|
||||
scores = (
|
||||
last_outputs["scores"].cpu()
|
||||
if isinstance(last_outputs["scores"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["scores"])
|
||||
)
|
||||
boxes = (
|
||||
last_outputs["boxes"].cpu()
|
||||
if isinstance(last_outputs["boxes"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["boxes"])
|
||||
)
|
||||
masks = (
|
||||
last_outputs["masks"].cpu()
|
||||
if isinstance(last_outputs["masks"], torch.Tensor)
|
||||
else torch.tensor(last_outputs["masks"])
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
object_ids,
|
||||
torch.tensor([0, 1], dtype=torch.int64),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
scores,
|
||||
torch.tensor([0.9683944582939148, 0.9740181565284729], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
boxes[0],
|
||||
torch.tensor([154.0, 117.0, 294.0, 395.0], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
masks[0, :3, :3].float(),
|
||||
torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], dtype=torch.float32),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# Check raw model_outputs mask values for last frame
|
||||
if len(model_outputs_per_frame) > 1:
|
||||
last_frame_idx = max(model_outputs_per_frame.keys())
|
||||
last_model_outputs = model_outputs_per_frame[last_frame_idx]
|
||||
num_objects = len(last_model_outputs.object_ids)
|
||||
if num_objects > 0:
|
||||
# Check raw mask from model_outputs (low-resolution, before post-processing)
|
||||
first_obj_id = last_model_outputs.object_ids[0]
|
||||
raw_mask = last_model_outputs.obj_id_to_mask[first_obj_id].cpu()
|
||||
torch.testing.assert_close(
|
||||
raw_mask[:1, :3, :3].float(),
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
[-23.935535430908203, -27.967025756835938, -23.519914627075195],
|
||||
[-25.742399215698242, -32.65046310424805, -24.71213150024414],
|
||||
[-25.263212203979492, -33.807132720947266, -27.463823318481445],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
atol=5e-3, # Higher tolerance for raw logits
|
||||
rtol=5e-3,
|
||||
)
|
||||
|
||||
def test_inference_video_multi_prompt(self):
|
||||
"""Test multi-prompt tracking - detecting multiple object categories in one pass."""
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(
|
||||
video=raw_video,
|
||||
inference_device=torch_device,
|
||||
processing_device="cpu",
|
||||
video_storage_device="cpu",
|
||||
)
|
||||
|
||||
# Add multiple text prompts
|
||||
prompts = ["person", "bed"]
|
||||
self.processor.add_text_prompt(
|
||||
inference_session=inference_session,
|
||||
text=prompts,
|
||||
)
|
||||
|
||||
# Propagate through video frames
|
||||
outputs_per_frame = {}
|
||||
for model_outputs in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
max_frame_num_to_track=3,
|
||||
):
|
||||
processed_outputs = self.processor.postprocess_outputs(inference_session, model_outputs)
|
||||
outputs_per_frame[model_outputs.frame_idx] = processed_outputs
|
||||
|
||||
# Check we processed the expected number of frames
|
||||
self.assertGreaterEqual(len(outputs_per_frame), 1)
|
||||
self.assertLessEqual(len(outputs_per_frame), 4)
|
||||
|
||||
# Check output structure for each frame
|
||||
for processed_outputs in outputs_per_frame.values():
|
||||
self.assertIn("object_ids", processed_outputs)
|
||||
self.assertIn("scores", processed_outputs)
|
||||
self.assertIn("boxes", processed_outputs)
|
||||
self.assertIn("masks", processed_outputs)
|
||||
self.assertIn("prompt_to_obj_ids", processed_outputs) # Multi-prompt specific
|
||||
|
||||
# Check prompt_to_obj_ids structure
|
||||
prompt_to_obj_ids = processed_outputs["prompt_to_obj_ids"]
|
||||
self.assertIsInstance(prompt_to_obj_ids, dict)
|
||||
for prompt, obj_ids in prompt_to_obj_ids.items():
|
||||
self.assertIsInstance(prompt, str)
|
||||
self.assertIsInstance(obj_ids, list)
|
||||
# Each object ID should be in the main object_ids list
|
||||
for obj_id in obj_ids:
|
||||
self.assertIn(obj_id, processed_outputs["object_ids"].tolist())
|
||||
|
||||
# Check that we detected objects from multiple prompts
|
||||
first_frame_outputs = outputs_per_frame[min(outputs_per_frame.keys())]
|
||||
prompt_to_obj_ids = first_frame_outputs["prompt_to_obj_ids"]
|
||||
|
||||
# Should have at least one prompt with detections
|
||||
self.assertGreater(len(prompt_to_obj_ids), 0)
|
||||
|
||||
# All prompts in prompt_to_obj_ids should be from our original prompts
|
||||
for prompt in prompt_to_obj_ids.keys():
|
||||
self.assertIn(prompt, prompts)
|
||||
|
||||
def test_custom_image_size(self):
|
||||
"""Test that custom image size can be set and propagates correctly to detector and tracker configs."""
|
||||
from transformers import Sam3VideoConfig
|
||||
|
||||
config = Sam3VideoConfig.from_pretrained("facebook/sam3")
|
||||
config.image_size = 560
|
||||
|
||||
self.assertEqual(config.image_size, 560)
|
||||
self.assertEqual(config.detector_config.image_size, 560)
|
||||
self.assertEqual(config.tracker_config.image_size, 560)
|
||||
self.assertEqual(config.detector_config.vision_config.image_size, 560)
|
||||
self.assertEqual(config.detector_config.vision_config.backbone_config.image_size, 560)
|
||||
|
||||
model = Sam3VideoModel.from_pretrained("facebook/sam3", config=config).to(torch_device).eval()
|
||||
self.assertEqual(model.config.image_size, 560)
|
||||
|
||||
def test_inference_with_different_dtypes(self):
|
||||
"""Test that inference works correctly for float32, bfloat16, and float16 dtypes."""
|
||||
raw_video = prepare_video()
|
||||
dtypes_to_test = [
|
||||
(torch.float32, None), # float32 is always available
|
||||
(torch.bfloat16, is_torch_bf16_available_on_device),
|
||||
(torch.float16, is_torch_fp16_available_on_device),
|
||||
]
|
||||
|
||||
for dtype, availability_check in dtypes_to_test:
|
||||
with self.subTest(dtype=dtype):
|
||||
# Skip if dtype is not available on device
|
||||
if availability_check is not None and not availability_check(torch_device):
|
||||
self.skipTest(f"{dtype} not supported on {torch_device}")
|
||||
|
||||
# Load model with specific dtype
|
||||
video_model = Sam3VideoModel.from_pretrained("facebook/sam3", torch_dtype=dtype).to(torch_device)
|
||||
video_model.eval()
|
||||
|
||||
# Initialize inference session
|
||||
inference_session = self.processor.init_video_session(
|
||||
video=raw_video,
|
||||
inference_device=torch_device,
|
||||
processing_device="cpu",
|
||||
video_storage_device="cpu",
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Add text prompt
|
||||
text = "person"
|
||||
inference_session = self.processor.add_text_prompt(
|
||||
inference_session=inference_session,
|
||||
text=text,
|
||||
)
|
||||
|
||||
# Run inference on first frame
|
||||
outputs_per_frame = {}
|
||||
model_outputs_per_frame = {}
|
||||
max_frame_num_to_track = 2
|
||||
for model_outputs in video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
max_frame_num_to_track=max_frame_num_to_track,
|
||||
):
|
||||
processed_outputs = self.processor.postprocess_outputs(inference_session, model_outputs)
|
||||
outputs_per_frame[model_outputs.frame_idx] = processed_outputs
|
||||
model_outputs_per_frame[model_outputs.frame_idx] = model_outputs
|
||||
|
||||
# Verify dtype is maintained in model outputs
|
||||
if len(model_outputs.object_ids) > 0:
|
||||
first_obj_id = model_outputs.object_ids[0]
|
||||
raw_mask = model_outputs.obj_id_to_mask[first_obj_id]
|
||||
self.assertEqual(raw_mask.dtype, dtype)
|
||||
|
||||
# Verify we processed frames
|
||||
self.assertGreaterEqual(len(outputs_per_frame), 1)
|
||||
self.assertLessEqual(len(outputs_per_frame), max_frame_num_to_track + 1)
|
||||
Reference in New Issue
Block a user