79 lines
2.9 KiB
Python
79 lines
2.9 KiB
Python
import importlib
|
|
import sys
|
|
import types
|
|
import unittest
|
|
|
|
|
|
def _import_remove_background():
|
|
if "remove_background" in sys.modules:
|
|
del sys.modules["remove_background"]
|
|
rembg_stub = types.ModuleType("rembg")
|
|
rembg_stub.remove = lambda *args, **kwargs: None
|
|
rembg_stub.new_session = lambda *args, **kwargs: None
|
|
sys.modules["rembg"] = rembg_stub
|
|
return importlib.import_module("remove_background")
|
|
|
|
|
|
class RemoveBackgroundTests(unittest.TestCase):
|
|
def setUp(self):
|
|
self.mod = _import_remove_background()
|
|
|
|
def test_alpha_to_mask_threshold(self):
|
|
alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8)
|
|
mask = self.mod._alpha_to_mask(alpha, threshold=10)
|
|
self.assertEqual(mask.shape, alpha.shape)
|
|
self.assertEqual(mask[0, 0], 0)
|
|
self.assertEqual(mask[0, 1], 0)
|
|
self.assertEqual(mask[0, 2], 0)
|
|
self.assertEqual(mask[0, 3], 255)
|
|
|
|
def test_prepare_mask_outputs_binary(self):
|
|
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
|
|
mask[2, 2] = 255
|
|
mask_hard, mask_used = self.mod._prepare_mask(mask, mask_dilate=0, mask_blur=3)
|
|
self.assertEqual(mask_hard.shape, mask.shape)
|
|
self.assertEqual(mask_used.shape, mask.shape)
|
|
self.assertTrue(self.mod.np.all(self.mod.np.isin(mask_used, [0, 255])))
|
|
|
|
def test_resolve_aot_paths_defaults(self):
|
|
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
|
|
"AOT-GAN-for-Inpainting", None
|
|
)
|
|
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
|
|
self.assertIsNone(aot_pretrain)
|
|
|
|
def test_resolve_aot_paths_relative_pretrain(self):
|
|
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
|
|
"AOT-GAN-for-Inpainting", "experiments/foo.pt"
|
|
)
|
|
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
|
|
self.assertTrue(str(aot_pretrain).endswith("experiments/foo.pt"))
|
|
|
|
def test_parse_aot_rates(self):
|
|
rates = self.mod._parse_aot_rates("1+2+4+8")
|
|
self.assertEqual(rates, [1, 2, 4, 8])
|
|
|
|
def test_mask_bbox_empty(self):
|
|
mask = self.mod.np.zeros((3, 3), dtype=self.mod.np.uint8)
|
|
self.assertIsNone(self.mod._mask_bbox(mask))
|
|
|
|
def test_mask_bbox_and_expand(self):
|
|
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
|
|
mask[1:3, 2:4] = 255
|
|
bbox = self.mod._mask_bbox(mask)
|
|
self.assertEqual(bbox, (2, 1, 3, 2))
|
|
expanded = self.mod._expand_bbox(bbox, 2, 5, 5)
|
|
self.assertEqual(expanded, (0, 0, 4, 4))
|
|
expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4)
|
|
self.assertEqual(expanded_min, (1, 0, 4, 3))
|
|
|
|
def test_ensure_rgba_size_resizes(self):
|
|
img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0))
|
|
out = self.mod._ensure_rgba_size(img, (4, 5))
|
|
self.assertEqual(out.mode, "RGBA")
|
|
self.assertEqual(out.size, (4, 5))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|