""" 图片去背景工具 使用rembg库自动去除图片背景 """ import os import sys import argparse from pathlib import Path import numpy as np import cv2 from PIL import Image, ImageOps # 避免 numba 在某些环境下缓存失败 os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache") from rembg import remove, new_session AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {} def _resolve_aot_paths(aot_root, aot_pretrain): """解析 AOT-GAN 的路径配置。""" aot_root = Path(aot_root).resolve() if aot_pretrain is None: aot_pretrain = None else: aot_pretrain = Path(aot_pretrain) if not aot_pretrain.is_absolute(): # 预训练权重相对路径以当前工作目录为基准 aot_pretrain = (Path.cwd() / aot_pretrain).resolve() return aot_root, aot_pretrain def _parse_aot_rates(rates_str): parts = [p for p in rates_str.split("+") if p] return [int(p) for p in parts] def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None): """加载/缓存 AOT-GAN 模型。""" aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain) if not aot_root.exists(): raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}") if aot_pretrain is None: raise ValueError("AOT-GAN需要指定预训练权重路径(--aot-pretrain)") if not aot_pretrain.exists(): raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}") if rates is None: rates = [1, 2, 4, 8] rates_tuple = tuple(rates) key = (str(aot_pretrain), device, int(block_num), rates_tuple) if key in AOT_MODEL_CACHE: return AOT_MODEL_CACHE[key] src_root = aot_root / "src" if str(src_root) not in sys.path: sys.path.insert(0, str(src_root)) import importlib try: import torch except ImportError as exc: raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc net = importlib.import_module("model.aotgan") class _Args: pass args = _Args() args.block_num = int(block_num) args.rates = rates model = net.InpaintGenerator(args) state = torch.load(str(aot_pretrain), map_location=device) if isinstance(state, dict): if "state_dict" in state: state = state["state_dict"] elif "model" in state: state = state["model"] elif "generator" in state: state = state["generator"] if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()): state = {k.replace("module.", "", 1): v for k, v in state.items()} model.load_state_dict(state, strict=True) model.to(device) model.eval() AOT_MODEL_CACHE[key] = (model, device) return AOT_MODEL_CACHE[key] def _mask_bbox(mask): ys, xs = np.where(mask > 0) if len(xs) == 0: return None return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) def _expand_bbox(bbox, pad, width, height): if bbox is None: return None x0, y0, x1, y1 = bbox pad = max(0, int(pad)) x0 = max(0, x0 - pad) y0 = max(0, y0 - pad) x1 = min(width - 1, x1 + pad) y1 = min(height - 1, y1 + pad) return x0, y0, x1, y1 def _expand_bbox_min_size(bbox, pad, width, height, min_size): """在给定 pad 基础上,确保裁剪区域至少为 min_size。""" expanded = _expand_bbox(bbox, pad, width, height) if expanded is None: return None x0, y0, x1, y1 = expanded roi_w = x1 - x0 + 1 roi_h = y1 - y0 + 1 need_w = max(0, int(min_size) - roi_w) need_h = max(0, int(min_size) - roi_h) if need_w == 0 and need_h == 0: return expanded extra = max((need_w + 1) // 2, (need_h + 1) // 2) return _expand_bbox(bbox, pad + extra, width, height) def _build_noise_prefill(img, mask_t, strength): """为mask区域生成噪声预填充,img取值范围为[-1, 1]。""" import torch img01 = (img + 1.0) / 2.0 unmasked = 1.0 - mask_t denom = unmasked.sum() if denom.item() < 1.0: mean = torch.full((1, 3, 1, 1), 0.5, device=img.device) std = torch.full((1, 3, 1, 1), 0.2, device=img.device) else: mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom std = torch.sqrt(var + 1e-6) noise = mean + std * float(strength) * torch.randn_like(img01) noise = noise.clamp(0.0, 1.0) return noise * 2.0 - 1.0 def _inpaint_with_aot_core( bgr, mask, aot_root, aot_pretrain, device="cpu", block_num=8, rates=None, noise_prefill=False, noise_strength=1.0, ): """使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。""" if mask.dtype != np.uint8: mask = mask.astype(np.uint8) if mask.max() <= 1: mask = mask * 255 if mask.max() == 0: return bgr h, w = mask.shape grid = 4 h2 = (h // grid) * grid w2 = (w // grid) * grid if h2 == 0 or w2 == 0: return bgr rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) img_crop = rgb[:h2, :w2, :] mask_crop = mask[:h2, :w2] import torch img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0 img = img * 2.0 - 1.0 mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0) img = img.unsqueeze(0) mask_t = mask_t.unsqueeze(0) if device != "cpu": img = img.to(device) mask_t = mask_t.to(device) model, _ = _get_aot_model( aot_root=aot_root, aot_pretrain=aot_pretrain, device=device, block_num=block_num, rates=rates, ) with torch.no_grad(): if noise_prefill: noise = _build_noise_prefill(img, mask_t, noise_strength) image_masked = img * (1 - mask_t) + noise * mask_t else: image_masked = img * (1 - mask_t) + mask_t pred = model(image_masked, mask_t) comp = pred * mask_t + img * (1 - mask_t) comp = comp[0].clamp(-1.0, 1.0) comp = (comp + 1.0) / 2.0 * 255.0 comp = comp.permute(1, 2, 0).byte().cpu().numpy() result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR) if h2 == h and w2 == w: return result_bgr output_full = bgr.copy() output_full[:h2, :w2, :] = result_bgr return output_full def _inpaint_with_aot( bgr, mask, aot_root, aot_pretrain, device="cpu", block_num=8, rates=None, crop=False, crop_pad=0, max_size=0, noise_prefill=False, noise_strength=1.0, ): """使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。""" min_side = 32 if mask.dtype != np.uint8: mask = mask.astype(np.uint8) if mask.max() <= 1: mask = mask * 255 if mask.max() == 0: return bgr h, w = mask.shape x0 = y0 = 0 x1 = w - 1 y1 = h - 1 if crop: bbox = _mask_bbox(mask) if bbox is None: return bgr expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side) if expanded is None: return bgr x0, y0, x1, y1 = expanded roi_w = x1 - x0 + 1 roi_h = y1 - y0 + 1 if roi_w < min_side or roi_h < min_side: # 裁剪区域过小会导致 AOT 失败,回退到全图修补 crop = False x0 = y0 = 0 x1 = w - 1 y1 = h - 1 bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1] mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1] roi_h, roi_w = bgr_roi.shape[:2] scale = 1.0 max_size = int(max_size) if max_size else 0 if max_size > 0: max_dim = max(roi_h, roi_w) if max_dim > max_size: scale = max_size / float(max_dim) new_w = max(1, int(round(roi_w * scale))) new_h = max(1, int(round(roi_h * scale))) if min(new_w, new_h) < min_side: scale = 1.0 new_w = roi_w new_h = roi_h interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR if scale != 1.0: bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp) mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST) filled_roi = _inpaint_with_aot_core( bgr_roi, mask_roi, aot_root=aot_root, aot_pretrain=aot_pretrain, device=device, block_num=block_num, rates=rates, noise_prefill=noise_prefill, noise_strength=noise_strength, ) if scale != 1.0: filled_roi = cv2.resize( filled_roi, (roi_w, roi_h), interpolation=cv2.INTER_LINEAR, ) if not crop: return filled_roi output_full = bgr.copy() output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi return output_full # 支持HEIC格式 try: from pillow_heif import register_heif_opener register_heif_opener() HEIC_SUPPORTED = True except ImportError: HEIC_SUPPORTED = False def remove_background(input_path, output_path, session=None, **kwargs): """ 去除图片背景 Args: input_path: 输入图片路径 output_path: 输出图片路径 session: rembg会话对象(可选) **kwargs: 其他参数,如alpha_matting相关参数 """ print(f"正在处理: {input_path}") # 读取输入图片 input_image = ImageOps.exif_transpose(Image.open(input_path)) # 去除背景 output_image = remove(input_image, session=session, **kwargs) # 保存输出图片 output_image.save(output_path) print(f"已保存: {output_path}") def _pil_to_bgr(image): """将PIL图片转换为OpenCV BGR格式""" rgb = image.convert("RGB") arr = np.array(rgb) return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) def _alpha_to_mask(alpha_channel, threshold=10): """将alpha通道转换为二值mask(0或255)""" mask = (alpha_channel > threshold).astype(np.uint8) * 255 return mask def _ensure_rgba_size(image, target_size): if image.mode != "RGBA": image = image.convert("RGBA") if image.size != target_size: image = image.resize(target_size, resample=Image.NEAREST) return image def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0): """生成硬边mask与处理后mask(用于填补/过渡)""" if mask_dilate and mask_dilate > 0: kernel = np.ones((mask_dilate, mask_dilate), np.uint8) mask = cv2.dilate(mask, kernel, iterations=1) if edge_grow and edge_grow > 0: kernel = np.ones((3, 3), np.uint8) mask = cv2.dilate(mask, kernel, iterations=int(edge_grow)) mask_hard = mask.copy() mask_used = mask_hard if mask_blur and mask_blur > 0: k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1 mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0) mask_used = (mask_used > 0).astype(np.uint8) * 255 return mask_hard, mask_used def remove_subject_and_inpaint( input_path, output_path, session=None, mask_dilate=3, mask_blur=3, mask_threshold=10, edge_grow=0, aot_root="AOT-GAN-for-Inpainting", aot_pretrain=None, aot_device="cpu", aot_block_num=8, aot_rates="1+2+4+8", aot_crop=False, aot_crop_pad=0, aot_max_size=0, aot_noise_prefill=False, aot_noise_strength=1.0, black_subject=False, black_threshold=50, gray_subject=False, gray_saturation_threshold=30, gray_value_threshold=200, feather=False, feather_radius=5, save_mask=False, mask_output_path=None, **kwargs, ): """ 去掉主体并补全背景 Args: input_path: 输入图片路径 output_path: 输出图片路径 session: rembg会话对象(可选) aot_root: AOT-GAN目录 aot_pretrain: AOT-GAN权重文件路径 aot_device: AOT-GAN设备 aot_block_num: AOT-GAN AOTBlock 数量 aot_rates: AOT-GAN AOTBlock 膨胀率 aot_crop: AOT仅裁剪mask区域进行修补 aot_crop_pad: AOT裁剪边缘留白像素 aot_max_size: AOT输入最大边限制(0为不限制) aot_noise_prefill: AOT使用随机噪声预填充 aot_noise_strength: 噪声强度系数 mask_dilate: mask膨胀大小 mask_blur: mask模糊大小(奇数) mask_threshold: alpha阈值 save_mask: 是否保存mask mask_output_path: mask保存路径 **kwargs: rembg参数 """ print(f"正在处理(去主体补背景): {input_path}") input_image = ImageOps.exif_transpose(Image.open(input_path)) # 使用rembg获取主体mask output_image = remove(input_image, session=session, **kwargs) output_image = _ensure_rgba_size(output_image, input_image.size) alpha = np.array(output_image.getchannel("A")) mask = _alpha_to_mask(alpha, threshold=mask_threshold) bgr = _pil_to_bgr(input_image) if black_subject: gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) black_mask = (gray <= black_threshold).astype(np.uint8) * 255 mask = cv2.bitwise_or(mask, black_mask) if gray_subject: hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) s = hsv[:, :, 1] v = hsv[:, :, 2] gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255 mask = cv2.bitwise_or(mask, gray_mask) mask_hard, mask_used = _prepare_mask( mask, mask_dilate=mask_dilate, mask_blur=mask_blur, edge_grow=edge_grow, ) rates = _parse_aot_rates(aot_rates) filled = _inpaint_with_aot( bgr, mask_used, aot_root=aot_root, aot_pretrain=aot_pretrain, device=aot_device, block_num=aot_block_num, rates=rates, crop=aot_crop, crop_pad=aot_crop_pad, max_size=aot_max_size, noise_prefill=aot_noise_prefill, noise_strength=aot_noise_strength, ) if feather and feather_radius > 0: # 仅在mask外侧做过渡,避免把原主体带回 mask_bin = (mask_hard > 0).astype(np.uint8) * 255 dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3) alpha = np.ones_like(dist_out, dtype=np.float32) outside = mask_bin == 0 alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0) alpha = alpha[:, :, None] blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8) result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) else: result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB) Image.fromarray(result).save(output_path) if save_mask: if mask_output_path is None: mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png" Image.fromarray(mask_used).save(mask_output_path) print(f"已保存mask: {mask_output_path}") print(f"已保存: {output_path}") def process_images_folder( input_folder, output_folder, model_name="u2net", alpha_matting=False, alpha_matting_foreground_threshold=240, alpha_matting_background_threshold=10, alpha_matting_erode_size=10, post_process_mask=False, remove_subject=False, mask_dilate=3, mask_blur=3, mask_threshold=10, edge_grow=0, aot_root="AOT-GAN-for-Inpainting", aot_pretrain=None, aot_device="cpu", aot_block_num=8, aot_rates="1+2+4+8", aot_crop=False, aot_crop_pad=0, aot_max_size=0, aot_noise_prefill=False, aot_noise_strength=1.0, black_subject=False, black_threshold=50, gray_subject=False, gray_saturation_threshold=30, gray_value_threshold=200, feather=False, feather_radius=5, save_mask=False, ): """ 批量处理文件夹中的所有图片 Args: input_folder: 输入文件夹路径 output_folder: 输出文件夹路径 model_name: 模型名称,可选值: - u2net (默认): 通用模型 - u2netp: 轻量版u2net - u2net_human_seg: 人物分割 - silueta: 精简版u2net (43MB) - isnet-general-use: 新的通用模型 - isnet-anime: 动漫角色高精度分割 - birefnet-general: 通用模型 - birefnet-portrait: 人像模型 alpha_matting: 是否启用alpha matting后处理(改善边缘质量) alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景 alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景 alpha_matting_erode_size: 侵蚀大小,用于平滑边缘 post_process_mask: 是否启用mask后处理 """ # 创建输出文件夹 Path(output_folder).mkdir(parents=True, exist_ok=True) # 创建会话(重用会话可以提高性能) print(f"使用模型: {model_name}") session = new_session(model_name) # 支持的图片格式 image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'} if HEIC_SUPPORTED: image_extensions.update({'.heic', '.heif'}) else: print("提示: 未安装pillow-heif,HEIC格式不可用。安装方法: pip install pillow-heif") # 获取所有图片文件 input_path = Path(input_folder) image_files = [f for f in input_path.iterdir() if f.is_file() and f.suffix.lower() in image_extensions] if not image_files: print(f"在 {input_folder} 中没有找到图片文件") return print(f"找到 {len(image_files)} 张图片,开始处理...") print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}") if alpha_matting: print(f" - 前景阈值: {alpha_matting_foreground_threshold}") print(f" - 背景阈值: {alpha_matting_background_threshold}") print(f" - 侵蚀大小: {alpha_matting_erode_size}") print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}") print(f"去主体补背景: {'启用' if remove_subject else '禁用'}") if remove_subject: print(f" - AOT目录: {aot_root}") print(f" - AOT权重: {aot_pretrain}") print(f" - AOT设备: {aot_device}") print(f" - AOT块数: {aot_block_num}") print(f" - AOT膨胀率: {aot_rates}") print(f" - AOT裁剪: {'是' if aot_crop else '否'}") if aot_crop: print(f" - AOT裁剪边界: {aot_crop_pad}") print(f" - AOT最大边: {aot_max_size}") print(f" - AOT噪声预填充: {'是' if aot_noise_prefill else '否'}") if aot_noise_prefill: print(f" - AOT噪声强度: {aot_noise_strength}") print(f" - mask膨胀: {mask_dilate}") print(f" - mask模糊: {mask_blur}") print(f" - mask阈值: {mask_threshold}") print(f" - 边缘扩张: {edge_grow}") print(f" - 黑色内容作为主体: {'是' if black_subject else '否'}") if black_subject: print(f" - 黑色阈值: {black_threshold}") print(f" - 灰阶内容作为主体: {'是' if gray_subject else '否'}") if gray_subject: print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}") print(f" - 灰阶亮度阈值: {gray_value_threshold}") print(f" - 边缘过渡: {'是' if feather else '否'}") if feather: print(f" - 过渡半径: {feather_radius}") print(f" - 保存mask: {'是' if save_mask else '否'}") print("-" * 50) # 处理每张图片 for i, image_file in enumerate(image_files, 1): try: if remove_subject: suffix = image_file.suffix.lower() if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}: output_filename = image_file.stem + "_bgfill" + suffix else: output_filename = image_file.stem + "_bgfill.jpg" else: # 去背景默认使用PNG格式以支持透明背景 output_filename = image_file.stem + "_nobg.png" output_path = Path(output_folder) / output_filename print(f"[{i}/{len(image_files)}] ", end="") if remove_subject: mask_output_path = None if save_mask: mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png")) remove_subject_and_inpaint( str(image_file), str(output_path), session=session, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size, post_process_mask=post_process_mask, mask_dilate=mask_dilate, mask_blur=mask_blur, mask_threshold=mask_threshold, edge_grow=edge_grow, aot_root=aot_root, aot_pretrain=aot_pretrain, aot_device=aot_device, aot_block_num=aot_block_num, aot_rates=aot_rates, aot_crop=aot_crop, aot_crop_pad=aot_crop_pad, aot_max_size=aot_max_size, aot_noise_prefill=aot_noise_prefill, aot_noise_strength=aot_noise_strength, black_subject=black_subject, black_threshold=black_threshold, gray_subject=gray_subject, gray_saturation_threshold=gray_saturation_threshold, gray_value_threshold=gray_value_threshold, feather=feather, feather_radius=feather_radius, save_mask=save_mask, mask_output_path=mask_output_path, ) else: remove_background( str(image_file), str(output_path), session=session, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size, post_process_mask=post_process_mask, ) except Exception as e: print(f"处理 {image_file.name} 时出错: {e}") print("-" * 50) print(f"处理完成!结果保存在 {output_folder} 文件夹中") if __name__ == "__main__": parser = argparse.ArgumentParser( description='图片去背景工具 - 使用rembg自动去除图片背景', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=''' 示例用法: # 使用默认参数处理images文件夹 python remove_background.py # 处理单个文件 python remove_background.py input.jpg output.png # 处理指定文件夹 python remove_background.py my_images/ my_output/ # 使用不同模型 python remove_background.py input.jpg output.png -m birefnet-portrait # 自定义alpha matting参数 python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5 ''' ) # 必需参数 parser.add_argument('input', nargs='?', default='images', help='输入文件或文件夹路径(默认: images)') parser.add_argument('output', nargs='?', default=None, help='输出文件或文件夹路径(可选,默认为output/)') # 模型选择 parser.add_argument('-m', '--model', default='isnet-general-use', choices=['u2net', 'u2netp', 'u2net_human_seg', 'silueta', 'isnet-general-use', 'isnet-anime', 'birefnet-general', 'birefnet-general-lite', 'birefnet-portrait', 'birefnet-dis', 'birefnet-hrsod', 'birefnet-cod', 'birefnet-massive'], help='选择使用的模型 (默认: isnet-general-use)') # Alpha Matting参数 parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true', help='启用alpha matting后处理(默认: false)') parser.add_argument('-ft', '--foreground-threshold', type=int, default=245, help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)') parser.add_argument('-bt', '--background-threshold', type=int, default=8, help='背景阈值 (0-255),值越大去除越多背景 (默认: 8)') parser.add_argument('-es', '--erode-size', type=int, default=2, help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)') # 其他选项 parser.add_argument('-p', '--post-process', '--post_process', action='store_true', help='启用mask后处理(默认: false)') # 去主体补背景参数 parser.add_argument('--remove-subject', '--remove_subject', action='store_true', help='去掉主体并补全背景(默认: false)') parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting', help='AOT-GAN目录(默认: AOT-GAN-for-Inpainting)') parser.add_argument('--aot-pretrain', type=str, default=None, help='AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)') parser.add_argument('--aot-device', type=str, default='cpu', help='AOT-GAN设备(默认: cpu)') parser.add_argument('--aot-block-num', type=int, default=8, help='AOTBlock数量(默认: 8)') parser.add_argument('--aot-rates', type=str, default='1+2+4+8', help='AOTBlock膨胀率(默认: 1+2+4+8)') parser.add_argument('--aot-crop', action='store_true', help='AOT仅对mask区域裁剪修补(默认: false)') parser.add_argument('--aot-crop-pad', type=int, default=0, help='AOT裁剪边缘留白像素(默认: 0)') parser.add_argument('--aot-max-size', type=int, default=0, help='AOT输入最大边限制,0为不限制(默认: 0)') parser.add_argument('--aot-noise-prefill', action='store_true', help='AOT使用随机噪声预填充(默认: false)') parser.add_argument('--aot-noise-strength', type=float, default=1.0, help='AOT噪声强度系数(默认: 1.0)') parser.add_argument('--mask-dilate', type=int, default=3, help='mask膨胀大小(默认: 3)') parser.add_argument('--mask-blur', type=int, default=3, help='mask模糊大小(默认: 3,建议奇数)') parser.add_argument('--mask-threshold', type=int, default=10, help='alpha阈值(默认: 10)') parser.add_argument('--edge-grow', type=int, default=0, help='主体边缘扩张像素(默认: 0)') parser.add_argument('--save-mask', '--save_mask', action='store_true', help='保存mask到output目录(默认: false)') parser.add_argument('--black-subject', '--black_subject', action='store_true', help='将黑色内容也视为主体(默认: false)') parser.add_argument('--black-threshold', type=int, default=50, help='黑色阈值(0-255,灰度越小越黑,默认: 50)') parser.add_argument('--gray-subject', '--gray_subject', action='store_true', help='将灰阶内容也视为主体(默认: false)') parser.add_argument('--gray-saturation-threshold', type=int, default=30, help='灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)') parser.add_argument('--gray-value-threshold', type=int, default=200, help='灰阶亮度阈值(0-255,越小越暗,默认: 200)') parser.add_argument('--feather', action='store_true', help='启用边缘过渡融合(默认: false)') parser.add_argument('--feather-radius', type=int, default=5, help='边缘过渡半径(默认: 5)') args = parser.parse_args() print("=" * 50) print("图片去背景工具") print("=" * 50) # 判断输入是文件还是文件夹 input_path = Path(args.input) if not input_path.exists(): print(f"错误: 输入路径不存在: {args.input}") exit(1) # 处理单个文件 if input_path.is_file(): suffix = input_path.suffix.lower() if args.remove_subject: if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}: output_name = input_path.stem + "_bgfill" + suffix else: output_name = input_path.stem + "_bgfill.jpg" else: output_name = input_path.stem + "_nobg.png" # 确定输出路径 if args.output is None: if args.remove_subject: output_path = input_path.parent / "output" / output_name else: output_path = input_path.parent / 'output' / output_name output_path.parent.mkdir(parents=True, exist_ok=True) else: output_candidate = Path(args.output) if output_candidate.exists() and output_candidate.is_dir(): output_path = output_candidate / output_name output_path.parent.mkdir(parents=True, exist_ok=True) elif output_candidate.suffix == "": output_candidate.mkdir(parents=True, exist_ok=True) output_path = output_candidate / output_name else: output_path = output_candidate output_path.parent.mkdir(parents=True, exist_ok=True) print(f"输入文件: {input_path}") print(f"输出文件: {output_path}") print(f"模型: {args.model}") print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}") if args.alpha_matting: print(f" - 前景阈值: {args.foreground_threshold}") print(f" - 背景阈值: {args.background_threshold}") print(f" - 侵蚀大小: {args.erode_size}") print(f"Mask后处理: {'启用' if args.post_process else '禁用'}") print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}") if args.remove_subject: print(f" - AOT目录: {args.aot_root}") print(f" - AOT权重: {args.aot_pretrain}") print(f" - AOT设备: {args.aot_device}") print(f" - AOT块数: {args.aot_block_num}") print(f" - AOT膨胀率: {args.aot_rates}") print(f" - AOT裁剪: {'是' if args.aot_crop else '否'}") if args.aot_crop: print(f" - AOT裁剪边界: {args.aot_crop_pad}") print(f" - AOT最大边: {args.aot_max_size}") print(f" - AOT噪声预填充: {'是' if args.aot_noise_prefill else '否'}") if args.aot_noise_prefill: print(f" - AOT噪声强度: {args.aot_noise_strength}") print(f" - mask膨胀: {args.mask_dilate}") print(f" - mask模糊: {args.mask_blur}") print(f" - mask阈值: {args.mask_threshold}") print(f" - 边缘扩张: {args.edge_grow}") print(f" - 黑色内容作为主体: {'是' if args.black_subject else '否'}") if args.black_subject: print(f" - 黑色阈值: {args.black_threshold}") print(f" - 灰阶内容作为主体: {'是' if args.gray_subject else '否'}") if args.gray_subject: print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}") print(f" - 灰阶亮度阈值: {args.gray_value_threshold}") print(f" - 边缘过渡: {'是' if args.feather else '否'}") if args.feather: print(f" - 过渡半径: {args.feather_radius}") print(f" - 保存mask: {'是' if args.save_mask else '否'}") print("-" * 50) # 创建会话 session = new_session(args.model) # 处理图片 if args.remove_subject: mask_output_path = None if args.save_mask: mask_output_path = str(output_path.with_suffix("")) + "_mask.png" remove_subject_and_inpaint( str(input_path), str(output_path), session=session, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, mask_dilate=args.mask_dilate, mask_blur=args.mask_blur, mask_threshold=args.mask_threshold, edge_grow=args.edge_grow, aot_root=args.aot_root, aot_pretrain=args.aot_pretrain, aot_device=args.aot_device, aot_block_num=args.aot_block_num, aot_rates=args.aot_rates, aot_crop=args.aot_crop, aot_crop_pad=args.aot_crop_pad, aot_max_size=args.aot_max_size, aot_noise_prefill=args.aot_noise_prefill, aot_noise_strength=args.aot_noise_strength, black_subject=args.black_subject, black_threshold=args.black_threshold, gray_subject=args.gray_subject, gray_saturation_threshold=args.gray_saturation_threshold, gray_value_threshold=args.gray_value_threshold, feather=args.feather, feather_radius=args.feather_radius, save_mask=args.save_mask, mask_output_path=mask_output_path, ) else: remove_background( str(input_path), str(output_path), session=session, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, ) print("-" * 50) print(f"处理完成!结果保存在: {output_path}") # 处理文件夹 elif input_path.is_dir(): output_folder = args.output if args.output else 'output' process_images_folder( str(input_path), output_folder, model_name=args.model, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, remove_subject=args.remove_subject, mask_dilate=args.mask_dilate, mask_blur=args.mask_blur, mask_threshold=args.mask_threshold, edge_grow=args.edge_grow, aot_root=args.aot_root, aot_pretrain=args.aot_pretrain, aot_device=args.aot_device, aot_block_num=args.aot_block_num, aot_rates=args.aot_rates, aot_crop=args.aot_crop, aot_crop_pad=args.aot_crop_pad, aot_max_size=args.aot_max_size, aot_noise_prefill=args.aot_noise_prefill, aot_noise_strength=args.aot_noise_strength, black_subject=args.black_subject, black_threshold=args.black_threshold, gray_subject=args.gray_subject, gray_saturation_threshold=args.gray_saturation_threshold, gray_value_threshold=args.gray_value_threshold, feather=args.feather, feather_radius=args.feather_radius, save_mask=args.save_mask, ) else: print(f"错误: 不支持的输入类型: {args.input}") exit(1)