Files
remove_backgroud/remove_background.py
2026-03-28 19:36:19 +08:00

1465 lines
51 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""书画与篆刻作品去背景工具。"""
import os
import sys
import argparse
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageOps
# 避免 numba 在某些环境下缓存失败
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
try:
from rembg import remove, new_session
except ImportError:
remove = None
new_session = None
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
REMBG_SESSION_CACHE: dict[str, object] = {}
def _resolve_aot_paths(aot_root, aot_pretrain):
"""解析 AOT-GAN 的路径配置。"""
aot_root = Path(aot_root).resolve()
if aot_pretrain is None:
aot_pretrain = None
else:
aot_pretrain = Path(aot_pretrain)
if not aot_pretrain.is_absolute():
# 预训练权重相对路径以当前工作目录为基准
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
return aot_root, aot_pretrain
def _parse_aot_rates(rates_str):
parts = [p for p in rates_str.split("+") if p]
return [int(p) for p in parts]
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
"""加载/缓存 AOT-GAN 模型。"""
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
if not aot_root.exists():
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
if aot_pretrain is None:
raise ValueError("AOT-GAN需要指定预训练权重路径--aot-pretrain")
if not aot_pretrain.exists():
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
if rates is None:
rates = [1, 2, 4, 8]
rates_tuple = tuple(rates)
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
if key in AOT_MODEL_CACHE:
return AOT_MODEL_CACHE[key]
src_root = aot_root / "src"
if str(src_root) not in sys.path:
sys.path.insert(0, str(src_root))
import importlib
try:
import torch
except ImportError as exc:
raise ImportError("未找到 PyTorch请先按README安装依赖。") from exc
net = importlib.import_module("model.aotgan")
class _Args:
pass
args = _Args()
args.block_num = int(block_num)
args.rates = rates
model = net.InpaintGenerator(args)
state = torch.load(str(aot_pretrain), map_location=device)
if isinstance(state, dict):
if "state_dict" in state:
state = state["state_dict"]
elif "model" in state:
state = state["model"]
elif "generator" in state:
state = state["generator"]
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
state = {k.replace("module.", "", 1): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
model.to(device)
model.eval()
AOT_MODEL_CACHE[key] = (model, device)
return AOT_MODEL_CACHE[key]
def _mask_bbox(mask):
ys, xs = np.where(mask > 0)
if len(xs) == 0:
return None
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
def _expand_bbox(bbox, pad, width, height):
if bbox is None:
return None
x0, y0, x1, y1 = bbox
pad = max(0, int(pad))
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(width - 1, x1 + pad)
y1 = min(height - 1, y1 + pad)
return x0, y0, x1, y1
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
expanded = _expand_bbox(bbox, pad, width, height)
if expanded is None:
return None
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
need_w = max(0, int(min_size) - roi_w)
need_h = max(0, int(min_size) - roi_h)
if need_w == 0 and need_h == 0:
return expanded
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
return _expand_bbox(bbox, pad + extra, width, height)
def _build_noise_prefill(img, mask_t, strength):
"""为mask区域生成噪声预填充img取值范围为[-1, 1]。"""
import torch
img01 = (img + 1.0) / 2.0
unmasked = 1.0 - mask_t
denom = unmasked.sum()
if denom.item() < 1.0:
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
else:
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
std = torch.sqrt(var + 1e-6)
noise = mean + std * float(strength) * torch.randn_like(img01)
noise = noise.clamp(0.0, 1.0)
return noise * 2.0 - 1.0
def _inpaint_with_aot_core(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
grid = 4
h2 = (h // grid) * grid
w2 = (w // grid) * grid
if h2 == 0 or w2 == 0:
return bgr
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
img_crop = rgb[:h2, :w2, :]
mask_crop = mask[:h2, :w2]
import torch
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
img = img * 2.0 - 1.0
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
img = img.unsqueeze(0)
mask_t = mask_t.unsqueeze(0)
if device != "cpu":
img = img.to(device)
mask_t = mask_t.to(device)
model, _ = _get_aot_model(
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
)
with torch.no_grad():
if noise_prefill:
noise = _build_noise_prefill(img, mask_t, noise_strength)
image_masked = img * (1 - mask_t) + noise * mask_t
else:
image_masked = img * (1 - mask_t) + mask_t
pred = model(image_masked, mask_t)
comp = pred * mask_t + img * (1 - mask_t)
comp = comp[0].clamp(-1.0, 1.0)
comp = (comp + 1.0) / 2.0 * 255.0
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
if h2 == h and w2 == w:
return result_bgr
output_full = bgr.copy()
output_full[:h2, :w2, :] = result_bgr
return output_full
def _inpaint_with_aot(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
crop=False,
crop_pad=0,
max_size=0,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
min_side = 32
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
if crop:
bbox = _mask_bbox(mask)
if bbox is None:
return bgr
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
if expanded is None:
return bgr
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
if roi_w < min_side or roi_h < min_side:
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
crop = False
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
roi_h, roi_w = bgr_roi.shape[:2]
scale = 1.0
max_size = int(max_size) if max_size else 0
if max_size > 0:
max_dim = max(roi_h, roi_w)
if max_dim > max_size:
scale = max_size / float(max_dim)
new_w = max(1, int(round(roi_w * scale)))
new_h = max(1, int(round(roi_h * scale)))
if min(new_w, new_h) < min_side:
scale = 1.0
new_w = roi_w
new_h = roi_h
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
if scale != 1.0:
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
mask_roi = cv2.resize(
mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST
)
filled_roi = _inpaint_with_aot_core(
bgr_roi,
mask_roi,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
noise_prefill=noise_prefill,
noise_strength=noise_strength,
)
if scale != 1.0:
filled_roi = cv2.resize(
filled_roi,
(roi_w, roi_h),
interpolation=cv2.INTER_LINEAR,
)
if not crop:
return filled_roi
output_full = bgr.copy()
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
return output_full
# 支持HEIC格式
try:
from pillow_heif import register_heif_opener
register_heif_opener()
HEIC_SUPPORTED = True
except ImportError:
HEIC_SUPPORTED = False
def _get_rembg_session(model_name):
"""按模型缓存 rembg 会话,避免重复加载。"""
if new_session is None:
raise ImportError(
"未找到 rembg请先安装相关依赖或改用 --foreground-mode artwork。"
)
if model_name not in REMBG_SESSION_CACHE:
REMBG_SESSION_CACHE[model_name] = new_session(model_name)
return REMBG_SESSION_CACHE[model_name]
def _resize_for_processing(image, max_size):
"""按最大边缩放图片,返回缩放后的图片与缩放比例。"""
height, width = image.shape[:2]
max_size = int(max_size) if max_size else 0
if max_size <= 0 or max(height, width) <= max_size:
return image.copy(), 1.0
scale = max_size / float(max(height, width))
resized = cv2.resize(
image,
(max(1, int(round(width * scale))), max(1, int(round(height * scale)))),
interpolation=cv2.INTER_AREA,
)
return resized, scale
def _otsu_threshold(channel, floor=0):
"""返回 Otsu 阈值,并设置一个最小下限避免纸张纹理误检。"""
if channel.size == 0 or int(channel.max()) <= 0:
return int(floor)
threshold, _ = cv2.threshold(channel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return max(int(round(threshold)), int(floor))
def _remove_small_components(mask, min_area):
"""移除面积过小的连通域,降低纸张纹理噪声。"""
if min_area <= 0 or int(mask.max()) == 0:
return mask
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
mask, connectivity=8
)
cleaned = np.zeros_like(mask)
for label in range(1, num_labels):
area = int(stats[label, cv2.CC_STAT_AREA])
if area >= min_area:
cleaned[labels == label] = 255
return cleaned
def _remove_border_frame_components(
mask,
min_width_ratio=0.7,
min_height_ratio=0.7,
max_fill_ratio=0.35,
):
"""移除贴边的大型空心边框连通域,避免纸张外轮廓被误判为前景。"""
if int(mask.max()) == 0:
return mask
height, width = mask.shape
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
cleaned = mask.copy()
for label in range(1, num_labels):
x, y, comp_w, comp_h, area = stats[label]
touches_border = x == 0 or y == 0 or (x + comp_w) == width or (y + comp_h) == height
if not touches_border:
continue
if comp_w < int(width * float(min_width_ratio)):
continue
if comp_h < int(height * float(min_height_ratio)):
continue
fill_ratio = float(area) / float(max(1, comp_w * comp_h))
if fill_ratio <= float(max_fill_ratio):
cleaned[labels == label] = 0
return cleaned
def _artwork_mask_is_reasonable(mask):
"""检查书画掩码是否可信,用于 auto 模式回退。"""
if mask.size == 0 or int(mask.max()) == 0:
return False
coverage = float(np.count_nonzero(mask)) / float(mask.size)
if coverage < 0.0005 or coverage > 0.55:
return False
height, width = mask.shape
border = max(4, min(height, width) // 32)
border_pixels = np.concatenate(
[
mask[:border, :].reshape(-1),
mask[-border:, :].reshape(-1),
mask[:, :border].reshape(-1),
mask[:, -border:].reshape(-1),
]
)
border_ratio = float(np.count_nonzero(border_pixels)) / float(border_pixels.size)
return border_ratio <= 0.28
def _compute_detail_alpha(rgb_image):
"""根据局部暗度与严格红章特征生成细节 alpha。"""
height, width = rgb_image.shape[:2]
gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
sigma_bg = max(8.0, max(height, width) / 80.0)
bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
dark_score = cv2.subtract(bg_gray, gray).astype(np.float32)
dark_alpha = np.clip((dark_score - 18.0) / 48.0, 0.0, 1.0)
hsv = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2HSV)
hue = hsv[:, :, 0].astype(np.float32)
saturation = hsv[:, :, 1].astype(np.float32)
value = hsv[:, :, 2].astype(np.float32)
red = rgb_image[:, :, 0].astype(np.float32)
green = rgb_image[:, :, 1].astype(np.float32)
blue = rgb_image[:, :, 2].astype(np.float32)
red_dominance = red - np.maximum(green, blue)
red_hue_mask = ((hue <= 12.0) | (hue >= 168.0)).astype(np.float32)
seal_alpha = np.clip((red_dominance - 20.0) / 40.0, 0.0, 1.0)
seal_alpha *= np.clip((saturation - 50.0) / 80.0, 0.0, 1.0)
seal_alpha *= (value >= 40.0).astype(np.float32)
seal_alpha *= red_hue_mask
detail_alpha = np.maximum(dark_alpha, seal_alpha)
return np.clip((detail_alpha - 0.07) / 0.93, 0.0, 1.0)
def _extract_artwork_mask(input_image, artwork_type="auto", max_size=1600):
"""为书画/篆刻作品提取前景掩码,避免依赖通用人像分割模型。"""
rgb = np.array(input_image.convert("RGB"))
resized, scale = _resize_for_processing(rgb, max_size=max_size)
height, width = resized.shape[:2]
gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY)
lab = cv2.cvtColor(resized, cv2.COLOR_RGB2LAB)
hsv = cv2.cvtColor(resized, cv2.COLOR_RGB2HSV)
sigma_bg = max(6.0, max(height, width) / 40.0)
bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
dark_score = cv2.subtract(bg_gray, gray)
lab_a = lab[:, :, 1]
lab_b = lab[:, :, 2]
bg_a = cv2.GaussianBlur(lab_a, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
bg_b = cv2.GaussianBlur(lab_b, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
chroma_score = np.sqrt(
(lab_a.astype(np.float32) - bg_a.astype(np.float32)) ** 2
+ (lab_b.astype(np.float32) - bg_b.astype(np.float32)) ** 2
)
chroma_score = np.clip(chroma_score * 3.0, 0, 255).astype(np.uint8)
red = resized[:, :, 0].astype(np.int16)
green = resized[:, :, 1].astype(np.int16)
blue = resized[:, :, 2].astype(np.int16)
hue = hsv[:, :, 0]
saturation = hsv[:, :, 1]
value = hsv[:, :, 2]
red_dominance = np.clip(red - np.maximum(green, blue), 0, 255).astype(np.uint8)
seal_score = np.clip(
red_dominance.astype(np.int16) * 2
+ np.clip(saturation.astype(np.int16) - 45, 0, 255),
0,
255,
).astype(np.uint8)
seal_score = np.where(
((hue <= 12) | (hue >= 168))
& (saturation >= 55)
& (value >= 40)
& (red_dominance >= 20),
seal_score,
0,
).astype(np.uint8)
color_score = np.maximum(chroma_score, seal_score)
if artwork_type == "seal":
combined_score = np.maximum(
np.clip(dark_score.astype(np.float32) * 1.0, 0, 255).astype(np.uint8),
np.clip(color_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8),
)
dark_floor = 12
color_floor = 18
combined_floor = 20
elif artwork_type == "calligraphy":
combined_score = np.maximum(
np.clip(dark_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8),
np.clip(color_score.astype(np.float32) * 0.8, 0, 255).astype(np.uint8),
)
dark_floor = 14
color_floor = 28
combined_floor = 18
else:
combined_score = np.maximum(
np.clip(dark_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8),
np.clip(color_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8),
)
dark_floor = 14
color_floor = 16
combined_floor = 18
smooth_sigma = max(0.6, max(height, width) / 320.0)
dark_smooth = cv2.GaussianBlur(
dark_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma
)
color_smooth = cv2.GaussianBlur(
color_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma
)
combined_smooth = cv2.GaussianBlur(
combined_score,
(0, 0),
sigmaX=smooth_sigma,
sigmaY=smooth_sigma,
)
dark_mask = (dark_smooth >= _otsu_threshold(dark_smooth, floor=dark_floor)).astype(
np.uint8
) * 255
color_mask = (
color_smooth >= _otsu_threshold(color_smooth, floor=color_floor)
).astype(np.uint8) * 255
combined_mask = (
combined_smooth >= _otsu_threshold(combined_smooth, floor=combined_floor)
).astype(np.uint8) * 255
if artwork_type == "seal":
mask = cv2.bitwise_or(combined_mask, color_mask)
mask = cv2.bitwise_or(mask, dark_mask)
elif artwork_type == "calligraphy":
mask = cv2.bitwise_or(combined_mask, dark_mask)
else:
mask = cv2.bitwise_or(combined_mask, cv2.bitwise_or(dark_mask, color_mask))
close_kernel = np.ones((3, 3), dtype=np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=1)
min_area = max(12, (height * width) // 25000)
mask = _remove_small_components(mask, min_area=min_area)
mask = _remove_border_frame_components(mask)
mask = cv2.dilate(mask, close_kernel, iterations=1)
if scale < 1.0:
mask = cv2.resize(
mask, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST
)
return mask
def _mask_to_transparent_image(input_image, mask):
"""将二值掩码转为带透明背景的 RGBA 图片。"""
rgb = np.array(input_image.convert("RGB"))
detail_alpha = _compute_detail_alpha(rgb)
alpha = ((mask.astype(np.float32) / 255.0) * detail_alpha * 255.0).astype(np.uint8)
alpha = cv2.GaussianBlur(alpha, (0, 0), sigmaX=0.8, sigmaY=0.8)
alpha = np.where(mask > 0, alpha, 0).astype(np.uint8)
rgba = np.dstack((rgb, alpha))
return Image.fromarray(rgba)
def _extract_foreground_mask(
input_image,
session=None,
model_name="isnet-general-use",
foreground_mode="artwork",
artwork_type="auto",
artwork_max_size=1600,
mask_threshold=10,
**kwargs,
):
"""按指定模式提取前景掩码。"""
if foreground_mode in {"artwork", "auto"}:
artwork_mask = _extract_artwork_mask(
input_image,
artwork_type=artwork_type,
max_size=artwork_max_size,
)
if foreground_mode == "artwork" or _artwork_mask_is_reasonable(artwork_mask):
return artwork_mask, "artwork"
if foreground_mode not in {"rembg", "auto"}:
raise ValueError(f"不支持的前景提取模式: {foreground_mode}")
if remove is None:
raise ImportError(
"未找到 rembg请先安装相关依赖或改用 --foreground-mode artwork。"
)
if session is None:
session = _get_rembg_session(model_name)
output_image = remove(input_image, session=session, **kwargs)
output_image = _ensure_rgba_size(output_image, input_image.size)
alpha = np.array(output_image.getchannel("A"))
return _alpha_to_mask(alpha, threshold=mask_threshold), "rembg"
def remove_background(
input_path,
output_path,
session=None,
model_name="isnet-general-use",
foreground_mode="artwork",
artwork_type="auto",
artwork_max_size=1600,
**kwargs,
):
"""
去除图片背景。
Args:
input_path: 输入图片路径
output_path: 输出图片路径
session: rembg会话对象可选
model_name: rembg 模型名称
foreground_mode: 前景提取模式,可选 artwork / auto / rembg
artwork_type: 书画类型,可选 auto / calligraphy / seal
artwork_max_size: 书画掩码估算时的最大边
**kwargs: 其他参数如alpha_matting相关参数
"""
print(f"正在处理: {input_path}")
# 读取输入图片
input_image = ImageOps.exif_transpose(Image.open(input_path))
if foreground_mode in {"artwork", "auto"}:
mask = _extract_artwork_mask(
input_image,
artwork_type=artwork_type,
max_size=artwork_max_size,
)
if foreground_mode == "artwork" or _artwork_mask_is_reasonable(mask):
output_image = _mask_to_transparent_image(input_image, mask)
output_image.save(output_path)
print(f"已保存: {output_path}")
return
if remove is None:
raise ImportError(
"未找到 rembg请先安装相关依赖或改用 --foreground-mode artwork。"
)
if session is None:
session = _get_rembg_session(model_name)
output_image = remove(input_image, session=session, **kwargs)
output_image.save(output_path)
print(f"已保存: {output_path}")
def _pil_to_bgr(image):
"""将PIL图片转换为OpenCV BGR格式"""
rgb = image.convert("RGB")
arr = np.array(rgb)
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
def _alpha_to_mask(alpha_channel, threshold=10):
"""将alpha通道转换为二值mask0或255"""
mask = (alpha_channel > threshold).astype(np.uint8) * 255
return mask
def _ensure_rgba_size(image, target_size):
if image.mode != "RGBA":
image = image.convert("RGBA")
if image.size != target_size:
image = image.resize(target_size, resample=Image.NEAREST)
return image
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
"""生成硬边mask与处理后mask用于填补/过渡)"""
if mask_dilate and mask_dilate > 0:
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=1)
if edge_grow and edge_grow > 0:
kernel = np.ones((3, 3), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
mask_hard = mask.copy()
mask_used = mask_hard
if mask_blur and mask_blur > 0:
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
mask_used = (mask_used > 0).astype(np.uint8) * 255
return mask_hard, mask_used
def remove_subject_and_inpaint(
input_path,
output_path,
session=None,
model_name="isnet-general-use",
foreground_mode="artwork",
artwork_type="auto",
artwork_max_size=1600,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
mask_output_path=None,
**kwargs,
):
"""
去掉主体并补全背景
Args:
input_path: 输入图片路径
output_path: 输出图片路径
session: rembg会话对象可选
model_name: rembg 模型名称
foreground_mode: 前景提取模式,可选 artwork / auto / rembg
artwork_type: 书画类型,可选 auto / calligraphy / seal
artwork_max_size: 书画掩码估算时的最大边
aot_root: AOT-GAN目录
aot_pretrain: AOT-GAN权重文件路径
aot_device: AOT-GAN设备
aot_block_num: AOT-GAN AOTBlock 数量
aot_rates: AOT-GAN AOTBlock 膨胀率
aot_crop: AOT仅裁剪mask区域进行修补
aot_crop_pad: AOT裁剪边缘留白像素
aot_max_size: AOT输入最大边限制0为不限制
aot_noise_prefill: AOT使用随机噪声预填充
aot_noise_strength: 噪声强度系数
mask_dilate: mask膨胀大小
mask_blur: mask模糊大小奇数
mask_threshold: alpha阈值
save_mask: 是否保存mask
mask_output_path: mask保存路径
**kwargs: rembg参数
"""
print(f"正在处理(去主体补背景): {input_path}")
input_image = ImageOps.exif_transpose(Image.open(input_path))
mask, mask_source = _extract_foreground_mask(
input_image,
session=session,
model_name=model_name,
foreground_mode=foreground_mode,
artwork_type=artwork_type,
artwork_max_size=artwork_max_size,
mask_threshold=mask_threshold,
**kwargs,
)
bgr = _pil_to_bgr(input_image)
if black_subject:
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, black_mask)
if gray_subject:
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
s = hsv[:, :, 1]
v = hsv[:, :, 2]
gray_mask = (
(s <= gray_saturation_threshold) & (v <= gray_value_threshold)
).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, gray_mask)
mask_hard, mask_used = _prepare_mask(
mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
edge_grow=edge_grow,
)
rates = _parse_aot_rates(aot_rates)
filled = _inpaint_with_aot(
bgr,
mask_used,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=aot_device,
block_num=aot_block_num,
rates=rates,
crop=aot_crop,
crop_pad=aot_crop_pad,
max_size=aot_max_size,
noise_prefill=aot_noise_prefill,
noise_strength=aot_noise_strength,
)
if feather and feather_radius > 0:
# 仅在mask外侧做过渡避免把原主体带回
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
alpha = np.ones_like(dist_out, dtype=np.float32)
outside = mask_bin == 0
alpha[outside] = np.clip(
1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0
)
alpha = alpha[:, :, None]
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
else:
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
Image.fromarray(result).save(output_path)
if save_mask:
if mask_output_path is None:
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
Image.fromarray(mask_used).save(mask_output_path)
print(f"已保存mask: {mask_output_path}")
print(f"前景提取模式: {mask_source}")
print(f"已保存: {output_path}")
def process_images_folder(
input_folder,
output_folder,
model_name="isnet-general-use",
foreground_mode="artwork",
artwork_type="auto",
artwork_max_size=1600,
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_size=10,
post_process_mask=False,
remove_subject=False,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
):
"""
批量处理文件夹中的所有图片
Args:
input_folder: 输入文件夹路径
output_folder: 输出文件夹路径
model_name: rembg 模型名称,可选值:
- u2net (默认): 通用模型
- u2netp: 轻量版u2net
- u2net_human_seg: 人物分割
- silueta: 精简版u2net (43MB)
- isnet-general-use: 新的通用模型
- isnet-anime: 动漫角色高精度分割
- birefnet-general: 通用模型
- birefnet-portrait: 人像模型
foreground_mode: 前景提取模式,可选 artwork / auto / rembg
artwork_type: 书画类型,可选 auto / calligraphy / seal
artwork_max_size: 书画掩码估算时的最大边
alpha_matting: 是否启用alpha matting后处理改善边缘质量
alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景
alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景
alpha_matting_erode_size: 侵蚀大小,用于平滑边缘
post_process_mask: 是否启用mask后处理
"""
# 创建输出文件夹
Path(output_folder).mkdir(parents=True, exist_ok=True)
print(f"前景提取模式: {foreground_mode}")
print(f"书画类型: {artwork_type}")
if foreground_mode in {"rembg", "auto"}:
print(f"rembg模型: {model_name}")
# 支持的图片格式
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
if HEIC_SUPPORTED:
image_extensions.update({".heic", ".heif"})
else:
print(
"提示: 未安装pillow-heifHEIC格式不可用。安装方法: pip install pillow-heif"
)
# 获取所有图片文件
input_path = Path(input_folder)
image_files = [
f
for f in input_path.iterdir()
if f.is_file() and f.suffix.lower() in image_extensions
]
if not image_files:
print(f"{input_folder} 中没有找到图片文件")
return
print(f"找到 {len(image_files)} 张图片,开始处理...")
print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}")
if alpha_matting:
print(f" - 前景阈值: {alpha_matting_foreground_threshold}")
print(f" - 背景阈值: {alpha_matting_background_threshold}")
print(f" - 侵蚀大小: {alpha_matting_erode_size}")
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
if remove_subject:
print(f" - AOT目录: {aot_root}")
print(f" - AOT权重: {aot_pretrain}")
print(f" - AOT设备: {aot_device}")
print(f" - AOT块数: {aot_block_num}")
print(f" - AOT膨胀率: {aot_rates}")
print(f" - AOT裁剪: {'' if aot_crop else ''}")
if aot_crop:
print(f" - AOT裁剪边界: {aot_crop_pad}")
print(f" - AOT最大边: {aot_max_size}")
print(f" - AOT噪声预填充: {'' if aot_noise_prefill else ''}")
if aot_noise_prefill:
print(f" - AOT噪声强度: {aot_noise_strength}")
print(f" - mask膨胀: {mask_dilate}")
print(f" - mask模糊: {mask_blur}")
print(f" - mask阈值: {mask_threshold}")
print(f" - 边缘扩张: {edge_grow}")
print(f" - 黑色内容作为主体: {'' if black_subject else ''}")
if black_subject:
print(f" - 黑色阈值: {black_threshold}")
print(f" - 灰阶内容作为主体: {'' if gray_subject else ''}")
if gray_subject:
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
print(f" - 边缘过渡: {'' if feather else ''}")
if feather:
print(f" - 过渡半径: {feather_radius}")
print(f" - 保存mask: {'' if save_mask else ''}")
print("-" * 50)
# 处理每张图片
for i, image_file in enumerate(image_files, 1):
try:
if remove_subject:
suffix = image_file.suffix.lower()
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_filename = image_file.stem + "_bgfill" + suffix
else:
output_filename = image_file.stem + "_bgfill.jpg"
else:
# 去背景默认使用PNG格式以支持透明背景
output_filename = image_file.stem + "_nobg.png"
output_path = Path(output_folder) / output_filename
print(f"[{i}/{len(image_files)}] ", end="")
if remove_subject:
mask_output_path = None
if save_mask:
mask_output_path = str(
Path(output_folder) / (image_file.stem + "_mask.png")
)
remove_subject_and_inpaint(
str(image_file),
str(output_path),
model_name=model_name,
foreground_mode=foreground_mode,
artwork_type=artwork_type,
artwork_max_size=artwork_max_size,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
mask_threshold=mask_threshold,
edge_grow=edge_grow,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
aot_device=aot_device,
aot_block_num=aot_block_num,
aot_rates=aot_rates,
aot_crop=aot_crop,
aot_crop_pad=aot_crop_pad,
aot_max_size=aot_max_size,
aot_noise_prefill=aot_noise_prefill,
aot_noise_strength=aot_noise_strength,
black_subject=black_subject,
black_threshold=black_threshold,
gray_subject=gray_subject,
gray_saturation_threshold=gray_saturation_threshold,
gray_value_threshold=gray_value_threshold,
feather=feather,
feather_radius=feather_radius,
save_mask=save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(image_file),
str(output_path),
model_name=model_name,
foreground_mode=foreground_mode,
artwork_type=artwork_type,
artwork_max_size=artwork_max_size,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
)
except Exception as e:
print(f"处理 {image_file.name} 时出错: {e}")
print("-" * 50)
print(f"处理完成!结果保存在 {output_folder} 文件夹中")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="书画与篆刻作品去背景工具 - 默认使用轻量书画掩码提取",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例用法:
# 使用默认参数处理images文件夹
python remove_background.py
# 处理单个文件
python remove_background.py input.jpg output.png
# 处理指定文件夹
python remove_background.py my_images/ my_output/
# 强制使用 rembg 旧流程
python remove_background.py input.jpg output.png --foreground-mode rembg -m isnet-general-use
# 指定为篆刻模式
python remove_background.py input.jpg output.png --artwork-type seal
# 自定义alpha matting参数
python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5
""",
)
# 必需参数
parser.add_argument(
"input",
nargs="?",
default="images",
help="输入文件或文件夹路径(默认: images",
)
parser.add_argument(
"output",
nargs="?",
default=None,
help="输出文件或文件夹路径可选默认为output/",
)
# 模型选择
parser.add_argument(
"-m",
"--model",
default="isnet-general-use",
choices=[
"u2net",
"u2netp",
"u2net_human_seg",
"silueta",
"isnet-general-use",
"isnet-anime",
"birefnet-general",
"birefnet-general-lite",
"birefnet-portrait",
"birefnet-dis",
"birefnet-hrsod",
"birefnet-cod",
"birefnet-massive",
],
help="选择 rembg 模型(仅在 --foreground-mode rembg/auto 时使用)",
)
parser.add_argument(
"--foreground-mode",
default="artwork",
choices=["artwork", "auto", "rembg"],
help="前景提取模式artwork 为书画专用快速模式auto 失败时回退 rembgrembg 为旧流程 (默认: artwork)",
)
parser.add_argument(
"--artwork-type",
default="auto",
choices=["auto", "calligraphy", "seal"],
help="书画类型auto 自动兼容书法与篆刻seal 更偏重红色印章calligraphy 更偏重墨色笔画 (默认: auto)",
)
parser.add_argument(
"--artwork-max-size",
type=int,
default=1600,
help="书画掩码估算时的最大边,越小越快,越大越精细 (默认: 1600)",
)
# Alpha Matting参数
parser.add_argument(
"-a",
"--alpha-matting",
"--alpha_matting",
action="store_true",
help="启用alpha matting后处理默认: false",
)
parser.add_argument(
"-ft",
"--foreground-threshold",
type=int,
default=245,
help="前景阈值 (0-255),值越大保留越多细节 (默认: 245)",
)
parser.add_argument(
"-bt",
"--background-threshold",
type=int,
default=8,
help="背景阈值 (0-255),值越大去除越多背景 (默认: 8)",
)
parser.add_argument(
"-es",
"--erode-size",
type=int,
default=2,
help="侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)",
)
# 其他选项
parser.add_argument(
"-p",
"--post-process",
"--post_process",
action="store_true",
help="启用mask后处理默认: false",
)
# 去主体补背景参数
parser.add_argument(
"--remove-subject",
"--remove_subject",
action="store_true",
help="去掉主体并补全背景(默认: false",
)
parser.add_argument(
"--aot-root",
type=str,
default="AOT-GAN-for-Inpainting",
help="AOT-GAN目录默认: AOT-GAN-for-Inpainting",
)
parser.add_argument(
"--aot-pretrain",
type=str,
default=None,
help="AOT-GAN预训练权重文件路径必填相对路径基于当前目录",
)
parser.add_argument(
"--aot-device", type=str, default="cpu", help="AOT-GAN设备默认: cpu"
)
parser.add_argument(
"--aot-block-num", type=int, default=8, help="AOTBlock数量默认: 8"
)
parser.add_argument(
"--aot-rates",
type=str,
default="1+2+4+8",
help="AOTBlock膨胀率默认: 1+2+4+8",
)
parser.add_argument(
"--aot-crop", action="store_true", help="AOT仅对mask区域裁剪修补默认: false"
)
parser.add_argument(
"--aot-crop-pad", type=int, default=0, help="AOT裁剪边缘留白像素默认: 0"
)
parser.add_argument(
"--aot-max-size",
type=int,
default=0,
help="AOT输入最大边限制0为不限制默认: 0",
)
parser.add_argument(
"--aot-noise-prefill",
action="store_true",
help="AOT使用随机噪声预填充默认: false",
)
parser.add_argument(
"--aot-noise-strength",
type=float,
default=1.0,
help="AOT噪声强度系数默认: 1.0",
)
parser.add_argument(
"--mask-dilate", type=int, default=3, help="mask膨胀大小默认: 3"
)
parser.add_argument(
"--mask-blur", type=int, default=3, help="mask模糊大小默认: 3建议奇数"
)
parser.add_argument(
"--mask-threshold", type=int, default=10, help="alpha阈值默认: 10"
)
parser.add_argument(
"--edge-grow", type=int, default=0, help="主体边缘扩张像素(默认: 0"
)
parser.add_argument(
"--save-mask",
"--save_mask",
action="store_true",
help="保存mask到output目录默认: false",
)
parser.add_argument(
"--black-subject",
"--black_subject",
action="store_true",
help="将黑色内容也视为主体(默认: false",
)
parser.add_argument(
"--black-threshold",
type=int,
default=50,
help="黑色阈值(0-255灰度越小越黑默认: 50)",
)
parser.add_argument(
"--gray-subject",
"--gray_subject",
action="store_true",
help="将灰阶内容也视为主体(默认: false",
)
parser.add_argument(
"--gray-saturation-threshold",
type=int,
default=30,
help="灰阶饱和度阈值(0-255越小越接近灰阶默认: 30)",
)
parser.add_argument(
"--gray-value-threshold",
type=int,
default=200,
help="灰阶亮度阈值(0-255越小越暗默认: 200)",
)
parser.add_argument(
"--feather", action="store_true", help="启用边缘过渡融合(默认: false"
)
parser.add_argument(
"--feather-radius", type=int, default=5, help="边缘过渡半径(默认: 5"
)
args = parser.parse_args()
print("=" * 50)
print("图片去背景工具")
print("=" * 50)
# 判断输入是文件还是文件夹
input_path = Path(args.input)
if not input_path.exists():
print(f"错误: 输入路径不存在: {args.input}")
exit(1)
# 处理单个文件
if input_path.is_file():
suffix = input_path.suffix.lower()
if args.remove_subject:
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_name = input_path.stem + "_bgfill" + suffix
else:
output_name = input_path.stem + "_bgfill.jpg"
else:
output_name = input_path.stem + "_nobg.png"
# 确定输出路径
if args.output is None:
if args.remove_subject:
output_path = input_path.parent / "output" / output_name
else:
output_path = input_path.parent / "output" / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
else:
output_candidate = Path(args.output)
if output_candidate.exists() and output_candidate.is_dir():
output_path = output_candidate / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
elif output_candidate.suffix == "":
output_candidate.mkdir(parents=True, exist_ok=True)
output_path = output_candidate / output_name
else:
output_path = output_candidate
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"输入文件: {input_path}")
print(f"输出文件: {output_path}")
print(f"前景提取模式: {args.foreground_mode}")
print(f"书画类型: {args.artwork_type}")
print(f"书画最大边: {args.artwork_max_size}")
if args.foreground_mode in {"rembg", "auto"}:
print(f"rembg模型: {args.model}")
print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}")
if args.alpha_matting:
print(f" - 前景阈值: {args.foreground_threshold}")
print(f" - 背景阈值: {args.background_threshold}")
print(f" - 侵蚀大小: {args.erode_size}")
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
if args.remove_subject:
print(f" - AOT目录: {args.aot_root}")
print(f" - AOT权重: {args.aot_pretrain}")
print(f" - AOT设备: {args.aot_device}")
print(f" - AOT块数: {args.aot_block_num}")
print(f" - AOT膨胀率: {args.aot_rates}")
print(f" - AOT裁剪: {'' if args.aot_crop else ''}")
if args.aot_crop:
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
print(f" - AOT最大边: {args.aot_max_size}")
print(f" - AOT噪声预填充: {'' if args.aot_noise_prefill else ''}")
if args.aot_noise_prefill:
print(f" - AOT噪声强度: {args.aot_noise_strength}")
print(f" - mask膨胀: {args.mask_dilate}")
print(f" - mask模糊: {args.mask_blur}")
print(f" - mask阈值: {args.mask_threshold}")
print(f" - 边缘扩张: {args.edge_grow}")
print(f" - 黑色内容作为主体: {'' if args.black_subject else ''}")
if args.black_subject:
print(f" - 黑色阈值: {args.black_threshold}")
print(f" - 灰阶内容作为主体: {'' if args.gray_subject else ''}")
if args.gray_subject:
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
print(f" - 边缘过渡: {'' if args.feather else ''}")
if args.feather:
print(f" - 过渡半径: {args.feather_radius}")
print(f" - 保存mask: {'' if args.save_mask else ''}")
print("-" * 50)
# 处理图片
if args.remove_subject:
mask_output_path = None
if args.save_mask:
mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
remove_subject_and_inpaint(
str(input_path),
str(output_path),
model_name=args.model,
foreground_mode=args.foreground_mode,
artwork_type=args.artwork_type,
artwork_max_size=args.artwork_max_size,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(input_path),
str(output_path),
model_name=args.model,
foreground_mode=args.foreground_mode,
artwork_type=args.artwork_type,
artwork_max_size=args.artwork_max_size,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
)
print("-" * 50)
print(f"处理完成!结果保存在: {output_path}")
# 处理文件夹
elif input_path.is_dir():
output_folder = args.output if args.output else "output"
process_images_folder(
str(input_path),
output_folder,
model_name=args.model,
foreground_mode=args.foreground_mode,
artwork_type=args.artwork_type,
artwork_max_size=args.artwork_max_size,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
remove_subject=args.remove_subject,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
)
else:
print(f"错误: 不支持的输入类型: {args.input}")
exit(1)