diff --git a/.DS_Store b/.DS_Store index a9c3e12..ac1148b 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/AOT-GAN-for-Inpainting b/AOT-GAN-for-Inpainting new file mode 160000 index 0000000..2cd1afd --- /dev/null +++ b/AOT-GAN-for-Inpainting @@ -0,0 +1 @@ +Subproject commit 2cd1afd8fdfabb101c678f6062d14bc7d302509e diff --git a/README.md b/README.md index 105efc3..42a424d 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,9 @@ python remove_background.py # 处理单个文件 python remove_background.py input.jpg output.png -# 查看所有参数-m, --model) +# 查看所有参数 +python remove_background.py -h +``` 不同模型适用于不同场景: @@ -39,7 +41,19 @@ python remove_background.py input.jpg output.png -m birefnet-portrait # 使用快速模型 python remove_background.py input.jpg output.png -m u2netp -`` + +# 书画类文字偏浅的温和参数示例 +python remove_background.py images output \ + --remove-subject --black_subject --gray_subject --save_mask \ + --black-threshold 30\ + --gray-saturation-threshold 30 --gray-value-threshold 30 \ + --edge-grow 2 \ + --feather --feather-radius 4 \ + --aot-pretrain experiments/G0000000.pt \ + --aot-max-size 1000 +``` + +查看已下载模型: ```bash ls -lh ~/.u2net/ ``` @@ -75,6 +89,40 @@ rm -rf ~/.u2net/ rm ~/.u2net/u2net.onnx ``` +## AOT-GAN 修补后端 + +`--remove-subject` 默认使用 AOT-GAN 修补。 +AOT-GAN 依赖 PyTorch(官方仓库测试 Python 3.8 / torch 1.8.1)。建议使用独立虚拟环境或确保兼容版本。 + +```bash +# 安装依赖(示例) +pip install torch torchvision +``` + +下载预训练权重后,运行示例: +```bash +python remove_background.py "images/IMG_9259 2.JPG" \ + --remove-subject --black-subject --gray-subject --save-mask \ + --aot-pretrain experiments/places2.pth +``` + +CPU 无 GPU 时的加速建议(只裁剪主体区域并限制最大边): +```bash +python remove_background.py "images/IMG_9259 2.JPG" \ + --remove-subject --black-subject --gray-subject --save-mask \ + --aot-pretrain experiments/places2.pth \ + --aot-crop --aot-crop-pad 24 --aot-max-size 1400 +``` + +减少“补脸”倾向:启用随机噪声预填充 +```bash +python remove_background.py "images/IMG_9259 2.JPG" \ + --remove-subject --black-subject --gray-subject --save-mask \ + --aot-pretrain experiments/places2.pth \ + --aot-crop --aot-crop-pad 64 --aot-max-size 900 \ + --aot-noise-prefill --aot-noise-strength 1.0 +``` + ## 可调整参数说明 ### 1. 模型选择 (model_name) @@ -97,10 +145,11 @@ rm ~/.u2net/u2net.onnx Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头发、毛发等细节。 -#### alpha_matting (True/False) -- **作用**: 是否启用alpha matting -- **默认**: False -- **建议**: 如果边缘不自然,启用此选项 +#### alpha_matting(开关) +- **作用**: 是否启用 alpha matting,提升边缘质量 +- **默认**: 关闭(不传 `-a/--alpha-matting`) +- **启用方式**: 传入 `-a` 或 `--alpha-matting` +- **效果**: 有利于细节边缘(毛发/细线),但速度稍慢 #### alpha_matting_foreground_threshold (0-255) - **作用**: 前景阈值,控制哪些区域被认为是前景 @@ -128,9 +177,45 @@ Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头 ### 3. Mask后处理 (post_process_mask) -- **作用**: 对mask进行额外的后处理 -- **默认**: False -- **建议**: 可以尝试启用看效果是否改善 +- **作用**: 对 mask 进行额外后处理 +- **默认**: 关闭(不传 `-p/--post-process`) +- **启用方式**: 传入 `-p` 或 `--post-process` +- **效果**: 有助于减少毛边,但可能略损失细节 + +### 4. 去主体补背景 (remove_subject) + +用于“去掉主体并补全背景”。当前仅使用 AOT-GAN 修补。 + +- **remove_subject(开关)**: 启用去主体补背景(默认关闭,传 `--remove-subject` 开启) +- **aot_root**: AOT-GAN 目录(默认: `AOT-GAN-for-Inpainting`) +- **aot_pretrain**: AOT-GAN 权重文件路径(必填) +- **aot_device**: AOT-GAN 设备(默认: `cpu`) +- **aot_block_num**: AOTBlock 数量(默认: 8) +- **aot_rates**: AOTBlock 膨胀率(默认: `1+2+4+8`) +- **aot_crop(开关)**: 仅对 mask 覆盖区域裁剪修补(默认关闭,传 `--aot-crop` 开启) +- **aot_crop_pad (像素)**: 裁剪边缘留白像素(默认: 0) +- **aot_max_size (像素)**: AOT 输入最大边限制(默认: 0 表示不限制) +- **aot_noise_prefill(开关)**: AOT使用随机噪声预填充(默认关闭) +- **aot_noise_strength (系数)**: 噪声强度(默认: 1.0) +- **mask_dilate (像素)**: mask 膨胀大小(默认: 3)。越大去除范围越大,风险更高 +- **mask_blur (像素)**: mask 模糊大小(默认: 3)。越大边缘越柔和但易过度 +- **mask_threshold (0-255)**: alpha 阈值(默认: 10)。越大保留越多主体 +- **edge_grow (像素)**: 主体边缘额外扩张(默认: 0)。用于清理残留边缘 +- **save_mask(开关)**: 保存 mask 方便检查(默认关闭,传 `--save-mask` 开启) +- **black_subject(开关)**: 将黑色内容也视为主体(默认关闭,传 `--black-subject` 开启) +- **black_threshold (0-255)**: 黑色阈值(默认: 50)。越大越容易把浅灰当黑 +- **gray_subject(开关)**: 将灰阶内容也视为主体(默认关闭,传 `--gray-subject` 开启) +- **gray_saturation_threshold (0-255)**: 灰阶饱和度阈值(默认: 30)。越大越容易把彩色当灰 +- **gray_value_threshold (0-255)**: 灰阶亮度阈值(默认: 200)。越大越容易把浅灰当灰 +- **feather(开关)**: 启用边缘过渡(默认关闭,传 `--feather` 开启) +- **feather_radius (像素)**: 过渡半径(默认: 5)。越大过渡越柔和但可能变糊 +- **说明**: 过渡仅在 mask 外侧进行,避免把主体边缘带回 + +### 5. 参数调优建议(针对书画/字迹) +- 先开启 `--remove-subject`,仅看主体遮罩是否覆盖到字迹 +- 文字残留:提高 `--black-threshold` 或 `--gray-*` 阈值 +- 过度修补:降低 `--black-threshold`、`--gray-value-threshold`,并减小 `--mask-dilate/--mask-blur` +- 边缘不自然:尝试开启 `--feather` 并使用较小的 `--feather-radius` ## 常见问题解决 diff --git a/__pycache__/remove_background.cpython-312.pyc b/__pycache__/remove_background.cpython-312.pyc new file mode 100644 index 0000000..1e78c22 Binary files /dev/null and b/__pycache__/remove_background.cpython-312.pyc differ diff --git a/experiments/D0000000.pt b/experiments/D0000000.pt new file mode 100644 index 0000000..2398008 Binary files /dev/null and b/experiments/D0000000.pt differ diff --git a/experiments/G0000000.pt b/experiments/G0000000.pt new file mode 100644 index 0000000..4911b7f Binary files /dev/null and b/experiments/G0000000.pt differ diff --git a/experiments/O0000000.pt b/experiments/O0000000.pt new file mode 100644 index 0000000..1d03a3c Binary files /dev/null and b/experiments/O0000000.pt differ diff --git a/output/.DS_Store b/images/.DS_Store similarity index 100% rename from output/.DS_Store rename to images/.DS_Store diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0033d2f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.mypy] +python_version = "3.12" +ignore_missing_imports = true +exclude = '^(generative_inpainting|images|output)/' + +[tool.ruff] +exclude = [ + "AOT-GAN-for-Inpainting", + "generative_inpainting", + "images", + "output", + "__pycache__", + ".venv", + "venv", + ".git", +] diff --git a/remove_background.py b/remove_background.py index 43531c2..a59690f 100644 --- a/remove_background.py +++ b/remove_background.py @@ -3,10 +3,288 @@ 使用rembg库自动去除图片背景 """ import os +import sys import argparse from pathlib import Path +import numpy as np +import cv2 +from PIL import Image, ImageOps + +# 避免 numba 在某些环境下缓存失败 +os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache") + from rembg import remove, new_session -from PIL import Image + +AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {} + +def _resolve_aot_paths(aot_root, aot_pretrain): + """解析 AOT-GAN 的路径配置。""" + aot_root = Path(aot_root).resolve() + if aot_pretrain is None: + aot_pretrain = None + else: + aot_pretrain = Path(aot_pretrain) + if not aot_pretrain.is_absolute(): + # 预训练权重相对路径以当前工作目录为基准 + aot_pretrain = (Path.cwd() / aot_pretrain).resolve() + return aot_root, aot_pretrain + +def _parse_aot_rates(rates_str): + parts = [p for p in rates_str.split("+") if p] + return [int(p) for p in parts] + +def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None): + """加载/缓存 AOT-GAN 模型。""" + aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain) + if not aot_root.exists(): + raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}") + if aot_pretrain is None: + raise ValueError("AOT-GAN需要指定预训练权重路径(--aot-pretrain)") + if not aot_pretrain.exists(): + raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}") + if rates is None: + rates = [1, 2, 4, 8] + rates_tuple = tuple(rates) + key = (str(aot_pretrain), device, int(block_num), rates_tuple) + if key in AOT_MODEL_CACHE: + return AOT_MODEL_CACHE[key] + + src_root = aot_root / "src" + if str(src_root) not in sys.path: + sys.path.insert(0, str(src_root)) + import importlib + try: + import torch + except ImportError as exc: + raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc + + net = importlib.import_module("model.aotgan") + class _Args: + pass + args = _Args() + args.block_num = int(block_num) + args.rates = rates + model = net.InpaintGenerator(args) + state = torch.load(str(aot_pretrain), map_location=device) + if isinstance(state, dict): + if "state_dict" in state: + state = state["state_dict"] + elif "model" in state: + state = state["model"] + elif "generator" in state: + state = state["generator"] + if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()): + state = {k.replace("module.", "", 1): v for k, v in state.items()} + model.load_state_dict(state, strict=True) + model.to(device) + model.eval() + AOT_MODEL_CACHE[key] = (model, device) + return AOT_MODEL_CACHE[key] + +def _mask_bbox(mask): + ys, xs = np.where(mask > 0) + if len(xs) == 0: + return None + return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) + +def _expand_bbox(bbox, pad, width, height): + if bbox is None: + return None + x0, y0, x1, y1 = bbox + pad = max(0, int(pad)) + x0 = max(0, x0 - pad) + y0 = max(0, y0 - pad) + x1 = min(width - 1, x1 + pad) + y1 = min(height - 1, y1 + pad) + return x0, y0, x1, y1 + +def _expand_bbox_min_size(bbox, pad, width, height, min_size): + """在给定 pad 基础上,确保裁剪区域至少为 min_size。""" + expanded = _expand_bbox(bbox, pad, width, height) + if expanded is None: + return None + x0, y0, x1, y1 = expanded + roi_w = x1 - x0 + 1 + roi_h = y1 - y0 + 1 + need_w = max(0, int(min_size) - roi_w) + need_h = max(0, int(min_size) - roi_h) + if need_w == 0 and need_h == 0: + return expanded + extra = max((need_w + 1) // 2, (need_h + 1) // 2) + return _expand_bbox(bbox, pad + extra, width, height) + +def _build_noise_prefill(img, mask_t, strength): + """为mask区域生成噪声预填充,img取值范围为[-1, 1]。""" + import torch + img01 = (img + 1.0) / 2.0 + unmasked = 1.0 - mask_t + denom = unmasked.sum() + if denom.item() < 1.0: + mean = torch.full((1, 3, 1, 1), 0.5, device=img.device) + std = torch.full((1, 3, 1, 1), 0.2, device=img.device) + else: + mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom + var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom + std = torch.sqrt(var + 1e-6) + noise = mean + std * float(strength) * torch.randn_like(img01) + noise = noise.clamp(0.0, 1.0) + return noise * 2.0 - 1.0 + +def _inpaint_with_aot_core( + bgr, + mask, + aot_root, + aot_pretrain, + device="cpu", + block_num=8, + rates=None, + noise_prefill=False, + noise_strength=1.0, +): + """使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。""" + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + if mask.max() <= 1: + mask = mask * 255 + if mask.max() == 0: + return bgr + + h, w = mask.shape + grid = 4 + h2 = (h // grid) * grid + w2 = (w // grid) * grid + if h2 == 0 or w2 == 0: + return bgr + + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + img_crop = rgb[:h2, :w2, :] + mask_crop = mask[:h2, :w2] + + import torch + img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0 + img = img * 2.0 - 1.0 + mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0) + img = img.unsqueeze(0) + mask_t = mask_t.unsqueeze(0) + if device != "cpu": + img = img.to(device) + mask_t = mask_t.to(device) + + model, _ = _get_aot_model( + aot_root=aot_root, + aot_pretrain=aot_pretrain, + device=device, + block_num=block_num, + rates=rates, + ) + with torch.no_grad(): + if noise_prefill: + noise = _build_noise_prefill(img, mask_t, noise_strength) + image_masked = img * (1 - mask_t) + noise * mask_t + else: + image_masked = img * (1 - mask_t) + mask_t + pred = model(image_masked, mask_t) + comp = pred * mask_t + img * (1 - mask_t) + comp = comp[0].clamp(-1.0, 1.0) + comp = (comp + 1.0) / 2.0 * 255.0 + comp = comp.permute(1, 2, 0).byte().cpu().numpy() + result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR) + + if h2 == h and w2 == w: + return result_bgr + output_full = bgr.copy() + output_full[:h2, :w2, :] = result_bgr + return output_full + +def _inpaint_with_aot( + bgr, + mask, + aot_root, + aot_pretrain, + device="cpu", + block_num=8, + rates=None, + crop=False, + crop_pad=0, + max_size=0, + noise_prefill=False, + noise_strength=1.0, +): + """使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。""" + min_side = 32 + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + if mask.max() <= 1: + mask = mask * 255 + if mask.max() == 0: + return bgr + + h, w = mask.shape + x0 = y0 = 0 + x1 = w - 1 + y1 = h - 1 + if crop: + bbox = _mask_bbox(mask) + if bbox is None: + return bgr + expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side) + if expanded is None: + return bgr + x0, y0, x1, y1 = expanded + roi_w = x1 - x0 + 1 + roi_h = y1 - y0 + 1 + if roi_w < min_side or roi_h < min_side: + # 裁剪区域过小会导致 AOT 失败,回退到全图修补 + crop = False + x0 = y0 = 0 + x1 = w - 1 + y1 = h - 1 + + bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1] + mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1] + roi_h, roi_w = bgr_roi.shape[:2] + + scale = 1.0 + max_size = int(max_size) if max_size else 0 + if max_size > 0: + max_dim = max(roi_h, roi_w) + if max_dim > max_size: + scale = max_size / float(max_dim) + new_w = max(1, int(round(roi_w * scale))) + new_h = max(1, int(round(roi_h * scale))) + if min(new_w, new_h) < min_side: + scale = 1.0 + new_w = roi_w + new_h = roi_h + interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR + if scale != 1.0: + bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp) + mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST) + + filled_roi = _inpaint_with_aot_core( + bgr_roi, + mask_roi, + aot_root=aot_root, + aot_pretrain=aot_pretrain, + device=device, + block_num=block_num, + rates=rates, + noise_prefill=noise_prefill, + noise_strength=noise_strength, + ) + + if scale != 1.0: + filled_roi = cv2.resize( + filled_roi, + (roi_w, roi_h), + interpolation=cv2.INTER_LINEAR, + ) + + if not crop: + return filled_roi + output_full = bgr.copy() + output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi + return output_full # 支持HEIC格式 try: @@ -16,17 +294,6 @@ try: except ImportError: HEIC_SUPPORTED = False -def str2bool(v): - """将字符串转换为布尔值""" - if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('布尔值应为true或false') - def remove_background(input_path, output_path, session=None, **kwargs): """ 去除图片背景 @@ -40,7 +307,7 @@ def remove_background(input_path, output_path, session=None, **kwargs): print(f"正在处理: {input_path}") # 读取输入图片 - input_image = Image.open(input_path) + input_image = ImageOps.exif_transpose(Image.open(input_path)) # 去除背景 output_image = remove(input_image, session=session, **kwargs) @@ -50,10 +317,195 @@ def remove_background(input_path, output_path, session=None, **kwargs): print(f"已保存: {output_path}") -def process_images_folder(input_folder, output_folder, model_name="u2net", - alpha_matting=False, alpha_matting_foreground_threshold=240, - alpha_matting_background_threshold=10, alpha_matting_erode_size=10, - post_process_mask=False): +def _pil_to_bgr(image): + """将PIL图片转换为OpenCV BGR格式""" + rgb = image.convert("RGB") + arr = np.array(rgb) + return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) + +def _alpha_to_mask(alpha_channel, threshold=10): + """将alpha通道转换为二值mask(0或255)""" + mask = (alpha_channel > threshold).astype(np.uint8) * 255 + return mask + +def _ensure_rgba_size(image, target_size): + if image.mode != "RGBA": + image = image.convert("RGBA") + if image.size != target_size: + image = image.resize(target_size, resample=Image.NEAREST) + return image + +def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0): + """生成硬边mask与处理后mask(用于填补/过渡)""" + if mask_dilate and mask_dilate > 0: + kernel = np.ones((mask_dilate, mask_dilate), np.uint8) + mask = cv2.dilate(mask, kernel, iterations=1) + if edge_grow and edge_grow > 0: + kernel = np.ones((3, 3), np.uint8) + mask = cv2.dilate(mask, kernel, iterations=int(edge_grow)) + + mask_hard = mask.copy() + mask_used = mask_hard + if mask_blur and mask_blur > 0: + k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1 + mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0) + mask_used = (mask_used > 0).astype(np.uint8) * 255 + return mask_hard, mask_used + +def remove_subject_and_inpaint( + input_path, + output_path, + session=None, + mask_dilate=3, + mask_blur=3, + mask_threshold=10, + edge_grow=0, + aot_root="AOT-GAN-for-Inpainting", + aot_pretrain=None, + aot_device="cpu", + aot_block_num=8, + aot_rates="1+2+4+8", + aot_crop=False, + aot_crop_pad=0, + aot_max_size=0, + aot_noise_prefill=False, + aot_noise_strength=1.0, + black_subject=False, + black_threshold=50, + gray_subject=False, + gray_saturation_threshold=30, + gray_value_threshold=200, + feather=False, + feather_radius=5, + save_mask=False, + mask_output_path=None, + **kwargs, +): + """ + 去掉主体并补全背景 + + Args: + input_path: 输入图片路径 + output_path: 输出图片路径 + session: rembg会话对象(可选) + aot_root: AOT-GAN目录 + aot_pretrain: AOT-GAN权重文件路径 + aot_device: AOT-GAN设备 + aot_block_num: AOT-GAN AOTBlock 数量 + aot_rates: AOT-GAN AOTBlock 膨胀率 + aot_crop: AOT仅裁剪mask区域进行修补 + aot_crop_pad: AOT裁剪边缘留白像素 + aot_max_size: AOT输入最大边限制(0为不限制) + aot_noise_prefill: AOT使用随机噪声预填充 + aot_noise_strength: 噪声强度系数 + mask_dilate: mask膨胀大小 + mask_blur: mask模糊大小(奇数) + mask_threshold: alpha阈值 + save_mask: 是否保存mask + mask_output_path: mask保存路径 + **kwargs: rembg参数 + """ + print(f"正在处理(去主体补背景): {input_path}") + + input_image = ImageOps.exif_transpose(Image.open(input_path)) + + # 使用rembg获取主体mask + output_image = remove(input_image, session=session, **kwargs) + output_image = _ensure_rgba_size(output_image, input_image.size) + + alpha = np.array(output_image.getchannel("A")) + mask = _alpha_to_mask(alpha, threshold=mask_threshold) + + bgr = _pil_to_bgr(input_image) + if black_subject: + gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) + black_mask = (gray <= black_threshold).astype(np.uint8) * 255 + mask = cv2.bitwise_or(mask, black_mask) + if gray_subject: + hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) + s = hsv[:, :, 1] + v = hsv[:, :, 2] + gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255 + mask = cv2.bitwise_or(mask, gray_mask) + + mask_hard, mask_used = _prepare_mask( + mask, + mask_dilate=mask_dilate, + mask_blur=mask_blur, + edge_grow=edge_grow, + ) + + rates = _parse_aot_rates(aot_rates) + filled = _inpaint_with_aot( + bgr, + mask_used, + aot_root=aot_root, + aot_pretrain=aot_pretrain, + device=aot_device, + block_num=aot_block_num, + rates=rates, + crop=aot_crop, + crop_pad=aot_crop_pad, + max_size=aot_max_size, + noise_prefill=aot_noise_prefill, + noise_strength=aot_noise_strength, + ) + + if feather and feather_radius > 0: + # 仅在mask外侧做过渡,避免把原主体带回 + mask_bin = (mask_hard > 0).astype(np.uint8) * 255 + dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3) + alpha = np.ones_like(dist_out, dtype=np.float32) + outside = mask_bin == 0 + alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0) + alpha = alpha[:, :, None] + blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8) + result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) + else: + result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB) + Image.fromarray(result).save(output_path) + + if save_mask: + if mask_output_path is None: + mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png" + Image.fromarray(mask_used).save(mask_output_path) + print(f"已保存mask: {mask_output_path}") + + print(f"已保存: {output_path}") + +def process_images_folder( + input_folder, + output_folder, + model_name="u2net", + alpha_matting=False, + alpha_matting_foreground_threshold=240, + alpha_matting_background_threshold=10, + alpha_matting_erode_size=10, + post_process_mask=False, + remove_subject=False, + mask_dilate=3, + mask_blur=3, + mask_threshold=10, + edge_grow=0, + aot_root="AOT-GAN-for-Inpainting", + aot_pretrain=None, + aot_device="cpu", + aot_block_num=8, + aot_rates="1+2+4+8", + aot_crop=False, + aot_crop_pad=0, + aot_max_size=0, + aot_noise_prefill=False, + aot_noise_strength=1.0, + black_subject=False, + black_threshold=50, + gray_subject=False, + gray_saturation_threshold=30, + gray_value_threshold=200, + feather=False, + feather_radius=5, + save_mask=False, +): """ 批量处理文件夹中的所有图片 @@ -105,26 +557,100 @@ def process_images_folder(input_folder, output_folder, model_name="u2net", print(f" - 背景阈值: {alpha_matting_background_threshold}") print(f" - 侵蚀大小: {alpha_matting_erode_size}") print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}") + print(f"去主体补背景: {'启用' if remove_subject else '禁用'}") + if remove_subject: + print(f" - AOT目录: {aot_root}") + print(f" - AOT权重: {aot_pretrain}") + print(f" - AOT设备: {aot_device}") + print(f" - AOT块数: {aot_block_num}") + print(f" - AOT膨胀率: {aot_rates}") + print(f" - AOT裁剪: {'是' if aot_crop else '否'}") + if aot_crop: + print(f" - AOT裁剪边界: {aot_crop_pad}") + print(f" - AOT最大边: {aot_max_size}") + print(f" - AOT噪声预填充: {'是' if aot_noise_prefill else '否'}") + if aot_noise_prefill: + print(f" - AOT噪声强度: {aot_noise_strength}") + print(f" - mask膨胀: {mask_dilate}") + print(f" - mask模糊: {mask_blur}") + print(f" - mask阈值: {mask_threshold}") + print(f" - 边缘扩张: {edge_grow}") + print(f" - 黑色内容作为主体: {'是' if black_subject else '否'}") + if black_subject: + print(f" - 黑色阈值: {black_threshold}") + print(f" - 灰阶内容作为主体: {'是' if gray_subject else '否'}") + if gray_subject: + print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}") + print(f" - 灰阶亮度阈值: {gray_value_threshold}") + print(f" - 边缘过渡: {'是' if feather else '否'}") + if feather: + print(f" - 过渡半径: {feather_radius}") + print(f" - 保存mask: {'是' if save_mask else '否'}") print("-" * 50) # 处理每张图片 for i, image_file in enumerate(image_files, 1): try: - # 输出文件名(保持原始名称,改为PNG格式以支持透明背景) - output_filename = image_file.stem + '_nobg.png' + if remove_subject: + suffix = image_file.suffix.lower() + if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}: + output_filename = image_file.stem + "_bgfill" + suffix + else: + output_filename = image_file.stem + "_bgfill.jpg" + else: + # 去背景默认使用PNG格式以支持透明背景 + output_filename = image_file.stem + "_nobg.png" output_path = Path(output_folder) / output_filename print(f"[{i}/{len(image_files)}] ", end="") - remove_background( - str(image_file), - str(output_path), - session=session, - alpha_matting=alpha_matting, - alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, - alpha_matting_background_threshold=alpha_matting_background_threshold, - alpha_matting_erode_size=alpha_matting_erode_size, - post_process_mask=post_process_mask - ) + if remove_subject: + mask_output_path = None + if save_mask: + mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png")) + remove_subject_and_inpaint( + str(image_file), + str(output_path), + session=session, + alpha_matting=alpha_matting, + alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, + alpha_matting_background_threshold=alpha_matting_background_threshold, + alpha_matting_erode_size=alpha_matting_erode_size, + post_process_mask=post_process_mask, + mask_dilate=mask_dilate, + mask_blur=mask_blur, + mask_threshold=mask_threshold, + edge_grow=edge_grow, + aot_root=aot_root, + aot_pretrain=aot_pretrain, + aot_device=aot_device, + aot_block_num=aot_block_num, + aot_rates=aot_rates, + aot_crop=aot_crop, + aot_crop_pad=aot_crop_pad, + aot_max_size=aot_max_size, + aot_noise_prefill=aot_noise_prefill, + aot_noise_strength=aot_noise_strength, + black_subject=black_subject, + black_threshold=black_threshold, + gray_subject=gray_subject, + gray_saturation_threshold=gray_saturation_threshold, + gray_value_threshold=gray_value_threshold, + feather=feather, + feather_radius=feather_radius, + save_mask=save_mask, + mask_output_path=mask_output_path, + ) + else: + remove_background( + str(image_file), + str(output_path), + session=session, + alpha_matting=alpha_matting, + alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, + alpha_matting_background_threshold=alpha_matting_background_threshold, + alpha_matting_erode_size=alpha_matting_erode_size, + post_process_mask=post_process_mask, + ) except Exception as e: print(f"处理 {image_file.name} 时出错: {e}") @@ -171,9 +697,8 @@ if __name__ == "__main__": help='选择使用的模型 (默认: isnet-general-use)') # Alpha Matting参数 - parser.add_argument('-a', '--alpha-matting', type=str2bool, nargs='?', const=True, default=True, - metavar='true/false', - help='启用alpha matting后处理(默认: true)。用法: -a 或 -a true 或 -a false') + parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true', + help='启用alpha matting后处理(默认: false)') parser.add_argument('-ft', '--foreground-threshold', type=int, default=245, help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)') parser.add_argument('-bt', '--background-threshold', type=int, default=8, @@ -182,9 +707,56 @@ if __name__ == "__main__": help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)') # 其他选项 - parser.add_argument('-p', '--post-process', type=str2bool, nargs='?', const=True, default=True, - metavar='true/false', - help='启用mask后处理(默认: true)。用法: -p 或 -p true 或 -p false') + parser.add_argument('-p', '--post-process', '--post_process', action='store_true', + help='启用mask后处理(默认: false)') + + # 去主体补背景参数 + parser.add_argument('--remove-subject', '--remove_subject', action='store_true', + help='去掉主体并补全背景(默认: false)') + parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting', + help='AOT-GAN目录(默认: AOT-GAN-for-Inpainting)') + parser.add_argument('--aot-pretrain', type=str, default=None, + help='AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)') + parser.add_argument('--aot-device', type=str, default='cpu', + help='AOT-GAN设备(默认: cpu)') + parser.add_argument('--aot-block-num', type=int, default=8, + help='AOTBlock数量(默认: 8)') + parser.add_argument('--aot-rates', type=str, default='1+2+4+8', + help='AOTBlock膨胀率(默认: 1+2+4+8)') + parser.add_argument('--aot-crop', action='store_true', + help='AOT仅对mask区域裁剪修补(默认: false)') + parser.add_argument('--aot-crop-pad', type=int, default=0, + help='AOT裁剪边缘留白像素(默认: 0)') + parser.add_argument('--aot-max-size', type=int, default=0, + help='AOT输入最大边限制,0为不限制(默认: 0)') + parser.add_argument('--aot-noise-prefill', action='store_true', + help='AOT使用随机噪声预填充(默认: false)') + parser.add_argument('--aot-noise-strength', type=float, default=1.0, + help='AOT噪声强度系数(默认: 1.0)') + parser.add_argument('--mask-dilate', type=int, default=3, + help='mask膨胀大小(默认: 3)') + parser.add_argument('--mask-blur', type=int, default=3, + help='mask模糊大小(默认: 3,建议奇数)') + parser.add_argument('--mask-threshold', type=int, default=10, + help='alpha阈值(默认: 10)') + parser.add_argument('--edge-grow', type=int, default=0, + help='主体边缘扩张像素(默认: 0)') + parser.add_argument('--save-mask', '--save_mask', action='store_true', + help='保存mask到output目录(默认: false)') + parser.add_argument('--black-subject', '--black_subject', action='store_true', + help='将黑色内容也视为主体(默认: false)') + parser.add_argument('--black-threshold', type=int, default=50, + help='黑色阈值(0-255,灰度越小越黑,默认: 50)') + parser.add_argument('--gray-subject', '--gray_subject', action='store_true', + help='将灰阶内容也视为主体(默认: false)') + parser.add_argument('--gray-saturation-threshold', type=int, default=30, + help='灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)') + parser.add_argument('--gray-value-threshold', type=int, default=200, + help='灰阶亮度阈值(0-255,越小越暗,默认: 200)') + parser.add_argument('--feather', action='store_true', + help='启用边缘过渡融合(默认: false)') + parser.add_argument('--feather-radius', type=int, default=5, + help='边缘过渡半径(默认: 5)') args = parser.parse_args() @@ -201,13 +773,33 @@ if __name__ == "__main__": # 处理单个文件 if input_path.is_file(): + suffix = input_path.suffix.lower() + if args.remove_subject: + if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}: + output_name = input_path.stem + "_bgfill" + suffix + else: + output_name = input_path.stem + "_bgfill.jpg" + else: + output_name = input_path.stem + "_nobg.png" + # 确定输出路径 if args.output is None: - output_path = input_path.parent / 'output' / (input_path.stem + '_nobg.png') + if args.remove_subject: + output_path = input_path.parent / "output" / output_name + else: + output_path = input_path.parent / 'output' / output_name output_path.parent.mkdir(parents=True, exist_ok=True) else: - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) + output_candidate = Path(args.output) + if output_candidate.exists() and output_candidate.is_dir(): + output_path = output_candidate / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + elif output_candidate.suffix == "": + output_candidate.mkdir(parents=True, exist_ok=True) + output_path = output_candidate / output_name + else: + output_path = output_candidate + output_path.parent.mkdir(parents=True, exist_ok=True) print(f"输入文件: {input_path}") print(f"输出文件: {output_path}") @@ -218,22 +810,89 @@ if __name__ == "__main__": print(f" - 背景阈值: {args.background_threshold}") print(f" - 侵蚀大小: {args.erode_size}") print(f"Mask后处理: {'启用' if args.post_process else '禁用'}") + print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}") + if args.remove_subject: + print(f" - AOT目录: {args.aot_root}") + print(f" - AOT权重: {args.aot_pretrain}") + print(f" - AOT设备: {args.aot_device}") + print(f" - AOT块数: {args.aot_block_num}") + print(f" - AOT膨胀率: {args.aot_rates}") + print(f" - AOT裁剪: {'是' if args.aot_crop else '否'}") + if args.aot_crop: + print(f" - AOT裁剪边界: {args.aot_crop_pad}") + print(f" - AOT最大边: {args.aot_max_size}") + print(f" - AOT噪声预填充: {'是' if args.aot_noise_prefill else '否'}") + if args.aot_noise_prefill: + print(f" - AOT噪声强度: {args.aot_noise_strength}") + print(f" - mask膨胀: {args.mask_dilate}") + print(f" - mask模糊: {args.mask_blur}") + print(f" - mask阈值: {args.mask_threshold}") + print(f" - 边缘扩张: {args.edge_grow}") + print(f" - 黑色内容作为主体: {'是' if args.black_subject else '否'}") + if args.black_subject: + print(f" - 黑色阈值: {args.black_threshold}") + print(f" - 灰阶内容作为主体: {'是' if args.gray_subject else '否'}") + if args.gray_subject: + print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}") + print(f" - 灰阶亮度阈值: {args.gray_value_threshold}") + print(f" - 边缘过渡: {'是' if args.feather else '否'}") + if args.feather: + print(f" - 过渡半径: {args.feather_radius}") + print(f" - 保存mask: {'是' if args.save_mask else '否'}") print("-" * 50) # 创建会话 session = new_session(args.model) # 处理图片 - remove_background( - str(input_path), - str(output_path), - session=session, - alpha_matting=args.alpha_matting, - alpha_matting_foreground_threshold=args.foreground_threshold, - alpha_matting_background_threshold=args.background_threshold, - alpha_matting_erode_size=args.erode_size, - post_process_mask=args.post_process - ) + if args.remove_subject: + mask_output_path = None + if args.save_mask: + mask_output_path = str(output_path.with_suffix("")) + "_mask.png" + remove_subject_and_inpaint( + str(input_path), + str(output_path), + session=session, + alpha_matting=args.alpha_matting, + alpha_matting_foreground_threshold=args.foreground_threshold, + alpha_matting_background_threshold=args.background_threshold, + alpha_matting_erode_size=args.erode_size, + post_process_mask=args.post_process, + mask_dilate=args.mask_dilate, + mask_blur=args.mask_blur, + mask_threshold=args.mask_threshold, + edge_grow=args.edge_grow, + aot_root=args.aot_root, + aot_pretrain=args.aot_pretrain, + aot_device=args.aot_device, + aot_block_num=args.aot_block_num, + aot_rates=args.aot_rates, + aot_crop=args.aot_crop, + aot_crop_pad=args.aot_crop_pad, + aot_max_size=args.aot_max_size, + aot_noise_prefill=args.aot_noise_prefill, + aot_noise_strength=args.aot_noise_strength, + black_subject=args.black_subject, + black_threshold=args.black_threshold, + gray_subject=args.gray_subject, + gray_saturation_threshold=args.gray_saturation_threshold, + gray_value_threshold=args.gray_value_threshold, + feather=args.feather, + feather_radius=args.feather_radius, + save_mask=args.save_mask, + mask_output_path=mask_output_path, + ) + else: + remove_background( + str(input_path), + str(output_path), + session=session, + alpha_matting=args.alpha_matting, + alpha_matting_foreground_threshold=args.foreground_threshold, + alpha_matting_background_threshold=args.background_threshold, + alpha_matting_erode_size=args.erode_size, + post_process_mask=args.post_process, + ) print("-" * 50) print(f"处理完成!结果保存在: {output_path}") @@ -250,7 +909,30 @@ if __name__ == "__main__": alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, - post_process_mask=args.post_process + post_process_mask=args.post_process, + remove_subject=args.remove_subject, + mask_dilate=args.mask_dilate, + mask_blur=args.mask_blur, + mask_threshold=args.mask_threshold, + edge_grow=args.edge_grow, + aot_root=args.aot_root, + aot_pretrain=args.aot_pretrain, + aot_device=args.aot_device, + aot_block_num=args.aot_block_num, + aot_rates=args.aot_rates, + aot_crop=args.aot_crop, + aot_crop_pad=args.aot_crop_pad, + aot_max_size=args.aot_max_size, + aot_noise_prefill=args.aot_noise_prefill, + aot_noise_strength=args.aot_noise_strength, + black_subject=args.black_subject, + black_threshold=args.black_threshold, + gray_subject=args.gray_subject, + gray_saturation_threshold=args.gray_saturation_threshold, + gray_value_threshold=args.gray_value_threshold, + feather=args.feather, + feather_radius=args.feather_radius, + save_mask=args.save_mask, ) else: diff --git a/requirements.txt b/requirements.txt index 6844ac3..0cb5bf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ rembg[gpu] pillow pillow-heif +opencv-python diff --git a/tests/__pycache__/test_remove_background.cpython-312.pyc b/tests/__pycache__/test_remove_background.cpython-312.pyc new file mode 100644 index 0000000..c0feb53 Binary files /dev/null and b/tests/__pycache__/test_remove_background.cpython-312.pyc differ diff --git a/tests/test_remove_background.py b/tests/test_remove_background.py new file mode 100644 index 0000000..b9469e8 --- /dev/null +++ b/tests/test_remove_background.py @@ -0,0 +1,78 @@ +import importlib +import sys +import types +import unittest + + +def _import_remove_background(): + if "remove_background" in sys.modules: + del sys.modules["remove_background"] + rembg_stub = types.ModuleType("rembg") + rembg_stub.remove = lambda *args, **kwargs: None + rembg_stub.new_session = lambda *args, **kwargs: None + sys.modules["rembg"] = rembg_stub + return importlib.import_module("remove_background") + + +class RemoveBackgroundTests(unittest.TestCase): + def setUp(self): + self.mod = _import_remove_background() + + def test_alpha_to_mask_threshold(self): + alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8) + mask = self.mod._alpha_to_mask(alpha, threshold=10) + self.assertEqual(mask.shape, alpha.shape) + self.assertEqual(mask[0, 0], 0) + self.assertEqual(mask[0, 1], 0) + self.assertEqual(mask[0, 2], 0) + self.assertEqual(mask[0, 3], 255) + + def test_prepare_mask_outputs_binary(self): + mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8) + mask[2, 2] = 255 + mask_hard, mask_used = self.mod._prepare_mask(mask, mask_dilate=0, mask_blur=3) + self.assertEqual(mask_hard.shape, mask.shape) + self.assertEqual(mask_used.shape, mask.shape) + self.assertTrue(self.mod.np.all(self.mod.np.isin(mask_used, [0, 255]))) + + def test_resolve_aot_paths_defaults(self): + aot_root, aot_pretrain = self.mod._resolve_aot_paths( + "AOT-GAN-for-Inpainting", None + ) + self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting")) + self.assertIsNone(aot_pretrain) + + def test_resolve_aot_paths_relative_pretrain(self): + aot_root, aot_pretrain = self.mod._resolve_aot_paths( + "AOT-GAN-for-Inpainting", "experiments/foo.pt" + ) + self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting")) + self.assertTrue(str(aot_pretrain).endswith("experiments/foo.pt")) + + def test_parse_aot_rates(self): + rates = self.mod._parse_aot_rates("1+2+4+8") + self.assertEqual(rates, [1, 2, 4, 8]) + + def test_mask_bbox_empty(self): + mask = self.mod.np.zeros((3, 3), dtype=self.mod.np.uint8) + self.assertIsNone(self.mod._mask_bbox(mask)) + + def test_mask_bbox_and_expand(self): + mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8) + mask[1:3, 2:4] = 255 + bbox = self.mod._mask_bbox(mask) + self.assertEqual(bbox, (2, 1, 3, 2)) + expanded = self.mod._expand_bbox(bbox, 2, 5, 5) + self.assertEqual(expanded, (0, 0, 4, 4)) + expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4) + self.assertEqual(expanded_min, (1, 0, 4, 3)) + + def test_ensure_rgba_size_resizes(self): + img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0)) + out = self.mod._ensure_rgba_size(img, (4, 5)) + self.assertEqual(out.mode, "RGBA") + self.assertEqual(out.size, (4, 5)) + + +if __name__ == "__main__": + unittest.main()