import importlib import sys import tempfile import types import unittest from pathlib import Path def _import_remove_background(): if "remove_background" in sys.modules: del sys.modules["remove_background"] rembg_stub = types.ModuleType("rembg") rembg_stub.remove = lambda *args, **kwargs: None rembg_stub.new_session = lambda *args, **kwargs: None sys.modules["rembg"] = rembg_stub return importlib.import_module("remove_background") class RemoveBackgroundTests(unittest.TestCase): def setUp(self): self.mod = _import_remove_background() def _make_paper_image(self, width=160, height=120): rng = self.mod.np.random.default_rng(0) base = self.mod.np.full((height, width, 3), 230, dtype=self.mod.np.uint8) noise = rng.integers(-8, 9, size=(height, width, 3), dtype=self.mod.np.int16) image = self.mod.np.clip(base.astype(self.mod.np.int16) + noise, 0, 255).astype( self.mod.np.uint8 ) return image def test_alpha_to_mask_threshold(self): alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8) mask = self.mod._alpha_to_mask(alpha, threshold=10) self.assertEqual(mask.shape, alpha.shape) self.assertEqual(mask[0, 0], 0) self.assertEqual(mask[0, 1], 0) self.assertEqual(mask[0, 2], 0) self.assertEqual(mask[0, 3], 255) def test_prepare_mask_outputs_binary(self): mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8) mask[2, 2] = 255 mask_hard, mask_used = self.mod._prepare_mask(mask, mask_dilate=0, mask_blur=3) self.assertEqual(mask_hard.shape, mask.shape) self.assertEqual(mask_used.shape, mask.shape) self.assertTrue(self.mod.np.all(self.mod.np.isin(mask_used, [0, 255]))) def test_resolve_aot_paths_defaults(self): aot_root, aot_pretrain = self.mod._resolve_aot_paths( "AOT-GAN-for-Inpainting", None ) self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting")) self.assertIsNone(aot_pretrain) def test_resolve_aot_paths_relative_pretrain(self): aot_root, aot_pretrain = self.mod._resolve_aot_paths( "AOT-GAN-for-Inpainting", "experiments/foo.pt" ) self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting")) self.assertTrue(str(aot_pretrain).endswith("experiments/foo.pt")) def test_parse_aot_rates(self): rates = self.mod._parse_aot_rates("1+2+4+8") self.assertEqual(rates, [1, 2, 4, 8]) def test_mask_bbox_empty(self): mask = self.mod.np.zeros((3, 3), dtype=self.mod.np.uint8) self.assertIsNone(self.mod._mask_bbox(mask)) def test_mask_bbox_and_expand(self): mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8) mask[1:3, 2:4] = 255 bbox = self.mod._mask_bbox(mask) self.assertEqual(bbox, (2, 1, 3, 2)) expanded = self.mod._expand_bbox(bbox, 2, 5, 5) self.assertEqual(expanded, (0, 0, 4, 4)) expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4) self.assertEqual(expanded_min, (1, 0, 4, 3)) def test_remove_border_frame_components_keeps_inner_content(self): mask = self.mod.np.zeros((12, 12), dtype=self.mod.np.uint8) mask[0, :] = 255 mask[-1, :] = 255 mask[:, 0] = 255 mask[:, -1] = 255 mask[4:8, 5:7] = 255 cleaned = self.mod._remove_border_frame_components( mask, min_width_ratio=0.7, min_height_ratio=0.7, max_fill_ratio=0.35, ) self.assertEqual(int(cleaned[0, 0]), 0) self.assertGreater(cleaned[4:8, 5:7].mean(), 200) def test_ensure_rgba_size_resizes(self): img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0)) out = self.mod._ensure_rgba_size(img, (4, 5)) self.assertEqual(out.mode, "RGBA") self.assertEqual(out.size, (4, 5)) def test_extract_artwork_mask_detects_calligraphy_strokes(self): image = self._make_paper_image() self.mod.cv2.line(image, (20, 20), (130, 90), (20, 20, 20), thickness=6) self.mod.cv2.line(image, (45, 15), (45, 105), (30, 30, 30), thickness=5) mask = self.mod._extract_artwork_mask( self.mod.Image.fromarray(image), artwork_type="calligraphy", max_size=120, ) self.assertGreater(mask[25:95, 42:49].mean(), 180) self.assertLess(mask[:15, :15].mean(), 20) def test_extract_artwork_mask_detects_red_seal(self): image = self._make_paper_image() self.mod.cv2.rectangle(image, (40, 30), (120, 95), (165, 30, 40), thickness=5) self.mod.cv2.line(image, (55, 40), (105, 85), (170, 25, 35), thickness=6) mask = self.mod._extract_artwork_mask( self.mod.Image.fromarray(image), artwork_type="seal", max_size=120, ) self.assertGreater(mask[38:92, 38:122].mean(), 40) self.assertLess(mask[:15, :15].mean(), 20) def test_extract_artwork_mask_ignores_warm_paper_cast(self): image = self.mod.np.full( (120, 160, 3), (222, 188, 150), dtype=self.mod.np.uint8 ) self.mod.cv2.line(image, (30, 20), (30, 100), (28, 28, 28), thickness=6) self.mod.cv2.line(image, (55, 20), (125, 95), (32, 32, 32), thickness=5) mask = self.mod._extract_artwork_mask( self.mod.Image.fromarray(image), artwork_type="calligraphy", max_size=120, ) self.assertGreater(mask[25:100, 26:34].mean(), 180) self.assertLess(mask[:12, :12].mean(), 20) def test_mask_to_transparent_image_refines_alpha_inside_coarse_mask(self): image = self.mod.np.full( (120, 160, 3), (224, 190, 156), dtype=self.mod.np.uint8 ) self.mod.cv2.line(image, (40, 18), (40, 102), (24, 24, 24), thickness=6) self.mod.cv2.rectangle(image, (102, 24), (132, 54), (170, 30, 40), thickness=-1) coarse_mask = self.mod.np.zeros((120, 160), dtype=self.mod.np.uint8) coarse_mask[10:110, 20:140] = 255 out = self.mod._mask_to_transparent_image( self.mod.Image.fromarray(image), coarse_mask, ) alpha = self.mod.np.array(out.getchannel("A")) self.assertGreater(alpha[25:95, 36:45].mean(), 110) self.assertGreater(alpha[30:50, 108:128].mean(), 90) self.assertLess(alpha[70:90, 70:90].mean(), 25) def test_remove_background_artwork_mode_skips_rembg(self): image = self._make_paper_image(width=100, height=80) self.mod.cv2.line(image, (15, 15), (85, 60), (20, 20, 20), thickness=5) original_remove = self.mod.remove def _unexpected_remove(*args, **kwargs): raise AssertionError("artwork 模式不应调用 rembg") self.mod.remove = _unexpected_remove try: with tempfile.TemporaryDirectory() as tmpdir: input_path = Path(tmpdir) / "input.png" output_path = Path(tmpdir) / "output.png" self.mod.Image.fromarray(image).save(input_path) self.mod.remove_background( str(input_path), str(output_path), foreground_mode="artwork", artwork_type="calligraphy", artwork_max_size=120, ) out = self.mod.Image.open(output_path) alpha = self.mod.np.array(out.getchannel("A")) self.assertEqual(out.mode, "RGBA") self.assertGreater(alpha[30:55, 35:70].mean(), 70) self.assertLess(alpha[:10, :10].mean(), 10) finally: self.mod.remove = original_remove if __name__ == "__main__": unittest.main()