"""书画与篆刻作品去背景工具。""" import os import sys import argparse from pathlib import Path import numpy as np import cv2 from PIL import Image, ImageOps # 避免 numba 在某些环境下缓存失败 os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache") try: from rembg import remove, new_session except ImportError: remove = None new_session = None AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {} REMBG_SESSION_CACHE: dict[str, object] = {} def _resolve_aot_paths(aot_root, aot_pretrain): """解析 AOT-GAN 的路径配置。""" aot_root = Path(aot_root).resolve() if aot_pretrain is None: aot_pretrain = None else: aot_pretrain = Path(aot_pretrain) if not aot_pretrain.is_absolute(): # 预训练权重相对路径以当前工作目录为基准 aot_pretrain = (Path.cwd() / aot_pretrain).resolve() return aot_root, aot_pretrain def _parse_aot_rates(rates_str): parts = [p for p in rates_str.split("+") if p] return [int(p) for p in parts] def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None): """加载/缓存 AOT-GAN 模型。""" aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain) if not aot_root.exists(): raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}") if aot_pretrain is None: raise ValueError("AOT-GAN需要指定预训练权重路径(--aot-pretrain)") if not aot_pretrain.exists(): raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}") if rates is None: rates = [1, 2, 4, 8] rates_tuple = tuple(rates) key = (str(aot_pretrain), device, int(block_num), rates_tuple) if key in AOT_MODEL_CACHE: return AOT_MODEL_CACHE[key] src_root = aot_root / "src" if str(src_root) not in sys.path: sys.path.insert(0, str(src_root)) import importlib try: import torch except ImportError as exc: raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc net = importlib.import_module("model.aotgan") class _Args: pass args = _Args() args.block_num = int(block_num) args.rates = rates model = net.InpaintGenerator(args) state = torch.load(str(aot_pretrain), map_location=device) if isinstance(state, dict): if "state_dict" in state: state = state["state_dict"] elif "model" in state: state = state["model"] elif "generator" in state: state = state["generator"] if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()): state = {k.replace("module.", "", 1): v for k, v in state.items()} model.load_state_dict(state, strict=True) model.to(device) model.eval() AOT_MODEL_CACHE[key] = (model, device) return AOT_MODEL_CACHE[key] def _mask_bbox(mask): ys, xs = np.where(mask > 0) if len(xs) == 0: return None return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) def _expand_bbox(bbox, pad, width, height): if bbox is None: return None x0, y0, x1, y1 = bbox pad = max(0, int(pad)) x0 = max(0, x0 - pad) y0 = max(0, y0 - pad) x1 = min(width - 1, x1 + pad) y1 = min(height - 1, y1 + pad) return x0, y0, x1, y1 def _expand_bbox_min_size(bbox, pad, width, height, min_size): """在给定 pad 基础上,确保裁剪区域至少为 min_size。""" expanded = _expand_bbox(bbox, pad, width, height) if expanded is None: return None x0, y0, x1, y1 = expanded roi_w = x1 - x0 + 1 roi_h = y1 - y0 + 1 need_w = max(0, int(min_size) - roi_w) need_h = max(0, int(min_size) - roi_h) if need_w == 0 and need_h == 0: return expanded extra = max((need_w + 1) // 2, (need_h + 1) // 2) return _expand_bbox(bbox, pad + extra, width, height) def _build_noise_prefill(img, mask_t, strength): """为mask区域生成噪声预填充,img取值范围为[-1, 1]。""" import torch img01 = (img + 1.0) / 2.0 unmasked = 1.0 - mask_t denom = unmasked.sum() if denom.item() < 1.0: mean = torch.full((1, 3, 1, 1), 0.5, device=img.device) std = torch.full((1, 3, 1, 1), 0.2, device=img.device) else: mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom std = torch.sqrt(var + 1e-6) noise = mean + std * float(strength) * torch.randn_like(img01) noise = noise.clamp(0.0, 1.0) return noise * 2.0 - 1.0 def _inpaint_with_aot_core( bgr, mask, aot_root, aot_pretrain, device="cpu", block_num=8, rates=None, noise_prefill=False, noise_strength=1.0, ): """使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。""" if mask.dtype != np.uint8: mask = mask.astype(np.uint8) if mask.max() <= 1: mask = mask * 255 if mask.max() == 0: return bgr h, w = mask.shape grid = 4 h2 = (h // grid) * grid w2 = (w // grid) * grid if h2 == 0 or w2 == 0: return bgr rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) img_crop = rgb[:h2, :w2, :] mask_crop = mask[:h2, :w2] import torch img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0 img = img * 2.0 - 1.0 mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0) img = img.unsqueeze(0) mask_t = mask_t.unsqueeze(0) if device != "cpu": img = img.to(device) mask_t = mask_t.to(device) model, _ = _get_aot_model( aot_root=aot_root, aot_pretrain=aot_pretrain, device=device, block_num=block_num, rates=rates, ) with torch.no_grad(): if noise_prefill: noise = _build_noise_prefill(img, mask_t, noise_strength) image_masked = img * (1 - mask_t) + noise * mask_t else: image_masked = img * (1 - mask_t) + mask_t pred = model(image_masked, mask_t) comp = pred * mask_t + img * (1 - mask_t) comp = comp[0].clamp(-1.0, 1.0) comp = (comp + 1.0) / 2.0 * 255.0 comp = comp.permute(1, 2, 0).byte().cpu().numpy() result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR) if h2 == h and w2 == w: return result_bgr output_full = bgr.copy() output_full[:h2, :w2, :] = result_bgr return output_full def _inpaint_with_aot( bgr, mask, aot_root, aot_pretrain, device="cpu", block_num=8, rates=None, crop=False, crop_pad=0, max_size=0, noise_prefill=False, noise_strength=1.0, ): """使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。""" min_side = 32 if mask.dtype != np.uint8: mask = mask.astype(np.uint8) if mask.max() <= 1: mask = mask * 255 if mask.max() == 0: return bgr h, w = mask.shape x0 = y0 = 0 x1 = w - 1 y1 = h - 1 if crop: bbox = _mask_bbox(mask) if bbox is None: return bgr expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side) if expanded is None: return bgr x0, y0, x1, y1 = expanded roi_w = x1 - x0 + 1 roi_h = y1 - y0 + 1 if roi_w < min_side or roi_h < min_side: # 裁剪区域过小会导致 AOT 失败,回退到全图修补 crop = False x0 = y0 = 0 x1 = w - 1 y1 = h - 1 bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1] mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1] roi_h, roi_w = bgr_roi.shape[:2] scale = 1.0 max_size = int(max_size) if max_size else 0 if max_size > 0: max_dim = max(roi_h, roi_w) if max_dim > max_size: scale = max_size / float(max_dim) new_w = max(1, int(round(roi_w * scale))) new_h = max(1, int(round(roi_h * scale))) if min(new_w, new_h) < min_side: scale = 1.0 new_w = roi_w new_h = roi_h interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR if scale != 1.0: bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp) mask_roi = cv2.resize( mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST ) filled_roi = _inpaint_with_aot_core( bgr_roi, mask_roi, aot_root=aot_root, aot_pretrain=aot_pretrain, device=device, block_num=block_num, rates=rates, noise_prefill=noise_prefill, noise_strength=noise_strength, ) if scale != 1.0: filled_roi = cv2.resize( filled_roi, (roi_w, roi_h), interpolation=cv2.INTER_LINEAR, ) if not crop: return filled_roi output_full = bgr.copy() output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi return output_full # 支持HEIC格式 try: from pillow_heif import register_heif_opener register_heif_opener() HEIC_SUPPORTED = True except ImportError: HEIC_SUPPORTED = False def _get_rembg_session(model_name): """按模型缓存 rembg 会话,避免重复加载。""" if new_session is None: raise ImportError( "未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。" ) if model_name not in REMBG_SESSION_CACHE: REMBG_SESSION_CACHE[model_name] = new_session(model_name) return REMBG_SESSION_CACHE[model_name] def _resize_for_processing(image, max_size): """按最大边缩放图片,返回缩放后的图片与缩放比例。""" height, width = image.shape[:2] max_size = int(max_size) if max_size else 0 if max_size <= 0 or max(height, width) <= max_size: return image.copy(), 1.0 scale = max_size / float(max(height, width)) resized = cv2.resize( image, (max(1, int(round(width * scale))), max(1, int(round(height * scale)))), interpolation=cv2.INTER_AREA, ) return resized, scale def _otsu_threshold(channel, floor=0): """返回 Otsu 阈值,并设置一个最小下限避免纸张纹理误检。""" if channel.size == 0 or int(channel.max()) <= 0: return int(floor) threshold, _ = cv2.threshold(channel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return max(int(round(threshold)), int(floor)) def _remove_small_components(mask, min_area): """移除面积过小的连通域,降低纸张纹理噪声。""" if min_area <= 0 or int(mask.max()) == 0: return mask num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( mask, connectivity=8 ) cleaned = np.zeros_like(mask) for label in range(1, num_labels): area = int(stats[label, cv2.CC_STAT_AREA]) if area >= min_area: cleaned[labels == label] = 255 return cleaned def _remove_border_frame_components( mask, min_width_ratio=0.7, min_height_ratio=0.7, max_fill_ratio=0.35, ): """移除贴边的大型空心边框连通域,避免纸张外轮廓被误判为前景。""" if int(mask.max()) == 0: return mask height, width = mask.shape num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) cleaned = mask.copy() for label in range(1, num_labels): x, y, comp_w, comp_h, area = stats[label] touches_border = x == 0 or y == 0 or (x + comp_w) == width or (y + comp_h) == height if not touches_border: continue if comp_w < int(width * float(min_width_ratio)): continue if comp_h < int(height * float(min_height_ratio)): continue fill_ratio = float(area) / float(max(1, comp_w * comp_h)) if fill_ratio <= float(max_fill_ratio): cleaned[labels == label] = 0 return cleaned def _artwork_mask_is_reasonable(mask): """检查书画掩码是否可信,用于 auto 模式回退。""" if mask.size == 0 or int(mask.max()) == 0: return False coverage = float(np.count_nonzero(mask)) / float(mask.size) if coverage < 0.0005 or coverage > 0.55: return False height, width = mask.shape border = max(4, min(height, width) // 32) border_pixels = np.concatenate( [ mask[:border, :].reshape(-1), mask[-border:, :].reshape(-1), mask[:, :border].reshape(-1), mask[:, -border:].reshape(-1), ] ) border_ratio = float(np.count_nonzero(border_pixels)) / float(border_pixels.size) return border_ratio <= 0.28 def _compute_detail_alpha(rgb_image): """根据局部暗度与严格红章特征生成细节 alpha。""" height, width = rgb_image.shape[:2] gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY) sigma_bg = max(8.0, max(height, width) / 80.0) bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) dark_score = cv2.subtract(bg_gray, gray).astype(np.float32) dark_alpha = np.clip((dark_score - 18.0) / 48.0, 0.0, 1.0) hsv = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2HSV) hue = hsv[:, :, 0].astype(np.float32) saturation = hsv[:, :, 1].astype(np.float32) value = hsv[:, :, 2].astype(np.float32) red = rgb_image[:, :, 0].astype(np.float32) green = rgb_image[:, :, 1].astype(np.float32) blue = rgb_image[:, :, 2].astype(np.float32) red_dominance = red - np.maximum(green, blue) red_hue_mask = ((hue <= 12.0) | (hue >= 168.0)).astype(np.float32) seal_alpha = np.clip((red_dominance - 20.0) / 40.0, 0.0, 1.0) seal_alpha *= np.clip((saturation - 50.0) / 80.0, 0.0, 1.0) seal_alpha *= (value >= 40.0).astype(np.float32) seal_alpha *= red_hue_mask detail_alpha = np.maximum(dark_alpha, seal_alpha) return np.clip((detail_alpha - 0.07) / 0.93, 0.0, 1.0) def _extract_artwork_mask(input_image, artwork_type="auto", max_size=1600): """为书画/篆刻作品提取前景掩码,避免依赖通用人像分割模型。""" rgb = np.array(input_image.convert("RGB")) resized, scale = _resize_for_processing(rgb, max_size=max_size) height, width = resized.shape[:2] gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY) lab = cv2.cvtColor(resized, cv2.COLOR_RGB2LAB) hsv = cv2.cvtColor(resized, cv2.COLOR_RGB2HSV) sigma_bg = max(6.0, max(height, width) / 40.0) bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) dark_score = cv2.subtract(bg_gray, gray) lab_a = lab[:, :, 1] lab_b = lab[:, :, 2] bg_a = cv2.GaussianBlur(lab_a, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) bg_b = cv2.GaussianBlur(lab_b, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) chroma_score = np.sqrt( (lab_a.astype(np.float32) - bg_a.astype(np.float32)) ** 2 + (lab_b.astype(np.float32) - bg_b.astype(np.float32)) ** 2 ) chroma_score = np.clip(chroma_score * 3.0, 0, 255).astype(np.uint8) red = resized[:, :, 0].astype(np.int16) green = resized[:, :, 1].astype(np.int16) blue = resized[:, :, 2].astype(np.int16) hue = hsv[:, :, 0] saturation = hsv[:, :, 1] value = hsv[:, :, 2] red_dominance = np.clip(red - np.maximum(green, blue), 0, 255).astype(np.uint8) seal_score = np.clip( red_dominance.astype(np.int16) * 2 + np.clip(saturation.astype(np.int16) - 45, 0, 255), 0, 255, ).astype(np.uint8) seal_score = np.where( ((hue <= 12) | (hue >= 168)) & (saturation >= 55) & (value >= 40) & (red_dominance >= 20), seal_score, 0, ).astype(np.uint8) color_score = np.maximum(chroma_score, seal_score) if artwork_type == "seal": combined_score = np.maximum( np.clip(dark_score.astype(np.float32) * 1.0, 0, 255).astype(np.uint8), np.clip(color_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8), ) dark_floor = 12 color_floor = 18 combined_floor = 20 elif artwork_type == "calligraphy": combined_score = np.maximum( np.clip(dark_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8), np.clip(color_score.astype(np.float32) * 0.8, 0, 255).astype(np.uint8), ) dark_floor = 14 color_floor = 28 combined_floor = 18 else: combined_score = np.maximum( np.clip(dark_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8), np.clip(color_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8), ) dark_floor = 14 color_floor = 16 combined_floor = 18 smooth_sigma = max(0.6, max(height, width) / 320.0) dark_smooth = cv2.GaussianBlur( dark_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma ) color_smooth = cv2.GaussianBlur( color_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma ) combined_smooth = cv2.GaussianBlur( combined_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma, ) dark_mask = (dark_smooth >= _otsu_threshold(dark_smooth, floor=dark_floor)).astype( np.uint8 ) * 255 color_mask = ( color_smooth >= _otsu_threshold(color_smooth, floor=color_floor) ).astype(np.uint8) * 255 combined_mask = ( combined_smooth >= _otsu_threshold(combined_smooth, floor=combined_floor) ).astype(np.uint8) * 255 if artwork_type == "seal": mask = cv2.bitwise_or(combined_mask, color_mask) mask = cv2.bitwise_or(mask, dark_mask) elif artwork_type == "calligraphy": mask = cv2.bitwise_or(combined_mask, dark_mask) else: mask = cv2.bitwise_or(combined_mask, cv2.bitwise_or(dark_mask, color_mask)) close_kernel = np.ones((3, 3), dtype=np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=1) min_area = max(12, (height * width) // 25000) mask = _remove_small_components(mask, min_area=min_area) mask = _remove_border_frame_components(mask) mask = cv2.dilate(mask, close_kernel, iterations=1) if scale < 1.0: mask = cv2.resize( mask, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST ) return mask def _mask_to_transparent_image(input_image, mask): """将二值掩码转为带透明背景的 RGBA 图片。""" rgb = np.array(input_image.convert("RGB")) detail_alpha = _compute_detail_alpha(rgb) alpha = ((mask.astype(np.float32) / 255.0) * detail_alpha * 255.0).astype(np.uint8) alpha = cv2.GaussianBlur(alpha, (0, 0), sigmaX=0.8, sigmaY=0.8) alpha = np.where(mask > 0, alpha, 0).astype(np.uint8) rgba = np.dstack((rgb, alpha)) return Image.fromarray(rgba) def _extract_foreground_mask( input_image, session=None, model_name="isnet-general-use", foreground_mode="artwork", artwork_type="auto", artwork_max_size=1600, mask_threshold=10, **kwargs, ): """按指定模式提取前景掩码。""" if foreground_mode in {"artwork", "auto"}: artwork_mask = _extract_artwork_mask( input_image, artwork_type=artwork_type, max_size=artwork_max_size, ) if foreground_mode == "artwork" or _artwork_mask_is_reasonable(artwork_mask): return artwork_mask, "artwork" if foreground_mode not in {"rembg", "auto"}: raise ValueError(f"不支持的前景提取模式: {foreground_mode}") if remove is None: raise ImportError( "未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。" ) if session is None: session = _get_rembg_session(model_name) output_image = remove(input_image, session=session, **kwargs) output_image = _ensure_rgba_size(output_image, input_image.size) alpha = np.array(output_image.getchannel("A")) return _alpha_to_mask(alpha, threshold=mask_threshold), "rembg" def remove_background( input_path, output_path, session=None, model_name="isnet-general-use", foreground_mode="artwork", artwork_type="auto", artwork_max_size=1600, **kwargs, ): """ 去除图片背景。 Args: input_path: 输入图片路径 output_path: 输出图片路径 session: rembg会话对象(可选) model_name: rembg 模型名称 foreground_mode: 前景提取模式,可选 artwork / auto / rembg artwork_type: 书画类型,可选 auto / calligraphy / seal artwork_max_size: 书画掩码估算时的最大边 **kwargs: 其他参数,如alpha_matting相关参数 """ print(f"正在处理: {input_path}") # 读取输入图片 input_image = ImageOps.exif_transpose(Image.open(input_path)) if foreground_mode in {"artwork", "auto"}: mask = _extract_artwork_mask( input_image, artwork_type=artwork_type, max_size=artwork_max_size, ) if foreground_mode == "artwork" or _artwork_mask_is_reasonable(mask): output_image = _mask_to_transparent_image(input_image, mask) output_image.save(output_path) print(f"已保存: {output_path}") return if remove is None: raise ImportError( "未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。" ) if session is None: session = _get_rembg_session(model_name) output_image = remove(input_image, session=session, **kwargs) output_image.save(output_path) print(f"已保存: {output_path}") def _pil_to_bgr(image): """将PIL图片转换为OpenCV BGR格式""" rgb = image.convert("RGB") arr = np.array(rgb) return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) def _alpha_to_mask(alpha_channel, threshold=10): """将alpha通道转换为二值mask(0或255)""" mask = (alpha_channel > threshold).astype(np.uint8) * 255 return mask def _ensure_rgba_size(image, target_size): if image.mode != "RGBA": image = image.convert("RGBA") if image.size != target_size: image = image.resize(target_size, resample=Image.NEAREST) return image def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0): """生成硬边mask与处理后mask(用于填补/过渡)""" if mask_dilate and mask_dilate > 0: kernel = np.ones((mask_dilate, mask_dilate), np.uint8) mask = cv2.dilate(mask, kernel, iterations=1) if edge_grow and edge_grow > 0: kernel = np.ones((3, 3), np.uint8) mask = cv2.dilate(mask, kernel, iterations=int(edge_grow)) mask_hard = mask.copy() mask_used = mask_hard if mask_blur and mask_blur > 0: k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1 mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0) mask_used = (mask_used > 0).astype(np.uint8) * 255 return mask_hard, mask_used def remove_subject_and_inpaint( input_path, output_path, session=None, model_name="isnet-general-use", foreground_mode="artwork", artwork_type="auto", artwork_max_size=1600, mask_dilate=3, mask_blur=3, mask_threshold=10, edge_grow=0, aot_root="AOT-GAN-for-Inpainting", aot_pretrain=None, aot_device="cpu", aot_block_num=8, aot_rates="1+2+4+8", aot_crop=False, aot_crop_pad=0, aot_max_size=0, aot_noise_prefill=False, aot_noise_strength=1.0, black_subject=False, black_threshold=50, gray_subject=False, gray_saturation_threshold=30, gray_value_threshold=200, feather=False, feather_radius=5, save_mask=False, mask_output_path=None, **kwargs, ): """ 去掉主体并补全背景 Args: input_path: 输入图片路径 output_path: 输出图片路径 session: rembg会话对象(可选) model_name: rembg 模型名称 foreground_mode: 前景提取模式,可选 artwork / auto / rembg artwork_type: 书画类型,可选 auto / calligraphy / seal artwork_max_size: 书画掩码估算时的最大边 aot_root: AOT-GAN目录 aot_pretrain: AOT-GAN权重文件路径 aot_device: AOT-GAN设备 aot_block_num: AOT-GAN AOTBlock 数量 aot_rates: AOT-GAN AOTBlock 膨胀率 aot_crop: AOT仅裁剪mask区域进行修补 aot_crop_pad: AOT裁剪边缘留白像素 aot_max_size: AOT输入最大边限制(0为不限制) aot_noise_prefill: AOT使用随机噪声预填充 aot_noise_strength: 噪声强度系数 mask_dilate: mask膨胀大小 mask_blur: mask模糊大小(奇数) mask_threshold: alpha阈值 save_mask: 是否保存mask mask_output_path: mask保存路径 **kwargs: rembg参数 """ print(f"正在处理(去主体补背景): {input_path}") input_image = ImageOps.exif_transpose(Image.open(input_path)) mask, mask_source = _extract_foreground_mask( input_image, session=session, model_name=model_name, foreground_mode=foreground_mode, artwork_type=artwork_type, artwork_max_size=artwork_max_size, mask_threshold=mask_threshold, **kwargs, ) bgr = _pil_to_bgr(input_image) if black_subject: gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) black_mask = (gray <= black_threshold).astype(np.uint8) * 255 mask = cv2.bitwise_or(mask, black_mask) if gray_subject: hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) s = hsv[:, :, 1] v = hsv[:, :, 2] gray_mask = ( (s <= gray_saturation_threshold) & (v <= gray_value_threshold) ).astype(np.uint8) * 255 mask = cv2.bitwise_or(mask, gray_mask) mask_hard, mask_used = _prepare_mask( mask, mask_dilate=mask_dilate, mask_blur=mask_blur, edge_grow=edge_grow, ) rates = _parse_aot_rates(aot_rates) filled = _inpaint_with_aot( bgr, mask_used, aot_root=aot_root, aot_pretrain=aot_pretrain, device=aot_device, block_num=aot_block_num, rates=rates, crop=aot_crop, crop_pad=aot_crop_pad, max_size=aot_max_size, noise_prefill=aot_noise_prefill, noise_strength=aot_noise_strength, ) if feather and feather_radius > 0: # 仅在mask外侧做过渡,避免把原主体带回 mask_bin = (mask_hard > 0).astype(np.uint8) * 255 dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3) alpha = np.ones_like(dist_out, dtype=np.float32) outside = mask_bin == 0 alpha[outside] = np.clip( 1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0 ) alpha = alpha[:, :, None] blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8) result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) else: result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB) Image.fromarray(result).save(output_path) if save_mask: if mask_output_path is None: mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png" Image.fromarray(mask_used).save(mask_output_path) print(f"已保存mask: {mask_output_path}") print(f"前景提取模式: {mask_source}") print(f"已保存: {output_path}") def process_images_folder( input_folder, output_folder, model_name="isnet-general-use", foreground_mode="artwork", artwork_type="auto", artwork_max_size=1600, alpha_matting=False, alpha_matting_foreground_threshold=240, alpha_matting_background_threshold=10, alpha_matting_erode_size=10, post_process_mask=False, remove_subject=False, mask_dilate=3, mask_blur=3, mask_threshold=10, edge_grow=0, aot_root="AOT-GAN-for-Inpainting", aot_pretrain=None, aot_device="cpu", aot_block_num=8, aot_rates="1+2+4+8", aot_crop=False, aot_crop_pad=0, aot_max_size=0, aot_noise_prefill=False, aot_noise_strength=1.0, black_subject=False, black_threshold=50, gray_subject=False, gray_saturation_threshold=30, gray_value_threshold=200, feather=False, feather_radius=5, save_mask=False, ): """ 批量处理文件夹中的所有图片 Args: input_folder: 输入文件夹路径 output_folder: 输出文件夹路径 model_name: rembg 模型名称,可选值: - u2net (默认): 通用模型 - u2netp: 轻量版u2net - u2net_human_seg: 人物分割 - silueta: 精简版u2net (43MB) - isnet-general-use: 新的通用模型 - isnet-anime: 动漫角色高精度分割 - birefnet-general: 通用模型 - birefnet-portrait: 人像模型 foreground_mode: 前景提取模式,可选 artwork / auto / rembg artwork_type: 书画类型,可选 auto / calligraphy / seal artwork_max_size: 书画掩码估算时的最大边 alpha_matting: 是否启用alpha matting后处理(改善边缘质量) alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景 alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景 alpha_matting_erode_size: 侵蚀大小,用于平滑边缘 post_process_mask: 是否启用mask后处理 """ # 创建输出文件夹 Path(output_folder).mkdir(parents=True, exist_ok=True) print(f"前景提取模式: {foreground_mode}") print(f"书画类型: {artwork_type}") if foreground_mode in {"rembg", "auto"}: print(f"rembg模型: {model_name}") # 支持的图片格式 image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} if HEIC_SUPPORTED: image_extensions.update({".heic", ".heif"}) else: print( "提示: 未安装pillow-heif,HEIC格式不可用。安装方法: pip install pillow-heif" ) # 获取所有图片文件 input_path = Path(input_folder) image_files = [ f for f in input_path.iterdir() if f.is_file() and f.suffix.lower() in image_extensions ] if not image_files: print(f"在 {input_folder} 中没有找到图片文件") return print(f"找到 {len(image_files)} 张图片,开始处理...") print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}") if alpha_matting: print(f" - 前景阈值: {alpha_matting_foreground_threshold}") print(f" - 背景阈值: {alpha_matting_background_threshold}") print(f" - 侵蚀大小: {alpha_matting_erode_size}") print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}") print(f"去主体补背景: {'启用' if remove_subject else '禁用'}") if remove_subject: print(f" - AOT目录: {aot_root}") print(f" - AOT权重: {aot_pretrain}") print(f" - AOT设备: {aot_device}") print(f" - AOT块数: {aot_block_num}") print(f" - AOT膨胀率: {aot_rates}") print(f" - AOT裁剪: {'是' if aot_crop else '否'}") if aot_crop: print(f" - AOT裁剪边界: {aot_crop_pad}") print(f" - AOT最大边: {aot_max_size}") print(f" - AOT噪声预填充: {'是' if aot_noise_prefill else '否'}") if aot_noise_prefill: print(f" - AOT噪声强度: {aot_noise_strength}") print(f" - mask膨胀: {mask_dilate}") print(f" - mask模糊: {mask_blur}") print(f" - mask阈值: {mask_threshold}") print(f" - 边缘扩张: {edge_grow}") print(f" - 黑色内容作为主体: {'是' if black_subject else '否'}") if black_subject: print(f" - 黑色阈值: {black_threshold}") print(f" - 灰阶内容作为主体: {'是' if gray_subject else '否'}") if gray_subject: print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}") print(f" - 灰阶亮度阈值: {gray_value_threshold}") print(f" - 边缘过渡: {'是' if feather else '否'}") if feather: print(f" - 过渡半径: {feather_radius}") print(f" - 保存mask: {'是' if save_mask else '否'}") print("-" * 50) # 处理每张图片 for i, image_file in enumerate(image_files, 1): try: if remove_subject: suffix = image_file.suffix.lower() if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}: output_filename = image_file.stem + "_bgfill" + suffix else: output_filename = image_file.stem + "_bgfill.jpg" else: # 去背景默认使用PNG格式以支持透明背景 output_filename = image_file.stem + "_nobg.png" output_path = Path(output_folder) / output_filename print(f"[{i}/{len(image_files)}] ", end="") if remove_subject: mask_output_path = None if save_mask: mask_output_path = str( Path(output_folder) / (image_file.stem + "_mask.png") ) remove_subject_and_inpaint( str(image_file), str(output_path), model_name=model_name, foreground_mode=foreground_mode, artwork_type=artwork_type, artwork_max_size=artwork_max_size, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size, post_process_mask=post_process_mask, mask_dilate=mask_dilate, mask_blur=mask_blur, mask_threshold=mask_threshold, edge_grow=edge_grow, aot_root=aot_root, aot_pretrain=aot_pretrain, aot_device=aot_device, aot_block_num=aot_block_num, aot_rates=aot_rates, aot_crop=aot_crop, aot_crop_pad=aot_crop_pad, aot_max_size=aot_max_size, aot_noise_prefill=aot_noise_prefill, aot_noise_strength=aot_noise_strength, black_subject=black_subject, black_threshold=black_threshold, gray_subject=gray_subject, gray_saturation_threshold=gray_saturation_threshold, gray_value_threshold=gray_value_threshold, feather=feather, feather_radius=feather_radius, save_mask=save_mask, mask_output_path=mask_output_path, ) else: remove_background( str(image_file), str(output_path), model_name=model_name, foreground_mode=foreground_mode, artwork_type=artwork_type, artwork_max_size=artwork_max_size, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size, post_process_mask=post_process_mask, ) except Exception as e: print(f"处理 {image_file.name} 时出错: {e}") print("-" * 50) print(f"处理完成!结果保存在 {output_folder} 文件夹中") if __name__ == "__main__": parser = argparse.ArgumentParser( description="书画与篆刻作品去背景工具 - 默认使用轻量书画掩码提取", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例用法: # 使用默认参数处理images文件夹 python remove_background.py # 处理单个文件 python remove_background.py input.jpg output.png # 处理指定文件夹 python remove_background.py my_images/ my_output/ # 强制使用 rembg 旧流程 python remove_background.py input.jpg output.png --foreground-mode rembg -m isnet-general-use # 指定为篆刻模式 python remove_background.py input.jpg output.png --artwork-type seal # 自定义alpha matting参数 python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5 """, ) # 必需参数 parser.add_argument( "input", nargs="?", default="images", help="输入文件或文件夹路径(默认: images)", ) parser.add_argument( "output", nargs="?", default=None, help="输出文件或文件夹路径(可选,默认为output/)", ) # 模型选择 parser.add_argument( "-m", "--model", default="isnet-general-use", choices=[ "u2net", "u2netp", "u2net_human_seg", "silueta", "isnet-general-use", "isnet-anime", "birefnet-general", "birefnet-general-lite", "birefnet-portrait", "birefnet-dis", "birefnet-hrsod", "birefnet-cod", "birefnet-massive", ], help="选择 rembg 模型(仅在 --foreground-mode rembg/auto 时使用)", ) parser.add_argument( "--foreground-mode", default="artwork", choices=["artwork", "auto", "rembg"], help="前景提取模式:artwork 为书画专用快速模式,auto 失败时回退 rembg,rembg 为旧流程 (默认: artwork)", ) parser.add_argument( "--artwork-type", default="auto", choices=["auto", "calligraphy", "seal"], help="书画类型:auto 自动兼容书法与篆刻,seal 更偏重红色印章,calligraphy 更偏重墨色笔画 (默认: auto)", ) parser.add_argument( "--artwork-max-size", type=int, default=1600, help="书画掩码估算时的最大边,越小越快,越大越精细 (默认: 1600)", ) # Alpha Matting参数 parser.add_argument( "-a", "--alpha-matting", "--alpha_matting", action="store_true", help="启用alpha matting后处理(默认: false)", ) parser.add_argument( "-ft", "--foreground-threshold", type=int, default=245, help="前景阈值 (0-255),值越大保留越多细节 (默认: 245)", ) parser.add_argument( "-bt", "--background-threshold", type=int, default=8, help="背景阈值 (0-255),值越大去除越多背景 (默认: 8)", ) parser.add_argument( "-es", "--erode-size", type=int, default=2, help="侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)", ) # 其他选项 parser.add_argument( "-p", "--post-process", "--post_process", action="store_true", help="启用mask后处理(默认: false)", ) # 去主体补背景参数 parser.add_argument( "--remove-subject", "--remove_subject", action="store_true", help="去掉主体并补全背景(默认: false)", ) parser.add_argument( "--aot-root", type=str, default="AOT-GAN-for-Inpainting", help="AOT-GAN目录(默认: AOT-GAN-for-Inpainting)", ) parser.add_argument( "--aot-pretrain", type=str, default=None, help="AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)", ) parser.add_argument( "--aot-device", type=str, default="cpu", help="AOT-GAN设备(默认: cpu)" ) parser.add_argument( "--aot-block-num", type=int, default=8, help="AOTBlock数量(默认: 8)" ) parser.add_argument( "--aot-rates", type=str, default="1+2+4+8", help="AOTBlock膨胀率(默认: 1+2+4+8)", ) parser.add_argument( "--aot-crop", action="store_true", help="AOT仅对mask区域裁剪修补(默认: false)" ) parser.add_argument( "--aot-crop-pad", type=int, default=0, help="AOT裁剪边缘留白像素(默认: 0)" ) parser.add_argument( "--aot-max-size", type=int, default=0, help="AOT输入最大边限制,0为不限制(默认: 0)", ) parser.add_argument( "--aot-noise-prefill", action="store_true", help="AOT使用随机噪声预填充(默认: false)", ) parser.add_argument( "--aot-noise-strength", type=float, default=1.0, help="AOT噪声强度系数(默认: 1.0)", ) parser.add_argument( "--mask-dilate", type=int, default=3, help="mask膨胀大小(默认: 3)" ) parser.add_argument( "--mask-blur", type=int, default=3, help="mask模糊大小(默认: 3,建议奇数)" ) parser.add_argument( "--mask-threshold", type=int, default=10, help="alpha阈值(默认: 10)" ) parser.add_argument( "--edge-grow", type=int, default=0, help="主体边缘扩张像素(默认: 0)" ) parser.add_argument( "--save-mask", "--save_mask", action="store_true", help="保存mask到output目录(默认: false)", ) parser.add_argument( "--black-subject", "--black_subject", action="store_true", help="将黑色内容也视为主体(默认: false)", ) parser.add_argument( "--black-threshold", type=int, default=50, help="黑色阈值(0-255,灰度越小越黑,默认: 50)", ) parser.add_argument( "--gray-subject", "--gray_subject", action="store_true", help="将灰阶内容也视为主体(默认: false)", ) parser.add_argument( "--gray-saturation-threshold", type=int, default=30, help="灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)", ) parser.add_argument( "--gray-value-threshold", type=int, default=200, help="灰阶亮度阈值(0-255,越小越暗,默认: 200)", ) parser.add_argument( "--feather", action="store_true", help="启用边缘过渡融合(默认: false)" ) parser.add_argument( "--feather-radius", type=int, default=5, help="边缘过渡半径(默认: 5)" ) args = parser.parse_args() print("=" * 50) print("图片去背景工具") print("=" * 50) # 判断输入是文件还是文件夹 input_path = Path(args.input) if not input_path.exists(): print(f"错误: 输入路径不存在: {args.input}") exit(1) # 处理单个文件 if input_path.is_file(): suffix = input_path.suffix.lower() if args.remove_subject: if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}: output_name = input_path.stem + "_bgfill" + suffix else: output_name = input_path.stem + "_bgfill.jpg" else: output_name = input_path.stem + "_nobg.png" # 确定输出路径 if args.output is None: if args.remove_subject: output_path = input_path.parent / "output" / output_name else: output_path = input_path.parent / "output" / output_name output_path.parent.mkdir(parents=True, exist_ok=True) else: output_candidate = Path(args.output) if output_candidate.exists() and output_candidate.is_dir(): output_path = output_candidate / output_name output_path.parent.mkdir(parents=True, exist_ok=True) elif output_candidate.suffix == "": output_candidate.mkdir(parents=True, exist_ok=True) output_path = output_candidate / output_name else: output_path = output_candidate output_path.parent.mkdir(parents=True, exist_ok=True) print(f"输入文件: {input_path}") print(f"输出文件: {output_path}") print(f"前景提取模式: {args.foreground_mode}") print(f"书画类型: {args.artwork_type}") print(f"书画最大边: {args.artwork_max_size}") if args.foreground_mode in {"rembg", "auto"}: print(f"rembg模型: {args.model}") print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}") if args.alpha_matting: print(f" - 前景阈值: {args.foreground_threshold}") print(f" - 背景阈值: {args.background_threshold}") print(f" - 侵蚀大小: {args.erode_size}") print(f"Mask后处理: {'启用' if args.post_process else '禁用'}") print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}") if args.remove_subject: print(f" - AOT目录: {args.aot_root}") print(f" - AOT权重: {args.aot_pretrain}") print(f" - AOT设备: {args.aot_device}") print(f" - AOT块数: {args.aot_block_num}") print(f" - AOT膨胀率: {args.aot_rates}") print(f" - AOT裁剪: {'是' if args.aot_crop else '否'}") if args.aot_crop: print(f" - AOT裁剪边界: {args.aot_crop_pad}") print(f" - AOT最大边: {args.aot_max_size}") print(f" - AOT噪声预填充: {'是' if args.aot_noise_prefill else '否'}") if args.aot_noise_prefill: print(f" - AOT噪声强度: {args.aot_noise_strength}") print(f" - mask膨胀: {args.mask_dilate}") print(f" - mask模糊: {args.mask_blur}") print(f" - mask阈值: {args.mask_threshold}") print(f" - 边缘扩张: {args.edge_grow}") print(f" - 黑色内容作为主体: {'是' if args.black_subject else '否'}") if args.black_subject: print(f" - 黑色阈值: {args.black_threshold}") print(f" - 灰阶内容作为主体: {'是' if args.gray_subject else '否'}") if args.gray_subject: print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}") print(f" - 灰阶亮度阈值: {args.gray_value_threshold}") print(f" - 边缘过渡: {'是' if args.feather else '否'}") if args.feather: print(f" - 过渡半径: {args.feather_radius}") print(f" - 保存mask: {'是' if args.save_mask else '否'}") print("-" * 50) # 处理图片 if args.remove_subject: mask_output_path = None if args.save_mask: mask_output_path = str(output_path.with_suffix("")) + "_mask.png" remove_subject_and_inpaint( str(input_path), str(output_path), model_name=args.model, foreground_mode=args.foreground_mode, artwork_type=args.artwork_type, artwork_max_size=args.artwork_max_size, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, mask_dilate=args.mask_dilate, mask_blur=args.mask_blur, mask_threshold=args.mask_threshold, edge_grow=args.edge_grow, aot_root=args.aot_root, aot_pretrain=args.aot_pretrain, aot_device=args.aot_device, aot_block_num=args.aot_block_num, aot_rates=args.aot_rates, aot_crop=args.aot_crop, aot_crop_pad=args.aot_crop_pad, aot_max_size=args.aot_max_size, aot_noise_prefill=args.aot_noise_prefill, aot_noise_strength=args.aot_noise_strength, black_subject=args.black_subject, black_threshold=args.black_threshold, gray_subject=args.gray_subject, gray_saturation_threshold=args.gray_saturation_threshold, gray_value_threshold=args.gray_value_threshold, feather=args.feather, feather_radius=args.feather_radius, save_mask=args.save_mask, mask_output_path=mask_output_path, ) else: remove_background( str(input_path), str(output_path), model_name=args.model, foreground_mode=args.foreground_mode, artwork_type=args.artwork_type, artwork_max_size=args.artwork_max_size, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, ) print("-" * 50) print(f"处理完成!结果保存在: {output_path}") # 处理文件夹 elif input_path.is_dir(): output_folder = args.output if args.output else "output" process_images_folder( str(input_path), output_folder, model_name=args.model, foreground_mode=args.foreground_mode, artwork_type=args.artwork_type, artwork_max_size=args.artwork_max_size, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, remove_subject=args.remove_subject, mask_dilate=args.mask_dilate, mask_blur=args.mask_blur, mask_threshold=args.mask_threshold, edge_grow=args.edge_grow, aot_root=args.aot_root, aot_pretrain=args.aot_pretrain, aot_device=args.aot_device, aot_block_num=args.aot_block_num, aot_rates=args.aot_rates, aot_crop=args.aot_crop, aot_crop_pad=args.aot_crop_pad, aot_max_size=args.aot_max_size, aot_noise_prefill=args.aot_noise_prefill, aot_noise_strength=args.aot_noise_strength, black_subject=args.black_subject, black_threshold=args.black_threshold, gray_subject=args.gray_subject, gray_saturation_threshold=args.gray_saturation_threshold, gray_value_threshold=args.gray_value_threshold, feather=args.feather, feather_radius=args.feather_radius, save_mask=args.save_mask, ) else: print(f"错误: 不支持的输入类型: {args.input}") exit(1)