1465 lines
51 KiB
Python
1465 lines
51 KiB
Python
"""书画与篆刻作品去背景工具。"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image, ImageOps
|
||
|
||
# 避免 numba 在某些环境下缓存失败
|
||
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
|
||
|
||
try:
|
||
from rembg import remove, new_session
|
||
except ImportError:
|
||
remove = None
|
||
new_session = None
|
||
|
||
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
|
||
REMBG_SESSION_CACHE: dict[str, object] = {}
|
||
|
||
|
||
def _resolve_aot_paths(aot_root, aot_pretrain):
|
||
"""解析 AOT-GAN 的路径配置。"""
|
||
aot_root = Path(aot_root).resolve()
|
||
if aot_pretrain is None:
|
||
aot_pretrain = None
|
||
else:
|
||
aot_pretrain = Path(aot_pretrain)
|
||
if not aot_pretrain.is_absolute():
|
||
# 预训练权重相对路径以当前工作目录为基准
|
||
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
|
||
return aot_root, aot_pretrain
|
||
|
||
|
||
def _parse_aot_rates(rates_str):
|
||
parts = [p for p in rates_str.split("+") if p]
|
||
return [int(p) for p in parts]
|
||
|
||
|
||
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
|
||
"""加载/缓存 AOT-GAN 模型。"""
|
||
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
|
||
if not aot_root.exists():
|
||
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
|
||
if aot_pretrain is None:
|
||
raise ValueError("AOT-GAN需要指定预训练权重路径(--aot-pretrain)")
|
||
if not aot_pretrain.exists():
|
||
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
|
||
if rates is None:
|
||
rates = [1, 2, 4, 8]
|
||
rates_tuple = tuple(rates)
|
||
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
|
||
if key in AOT_MODEL_CACHE:
|
||
return AOT_MODEL_CACHE[key]
|
||
|
||
src_root = aot_root / "src"
|
||
if str(src_root) not in sys.path:
|
||
sys.path.insert(0, str(src_root))
|
||
import importlib
|
||
|
||
try:
|
||
import torch
|
||
except ImportError as exc:
|
||
raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc
|
||
|
||
net = importlib.import_module("model.aotgan")
|
||
|
||
class _Args:
|
||
pass
|
||
|
||
args = _Args()
|
||
args.block_num = int(block_num)
|
||
args.rates = rates
|
||
model = net.InpaintGenerator(args)
|
||
state = torch.load(str(aot_pretrain), map_location=device)
|
||
if isinstance(state, dict):
|
||
if "state_dict" in state:
|
||
state = state["state_dict"]
|
||
elif "model" in state:
|
||
state = state["model"]
|
||
elif "generator" in state:
|
||
state = state["generator"]
|
||
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
|
||
state = {k.replace("module.", "", 1): v for k, v in state.items()}
|
||
model.load_state_dict(state, strict=True)
|
||
model.to(device)
|
||
model.eval()
|
||
AOT_MODEL_CACHE[key] = (model, device)
|
||
return AOT_MODEL_CACHE[key]
|
||
|
||
|
||
def _mask_bbox(mask):
|
||
ys, xs = np.where(mask > 0)
|
||
if len(xs) == 0:
|
||
return None
|
||
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
|
||
|
||
|
||
def _expand_bbox(bbox, pad, width, height):
|
||
if bbox is None:
|
||
return None
|
||
x0, y0, x1, y1 = bbox
|
||
pad = max(0, int(pad))
|
||
x0 = max(0, x0 - pad)
|
||
y0 = max(0, y0 - pad)
|
||
x1 = min(width - 1, x1 + pad)
|
||
y1 = min(height - 1, y1 + pad)
|
||
return x0, y0, x1, y1
|
||
|
||
|
||
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
|
||
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
|
||
expanded = _expand_bbox(bbox, pad, width, height)
|
||
if expanded is None:
|
||
return None
|
||
x0, y0, x1, y1 = expanded
|
||
roi_w = x1 - x0 + 1
|
||
roi_h = y1 - y0 + 1
|
||
need_w = max(0, int(min_size) - roi_w)
|
||
need_h = max(0, int(min_size) - roi_h)
|
||
if need_w == 0 and need_h == 0:
|
||
return expanded
|
||
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
|
||
return _expand_bbox(bbox, pad + extra, width, height)
|
||
|
||
|
||
def _build_noise_prefill(img, mask_t, strength):
|
||
"""为mask区域生成噪声预填充,img取值范围为[-1, 1]。"""
|
||
import torch
|
||
|
||
img01 = (img + 1.0) / 2.0
|
||
unmasked = 1.0 - mask_t
|
||
denom = unmasked.sum()
|
||
if denom.item() < 1.0:
|
||
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
|
||
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
|
||
else:
|
||
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
|
||
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
|
||
std = torch.sqrt(var + 1e-6)
|
||
noise = mean + std * float(strength) * torch.randn_like(img01)
|
||
noise = noise.clamp(0.0, 1.0)
|
||
return noise * 2.0 - 1.0
|
||
|
||
|
||
def _inpaint_with_aot_core(
|
||
bgr,
|
||
mask,
|
||
aot_root,
|
||
aot_pretrain,
|
||
device="cpu",
|
||
block_num=8,
|
||
rates=None,
|
||
noise_prefill=False,
|
||
noise_strength=1.0,
|
||
):
|
||
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
|
||
if mask.dtype != np.uint8:
|
||
mask = mask.astype(np.uint8)
|
||
if mask.max() <= 1:
|
||
mask = mask * 255
|
||
if mask.max() == 0:
|
||
return bgr
|
||
|
||
h, w = mask.shape
|
||
grid = 4
|
||
h2 = (h // grid) * grid
|
||
w2 = (w // grid) * grid
|
||
if h2 == 0 or w2 == 0:
|
||
return bgr
|
||
|
||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||
img_crop = rgb[:h2, :w2, :]
|
||
mask_crop = mask[:h2, :w2]
|
||
|
||
import torch
|
||
|
||
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
|
||
img = img * 2.0 - 1.0
|
||
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
|
||
img = img.unsqueeze(0)
|
||
mask_t = mask_t.unsqueeze(0)
|
||
if device != "cpu":
|
||
img = img.to(device)
|
||
mask_t = mask_t.to(device)
|
||
|
||
model, _ = _get_aot_model(
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
device=device,
|
||
block_num=block_num,
|
||
rates=rates,
|
||
)
|
||
with torch.no_grad():
|
||
if noise_prefill:
|
||
noise = _build_noise_prefill(img, mask_t, noise_strength)
|
||
image_masked = img * (1 - mask_t) + noise * mask_t
|
||
else:
|
||
image_masked = img * (1 - mask_t) + mask_t
|
||
pred = model(image_masked, mask_t)
|
||
comp = pred * mask_t + img * (1 - mask_t)
|
||
comp = comp[0].clamp(-1.0, 1.0)
|
||
comp = (comp + 1.0) / 2.0 * 255.0
|
||
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
|
||
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
|
||
|
||
if h2 == h and w2 == w:
|
||
return result_bgr
|
||
output_full = bgr.copy()
|
||
output_full[:h2, :w2, :] = result_bgr
|
||
return output_full
|
||
|
||
|
||
def _inpaint_with_aot(
|
||
bgr,
|
||
mask,
|
||
aot_root,
|
||
aot_pretrain,
|
||
device="cpu",
|
||
block_num=8,
|
||
rates=None,
|
||
crop=False,
|
||
crop_pad=0,
|
||
max_size=0,
|
||
noise_prefill=False,
|
||
noise_strength=1.0,
|
||
):
|
||
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
|
||
min_side = 32
|
||
if mask.dtype != np.uint8:
|
||
mask = mask.astype(np.uint8)
|
||
if mask.max() <= 1:
|
||
mask = mask * 255
|
||
if mask.max() == 0:
|
||
return bgr
|
||
|
||
h, w = mask.shape
|
||
x0 = y0 = 0
|
||
x1 = w - 1
|
||
y1 = h - 1
|
||
if crop:
|
||
bbox = _mask_bbox(mask)
|
||
if bbox is None:
|
||
return bgr
|
||
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
|
||
if expanded is None:
|
||
return bgr
|
||
x0, y0, x1, y1 = expanded
|
||
roi_w = x1 - x0 + 1
|
||
roi_h = y1 - y0 + 1
|
||
if roi_w < min_side or roi_h < min_side:
|
||
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
|
||
crop = False
|
||
x0 = y0 = 0
|
||
x1 = w - 1
|
||
y1 = h - 1
|
||
|
||
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
|
||
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
|
||
roi_h, roi_w = bgr_roi.shape[:2]
|
||
|
||
scale = 1.0
|
||
max_size = int(max_size) if max_size else 0
|
||
if max_size > 0:
|
||
max_dim = max(roi_h, roi_w)
|
||
if max_dim > max_size:
|
||
scale = max_size / float(max_dim)
|
||
new_w = max(1, int(round(roi_w * scale)))
|
||
new_h = max(1, int(round(roi_h * scale)))
|
||
if min(new_w, new_h) < min_side:
|
||
scale = 1.0
|
||
new_w = roi_w
|
||
new_h = roi_h
|
||
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
|
||
if scale != 1.0:
|
||
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
|
||
mask_roi = cv2.resize(
|
||
mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST
|
||
)
|
||
|
||
filled_roi = _inpaint_with_aot_core(
|
||
bgr_roi,
|
||
mask_roi,
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
device=device,
|
||
block_num=block_num,
|
||
rates=rates,
|
||
noise_prefill=noise_prefill,
|
||
noise_strength=noise_strength,
|
||
)
|
||
|
||
if scale != 1.0:
|
||
filled_roi = cv2.resize(
|
||
filled_roi,
|
||
(roi_w, roi_h),
|
||
interpolation=cv2.INTER_LINEAR,
|
||
)
|
||
|
||
if not crop:
|
||
return filled_roi
|
||
output_full = bgr.copy()
|
||
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
|
||
return output_full
|
||
|
||
|
||
# 支持HEIC格式
|
||
try:
|
||
from pillow_heif import register_heif_opener
|
||
|
||
register_heif_opener()
|
||
HEIC_SUPPORTED = True
|
||
except ImportError:
|
||
HEIC_SUPPORTED = False
|
||
|
||
|
||
def _get_rembg_session(model_name):
|
||
"""按模型缓存 rembg 会话,避免重复加载。"""
|
||
if new_session is None:
|
||
raise ImportError(
|
||
"未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。"
|
||
)
|
||
if model_name not in REMBG_SESSION_CACHE:
|
||
REMBG_SESSION_CACHE[model_name] = new_session(model_name)
|
||
return REMBG_SESSION_CACHE[model_name]
|
||
|
||
|
||
def _resize_for_processing(image, max_size):
|
||
"""按最大边缩放图片,返回缩放后的图片与缩放比例。"""
|
||
height, width = image.shape[:2]
|
||
max_size = int(max_size) if max_size else 0
|
||
if max_size <= 0 or max(height, width) <= max_size:
|
||
return image.copy(), 1.0
|
||
scale = max_size / float(max(height, width))
|
||
resized = cv2.resize(
|
||
image,
|
||
(max(1, int(round(width * scale))), max(1, int(round(height * scale)))),
|
||
interpolation=cv2.INTER_AREA,
|
||
)
|
||
return resized, scale
|
||
|
||
|
||
def _otsu_threshold(channel, floor=0):
|
||
"""返回 Otsu 阈值,并设置一个最小下限避免纸张纹理误检。"""
|
||
if channel.size == 0 or int(channel.max()) <= 0:
|
||
return int(floor)
|
||
threshold, _ = cv2.threshold(channel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||
return max(int(round(threshold)), int(floor))
|
||
|
||
|
||
def _remove_small_components(mask, min_area):
|
||
"""移除面积过小的连通域,降低纸张纹理噪声。"""
|
||
if min_area <= 0 or int(mask.max()) == 0:
|
||
return mask
|
||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
|
||
mask, connectivity=8
|
||
)
|
||
cleaned = np.zeros_like(mask)
|
||
for label in range(1, num_labels):
|
||
area = int(stats[label, cv2.CC_STAT_AREA])
|
||
if area >= min_area:
|
||
cleaned[labels == label] = 255
|
||
return cleaned
|
||
|
||
|
||
def _remove_border_frame_components(
|
||
mask,
|
||
min_width_ratio=0.7,
|
||
min_height_ratio=0.7,
|
||
max_fill_ratio=0.35,
|
||
):
|
||
"""移除贴边的大型空心边框连通域,避免纸张外轮廓被误判为前景。"""
|
||
if int(mask.max()) == 0:
|
||
return mask
|
||
|
||
height, width = mask.shape
|
||
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
|
||
cleaned = mask.copy()
|
||
for label in range(1, num_labels):
|
||
x, y, comp_w, comp_h, area = stats[label]
|
||
touches_border = x == 0 or y == 0 or (x + comp_w) == width or (y + comp_h) == height
|
||
if not touches_border:
|
||
continue
|
||
if comp_w < int(width * float(min_width_ratio)):
|
||
continue
|
||
if comp_h < int(height * float(min_height_ratio)):
|
||
continue
|
||
fill_ratio = float(area) / float(max(1, comp_w * comp_h))
|
||
if fill_ratio <= float(max_fill_ratio):
|
||
cleaned[labels == label] = 0
|
||
return cleaned
|
||
|
||
|
||
def _artwork_mask_is_reasonable(mask):
|
||
"""检查书画掩码是否可信,用于 auto 模式回退。"""
|
||
if mask.size == 0 or int(mask.max()) == 0:
|
||
return False
|
||
|
||
coverage = float(np.count_nonzero(mask)) / float(mask.size)
|
||
if coverage < 0.0005 or coverage > 0.55:
|
||
return False
|
||
|
||
height, width = mask.shape
|
||
border = max(4, min(height, width) // 32)
|
||
border_pixels = np.concatenate(
|
||
[
|
||
mask[:border, :].reshape(-1),
|
||
mask[-border:, :].reshape(-1),
|
||
mask[:, :border].reshape(-1),
|
||
mask[:, -border:].reshape(-1),
|
||
]
|
||
)
|
||
border_ratio = float(np.count_nonzero(border_pixels)) / float(border_pixels.size)
|
||
return border_ratio <= 0.28
|
||
|
||
|
||
def _compute_detail_alpha(rgb_image):
|
||
"""根据局部暗度与严格红章特征生成细节 alpha。"""
|
||
height, width = rgb_image.shape[:2]
|
||
gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
|
||
sigma_bg = max(8.0, max(height, width) / 80.0)
|
||
bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
|
||
dark_score = cv2.subtract(bg_gray, gray).astype(np.float32)
|
||
dark_alpha = np.clip((dark_score - 18.0) / 48.0, 0.0, 1.0)
|
||
|
||
hsv = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2HSV)
|
||
hue = hsv[:, :, 0].astype(np.float32)
|
||
saturation = hsv[:, :, 1].astype(np.float32)
|
||
value = hsv[:, :, 2].astype(np.float32)
|
||
red = rgb_image[:, :, 0].astype(np.float32)
|
||
green = rgb_image[:, :, 1].astype(np.float32)
|
||
blue = rgb_image[:, :, 2].astype(np.float32)
|
||
red_dominance = red - np.maximum(green, blue)
|
||
red_hue_mask = ((hue <= 12.0) | (hue >= 168.0)).astype(np.float32)
|
||
seal_alpha = np.clip((red_dominance - 20.0) / 40.0, 0.0, 1.0)
|
||
seal_alpha *= np.clip((saturation - 50.0) / 80.0, 0.0, 1.0)
|
||
seal_alpha *= (value >= 40.0).astype(np.float32)
|
||
seal_alpha *= red_hue_mask
|
||
|
||
detail_alpha = np.maximum(dark_alpha, seal_alpha)
|
||
return np.clip((detail_alpha - 0.07) / 0.93, 0.0, 1.0)
|
||
|
||
|
||
def _extract_artwork_mask(input_image, artwork_type="auto", max_size=1600):
|
||
"""为书画/篆刻作品提取前景掩码,避免依赖通用人像分割模型。"""
|
||
rgb = np.array(input_image.convert("RGB"))
|
||
resized, scale = _resize_for_processing(rgb, max_size=max_size)
|
||
height, width = resized.shape[:2]
|
||
|
||
gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY)
|
||
lab = cv2.cvtColor(resized, cv2.COLOR_RGB2LAB)
|
||
hsv = cv2.cvtColor(resized, cv2.COLOR_RGB2HSV)
|
||
sigma_bg = max(6.0, max(height, width) / 40.0)
|
||
|
||
bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
|
||
dark_score = cv2.subtract(bg_gray, gray)
|
||
|
||
lab_a = lab[:, :, 1]
|
||
lab_b = lab[:, :, 2]
|
||
bg_a = cv2.GaussianBlur(lab_a, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
|
||
bg_b = cv2.GaussianBlur(lab_b, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg)
|
||
chroma_score = np.sqrt(
|
||
(lab_a.astype(np.float32) - bg_a.astype(np.float32)) ** 2
|
||
+ (lab_b.astype(np.float32) - bg_b.astype(np.float32)) ** 2
|
||
)
|
||
chroma_score = np.clip(chroma_score * 3.0, 0, 255).astype(np.uint8)
|
||
|
||
red = resized[:, :, 0].astype(np.int16)
|
||
green = resized[:, :, 1].astype(np.int16)
|
||
blue = resized[:, :, 2].astype(np.int16)
|
||
hue = hsv[:, :, 0]
|
||
saturation = hsv[:, :, 1]
|
||
value = hsv[:, :, 2]
|
||
red_dominance = np.clip(red - np.maximum(green, blue), 0, 255).astype(np.uint8)
|
||
seal_score = np.clip(
|
||
red_dominance.astype(np.int16) * 2
|
||
+ np.clip(saturation.astype(np.int16) - 45, 0, 255),
|
||
0,
|
||
255,
|
||
).astype(np.uint8)
|
||
seal_score = np.where(
|
||
((hue <= 12) | (hue >= 168))
|
||
& (saturation >= 55)
|
||
& (value >= 40)
|
||
& (red_dominance >= 20),
|
||
seal_score,
|
||
0,
|
||
).astype(np.uint8)
|
||
color_score = np.maximum(chroma_score, seal_score)
|
||
|
||
if artwork_type == "seal":
|
||
combined_score = np.maximum(
|
||
np.clip(dark_score.astype(np.float32) * 1.0, 0, 255).astype(np.uint8),
|
||
np.clip(color_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8),
|
||
)
|
||
dark_floor = 12
|
||
color_floor = 18
|
||
combined_floor = 20
|
||
elif artwork_type == "calligraphy":
|
||
combined_score = np.maximum(
|
||
np.clip(dark_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8),
|
||
np.clip(color_score.astype(np.float32) * 0.8, 0, 255).astype(np.uint8),
|
||
)
|
||
dark_floor = 14
|
||
color_floor = 28
|
||
combined_floor = 18
|
||
else:
|
||
combined_score = np.maximum(
|
||
np.clip(dark_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8),
|
||
np.clip(color_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8),
|
||
)
|
||
dark_floor = 14
|
||
color_floor = 16
|
||
combined_floor = 18
|
||
|
||
smooth_sigma = max(0.6, max(height, width) / 320.0)
|
||
dark_smooth = cv2.GaussianBlur(
|
||
dark_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma
|
||
)
|
||
color_smooth = cv2.GaussianBlur(
|
||
color_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma
|
||
)
|
||
combined_smooth = cv2.GaussianBlur(
|
||
combined_score,
|
||
(0, 0),
|
||
sigmaX=smooth_sigma,
|
||
sigmaY=smooth_sigma,
|
||
)
|
||
|
||
dark_mask = (dark_smooth >= _otsu_threshold(dark_smooth, floor=dark_floor)).astype(
|
||
np.uint8
|
||
) * 255
|
||
color_mask = (
|
||
color_smooth >= _otsu_threshold(color_smooth, floor=color_floor)
|
||
).astype(np.uint8) * 255
|
||
combined_mask = (
|
||
combined_smooth >= _otsu_threshold(combined_smooth, floor=combined_floor)
|
||
).astype(np.uint8) * 255
|
||
|
||
if artwork_type == "seal":
|
||
mask = cv2.bitwise_or(combined_mask, color_mask)
|
||
mask = cv2.bitwise_or(mask, dark_mask)
|
||
elif artwork_type == "calligraphy":
|
||
mask = cv2.bitwise_or(combined_mask, dark_mask)
|
||
else:
|
||
mask = cv2.bitwise_or(combined_mask, cv2.bitwise_or(dark_mask, color_mask))
|
||
close_kernel = np.ones((3, 3), dtype=np.uint8)
|
||
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=1)
|
||
min_area = max(12, (height * width) // 25000)
|
||
mask = _remove_small_components(mask, min_area=min_area)
|
||
mask = _remove_border_frame_components(mask)
|
||
mask = cv2.dilate(mask, close_kernel, iterations=1)
|
||
|
||
if scale < 1.0:
|
||
mask = cv2.resize(
|
||
mask, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST
|
||
)
|
||
return mask
|
||
|
||
|
||
def _mask_to_transparent_image(input_image, mask):
|
||
"""将二值掩码转为带透明背景的 RGBA 图片。"""
|
||
rgb = np.array(input_image.convert("RGB"))
|
||
detail_alpha = _compute_detail_alpha(rgb)
|
||
alpha = ((mask.astype(np.float32) / 255.0) * detail_alpha * 255.0).astype(np.uint8)
|
||
alpha = cv2.GaussianBlur(alpha, (0, 0), sigmaX=0.8, sigmaY=0.8)
|
||
alpha = np.where(mask > 0, alpha, 0).astype(np.uint8)
|
||
rgba = np.dstack((rgb, alpha))
|
||
return Image.fromarray(rgba)
|
||
|
||
|
||
def _extract_foreground_mask(
|
||
input_image,
|
||
session=None,
|
||
model_name="isnet-general-use",
|
||
foreground_mode="artwork",
|
||
artwork_type="auto",
|
||
artwork_max_size=1600,
|
||
mask_threshold=10,
|
||
**kwargs,
|
||
):
|
||
"""按指定模式提取前景掩码。"""
|
||
if foreground_mode in {"artwork", "auto"}:
|
||
artwork_mask = _extract_artwork_mask(
|
||
input_image,
|
||
artwork_type=artwork_type,
|
||
max_size=artwork_max_size,
|
||
)
|
||
if foreground_mode == "artwork" or _artwork_mask_is_reasonable(artwork_mask):
|
||
return artwork_mask, "artwork"
|
||
|
||
if foreground_mode not in {"rembg", "auto"}:
|
||
raise ValueError(f"不支持的前景提取模式: {foreground_mode}")
|
||
|
||
if remove is None:
|
||
raise ImportError(
|
||
"未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。"
|
||
)
|
||
if session is None:
|
||
session = _get_rembg_session(model_name)
|
||
output_image = remove(input_image, session=session, **kwargs)
|
||
output_image = _ensure_rgba_size(output_image, input_image.size)
|
||
alpha = np.array(output_image.getchannel("A"))
|
||
return _alpha_to_mask(alpha, threshold=mask_threshold), "rembg"
|
||
|
||
|
||
def remove_background(
|
||
input_path,
|
||
output_path,
|
||
session=None,
|
||
model_name="isnet-general-use",
|
||
foreground_mode="artwork",
|
||
artwork_type="auto",
|
||
artwork_max_size=1600,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
去除图片背景。
|
||
|
||
Args:
|
||
input_path: 输入图片路径
|
||
output_path: 输出图片路径
|
||
session: rembg会话对象(可选)
|
||
model_name: rembg 模型名称
|
||
foreground_mode: 前景提取模式,可选 artwork / auto / rembg
|
||
artwork_type: 书画类型,可选 auto / calligraphy / seal
|
||
artwork_max_size: 书画掩码估算时的最大边
|
||
**kwargs: 其他参数,如alpha_matting相关参数
|
||
"""
|
||
print(f"正在处理: {input_path}")
|
||
|
||
# 读取输入图片
|
||
input_image = ImageOps.exif_transpose(Image.open(input_path))
|
||
|
||
if foreground_mode in {"artwork", "auto"}:
|
||
mask = _extract_artwork_mask(
|
||
input_image,
|
||
artwork_type=artwork_type,
|
||
max_size=artwork_max_size,
|
||
)
|
||
if foreground_mode == "artwork" or _artwork_mask_is_reasonable(mask):
|
||
output_image = _mask_to_transparent_image(input_image, mask)
|
||
output_image.save(output_path)
|
||
print(f"已保存: {output_path}")
|
||
return
|
||
|
||
if remove is None:
|
||
raise ImportError(
|
||
"未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。"
|
||
)
|
||
if session is None:
|
||
session = _get_rembg_session(model_name)
|
||
|
||
output_image = remove(input_image, session=session, **kwargs)
|
||
output_image.save(output_path)
|
||
|
||
print(f"已保存: {output_path}")
|
||
|
||
|
||
def _pil_to_bgr(image):
|
||
"""将PIL图片转换为OpenCV BGR格式"""
|
||
rgb = image.convert("RGB")
|
||
arr = np.array(rgb)
|
||
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
||
|
||
|
||
def _alpha_to_mask(alpha_channel, threshold=10):
|
||
"""将alpha通道转换为二值mask(0或255)"""
|
||
mask = (alpha_channel > threshold).astype(np.uint8) * 255
|
||
return mask
|
||
|
||
|
||
def _ensure_rgba_size(image, target_size):
|
||
if image.mode != "RGBA":
|
||
image = image.convert("RGBA")
|
||
if image.size != target_size:
|
||
image = image.resize(target_size, resample=Image.NEAREST)
|
||
return image
|
||
|
||
|
||
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
|
||
"""生成硬边mask与处理后mask(用于填补/过渡)"""
|
||
if mask_dilate and mask_dilate > 0:
|
||
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
|
||
mask = cv2.dilate(mask, kernel, iterations=1)
|
||
if edge_grow and edge_grow > 0:
|
||
kernel = np.ones((3, 3), np.uint8)
|
||
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
|
||
|
||
mask_hard = mask.copy()
|
||
mask_used = mask_hard
|
||
if mask_blur and mask_blur > 0:
|
||
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
|
||
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
|
||
mask_used = (mask_used > 0).astype(np.uint8) * 255
|
||
return mask_hard, mask_used
|
||
|
||
|
||
def remove_subject_and_inpaint(
|
||
input_path,
|
||
output_path,
|
||
session=None,
|
||
model_name="isnet-general-use",
|
||
foreground_mode="artwork",
|
||
artwork_type="auto",
|
||
artwork_max_size=1600,
|
||
mask_dilate=3,
|
||
mask_blur=3,
|
||
mask_threshold=10,
|
||
edge_grow=0,
|
||
aot_root="AOT-GAN-for-Inpainting",
|
||
aot_pretrain=None,
|
||
aot_device="cpu",
|
||
aot_block_num=8,
|
||
aot_rates="1+2+4+8",
|
||
aot_crop=False,
|
||
aot_crop_pad=0,
|
||
aot_max_size=0,
|
||
aot_noise_prefill=False,
|
||
aot_noise_strength=1.0,
|
||
black_subject=False,
|
||
black_threshold=50,
|
||
gray_subject=False,
|
||
gray_saturation_threshold=30,
|
||
gray_value_threshold=200,
|
||
feather=False,
|
||
feather_radius=5,
|
||
save_mask=False,
|
||
mask_output_path=None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
去掉主体并补全背景
|
||
|
||
Args:
|
||
input_path: 输入图片路径
|
||
output_path: 输出图片路径
|
||
session: rembg会话对象(可选)
|
||
model_name: rembg 模型名称
|
||
foreground_mode: 前景提取模式,可选 artwork / auto / rembg
|
||
artwork_type: 书画类型,可选 auto / calligraphy / seal
|
||
artwork_max_size: 书画掩码估算时的最大边
|
||
aot_root: AOT-GAN目录
|
||
aot_pretrain: AOT-GAN权重文件路径
|
||
aot_device: AOT-GAN设备
|
||
aot_block_num: AOT-GAN AOTBlock 数量
|
||
aot_rates: AOT-GAN AOTBlock 膨胀率
|
||
aot_crop: AOT仅裁剪mask区域进行修补
|
||
aot_crop_pad: AOT裁剪边缘留白像素
|
||
aot_max_size: AOT输入最大边限制(0为不限制)
|
||
aot_noise_prefill: AOT使用随机噪声预填充
|
||
aot_noise_strength: 噪声强度系数
|
||
mask_dilate: mask膨胀大小
|
||
mask_blur: mask模糊大小(奇数)
|
||
mask_threshold: alpha阈值
|
||
save_mask: 是否保存mask
|
||
mask_output_path: mask保存路径
|
||
**kwargs: rembg参数
|
||
"""
|
||
print(f"正在处理(去主体补背景): {input_path}")
|
||
|
||
input_image = ImageOps.exif_transpose(Image.open(input_path))
|
||
|
||
mask, mask_source = _extract_foreground_mask(
|
||
input_image,
|
||
session=session,
|
||
model_name=model_name,
|
||
foreground_mode=foreground_mode,
|
||
artwork_type=artwork_type,
|
||
artwork_max_size=artwork_max_size,
|
||
mask_threshold=mask_threshold,
|
||
**kwargs,
|
||
)
|
||
|
||
bgr = _pil_to_bgr(input_image)
|
||
if black_subject:
|
||
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
||
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
|
||
mask = cv2.bitwise_or(mask, black_mask)
|
||
if gray_subject:
|
||
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||
s = hsv[:, :, 1]
|
||
v = hsv[:, :, 2]
|
||
gray_mask = (
|
||
(s <= gray_saturation_threshold) & (v <= gray_value_threshold)
|
||
).astype(np.uint8) * 255
|
||
mask = cv2.bitwise_or(mask, gray_mask)
|
||
|
||
mask_hard, mask_used = _prepare_mask(
|
||
mask,
|
||
mask_dilate=mask_dilate,
|
||
mask_blur=mask_blur,
|
||
edge_grow=edge_grow,
|
||
)
|
||
|
||
rates = _parse_aot_rates(aot_rates)
|
||
filled = _inpaint_with_aot(
|
||
bgr,
|
||
mask_used,
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
device=aot_device,
|
||
block_num=aot_block_num,
|
||
rates=rates,
|
||
crop=aot_crop,
|
||
crop_pad=aot_crop_pad,
|
||
max_size=aot_max_size,
|
||
noise_prefill=aot_noise_prefill,
|
||
noise_strength=aot_noise_strength,
|
||
)
|
||
|
||
if feather and feather_radius > 0:
|
||
# 仅在mask外侧做过渡,避免把原主体带回
|
||
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
|
||
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
|
||
alpha = np.ones_like(dist_out, dtype=np.float32)
|
||
outside = mask_bin == 0
|
||
alpha[outside] = np.clip(
|
||
1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0
|
||
)
|
||
alpha = alpha[:, :, None]
|
||
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
|
||
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
|
||
else:
|
||
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
|
||
Image.fromarray(result).save(output_path)
|
||
|
||
if save_mask:
|
||
if mask_output_path is None:
|
||
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
|
||
Image.fromarray(mask_used).save(mask_output_path)
|
||
print(f"已保存mask: {mask_output_path}")
|
||
|
||
print(f"前景提取模式: {mask_source}")
|
||
print(f"已保存: {output_path}")
|
||
|
||
|
||
def process_images_folder(
|
||
input_folder,
|
||
output_folder,
|
||
model_name="isnet-general-use",
|
||
foreground_mode="artwork",
|
||
artwork_type="auto",
|
||
artwork_max_size=1600,
|
||
alpha_matting=False,
|
||
alpha_matting_foreground_threshold=240,
|
||
alpha_matting_background_threshold=10,
|
||
alpha_matting_erode_size=10,
|
||
post_process_mask=False,
|
||
remove_subject=False,
|
||
mask_dilate=3,
|
||
mask_blur=3,
|
||
mask_threshold=10,
|
||
edge_grow=0,
|
||
aot_root="AOT-GAN-for-Inpainting",
|
||
aot_pretrain=None,
|
||
aot_device="cpu",
|
||
aot_block_num=8,
|
||
aot_rates="1+2+4+8",
|
||
aot_crop=False,
|
||
aot_crop_pad=0,
|
||
aot_max_size=0,
|
||
aot_noise_prefill=False,
|
||
aot_noise_strength=1.0,
|
||
black_subject=False,
|
||
black_threshold=50,
|
||
gray_subject=False,
|
||
gray_saturation_threshold=30,
|
||
gray_value_threshold=200,
|
||
feather=False,
|
||
feather_radius=5,
|
||
save_mask=False,
|
||
):
|
||
"""
|
||
批量处理文件夹中的所有图片
|
||
|
||
Args:
|
||
input_folder: 输入文件夹路径
|
||
output_folder: 输出文件夹路径
|
||
model_name: rembg 模型名称,可选值:
|
||
- u2net (默认): 通用模型
|
||
- u2netp: 轻量版u2net
|
||
- u2net_human_seg: 人物分割
|
||
- silueta: 精简版u2net (43MB)
|
||
- isnet-general-use: 新的通用模型
|
||
- isnet-anime: 动漫角色高精度分割
|
||
- birefnet-general: 通用模型
|
||
- birefnet-portrait: 人像模型
|
||
foreground_mode: 前景提取模式,可选 artwork / auto / rembg
|
||
artwork_type: 书画类型,可选 auto / calligraphy / seal
|
||
artwork_max_size: 书画掩码估算时的最大边
|
||
alpha_matting: 是否启用alpha matting后处理(改善边缘质量)
|
||
alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景
|
||
alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景
|
||
alpha_matting_erode_size: 侵蚀大小,用于平滑边缘
|
||
post_process_mask: 是否启用mask后处理
|
||
"""
|
||
# 创建输出文件夹
|
||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"前景提取模式: {foreground_mode}")
|
||
print(f"书画类型: {artwork_type}")
|
||
if foreground_mode in {"rembg", "auto"}:
|
||
print(f"rembg模型: {model_name}")
|
||
|
||
# 支持的图片格式
|
||
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
||
if HEIC_SUPPORTED:
|
||
image_extensions.update({".heic", ".heif"})
|
||
else:
|
||
print(
|
||
"提示: 未安装pillow-heif,HEIC格式不可用。安装方法: pip install pillow-heif"
|
||
)
|
||
|
||
# 获取所有图片文件
|
||
input_path = Path(input_folder)
|
||
image_files = [
|
||
f
|
||
for f in input_path.iterdir()
|
||
if f.is_file() and f.suffix.lower() in image_extensions
|
||
]
|
||
|
||
if not image_files:
|
||
print(f"在 {input_folder} 中没有找到图片文件")
|
||
return
|
||
|
||
print(f"找到 {len(image_files)} 张图片,开始处理...")
|
||
print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}")
|
||
if alpha_matting:
|
||
print(f" - 前景阈值: {alpha_matting_foreground_threshold}")
|
||
print(f" - 背景阈值: {alpha_matting_background_threshold}")
|
||
print(f" - 侵蚀大小: {alpha_matting_erode_size}")
|
||
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
|
||
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
|
||
if remove_subject:
|
||
print(f" - AOT目录: {aot_root}")
|
||
print(f" - AOT权重: {aot_pretrain}")
|
||
print(f" - AOT设备: {aot_device}")
|
||
print(f" - AOT块数: {aot_block_num}")
|
||
print(f" - AOT膨胀率: {aot_rates}")
|
||
print(f" - AOT裁剪: {'是' if aot_crop else '否'}")
|
||
if aot_crop:
|
||
print(f" - AOT裁剪边界: {aot_crop_pad}")
|
||
print(f" - AOT最大边: {aot_max_size}")
|
||
print(f" - AOT噪声预填充: {'是' if aot_noise_prefill else '否'}")
|
||
if aot_noise_prefill:
|
||
print(f" - AOT噪声强度: {aot_noise_strength}")
|
||
print(f" - mask膨胀: {mask_dilate}")
|
||
print(f" - mask模糊: {mask_blur}")
|
||
print(f" - mask阈值: {mask_threshold}")
|
||
print(f" - 边缘扩张: {edge_grow}")
|
||
print(f" - 黑色内容作为主体: {'是' if black_subject else '否'}")
|
||
if black_subject:
|
||
print(f" - 黑色阈值: {black_threshold}")
|
||
print(f" - 灰阶内容作为主体: {'是' if gray_subject else '否'}")
|
||
if gray_subject:
|
||
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
|
||
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
|
||
print(f" - 边缘过渡: {'是' if feather else '否'}")
|
||
if feather:
|
||
print(f" - 过渡半径: {feather_radius}")
|
||
print(f" - 保存mask: {'是' if save_mask else '否'}")
|
||
print("-" * 50)
|
||
|
||
# 处理每张图片
|
||
for i, image_file in enumerate(image_files, 1):
|
||
try:
|
||
if remove_subject:
|
||
suffix = image_file.suffix.lower()
|
||
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
|
||
output_filename = image_file.stem + "_bgfill" + suffix
|
||
else:
|
||
output_filename = image_file.stem + "_bgfill.jpg"
|
||
else:
|
||
# 去背景默认使用PNG格式以支持透明背景
|
||
output_filename = image_file.stem + "_nobg.png"
|
||
output_path = Path(output_folder) / output_filename
|
||
|
||
print(f"[{i}/{len(image_files)}] ", end="")
|
||
if remove_subject:
|
||
mask_output_path = None
|
||
if save_mask:
|
||
mask_output_path = str(
|
||
Path(output_folder) / (image_file.stem + "_mask.png")
|
||
)
|
||
remove_subject_and_inpaint(
|
||
str(image_file),
|
||
str(output_path),
|
||
model_name=model_name,
|
||
foreground_mode=foreground_mode,
|
||
artwork_type=artwork_type,
|
||
artwork_max_size=artwork_max_size,
|
||
alpha_matting=alpha_matting,
|
||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||
post_process_mask=post_process_mask,
|
||
mask_dilate=mask_dilate,
|
||
mask_blur=mask_blur,
|
||
mask_threshold=mask_threshold,
|
||
edge_grow=edge_grow,
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
aot_device=aot_device,
|
||
aot_block_num=aot_block_num,
|
||
aot_rates=aot_rates,
|
||
aot_crop=aot_crop,
|
||
aot_crop_pad=aot_crop_pad,
|
||
aot_max_size=aot_max_size,
|
||
aot_noise_prefill=aot_noise_prefill,
|
||
aot_noise_strength=aot_noise_strength,
|
||
black_subject=black_subject,
|
||
black_threshold=black_threshold,
|
||
gray_subject=gray_subject,
|
||
gray_saturation_threshold=gray_saturation_threshold,
|
||
gray_value_threshold=gray_value_threshold,
|
||
feather=feather,
|
||
feather_radius=feather_radius,
|
||
save_mask=save_mask,
|
||
mask_output_path=mask_output_path,
|
||
)
|
||
else:
|
||
remove_background(
|
||
str(image_file),
|
||
str(output_path),
|
||
model_name=model_name,
|
||
foreground_mode=foreground_mode,
|
||
artwork_type=artwork_type,
|
||
artwork_max_size=artwork_max_size,
|
||
alpha_matting=alpha_matting,
|
||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||
post_process_mask=post_process_mask,
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"处理 {image_file.name} 时出错: {e}")
|
||
|
||
print("-" * 50)
|
||
print(f"处理完成!结果保存在 {output_folder} 文件夹中")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(
|
||
description="书画与篆刻作品去背景工具 - 默认使用轻量书画掩码提取",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例用法:
|
||
# 使用默认参数处理images文件夹
|
||
python remove_background.py
|
||
|
||
# 处理单个文件
|
||
python remove_background.py input.jpg output.png
|
||
|
||
# 处理指定文件夹
|
||
python remove_background.py my_images/ my_output/
|
||
|
||
# 强制使用 rembg 旧流程
|
||
python remove_background.py input.jpg output.png --foreground-mode rembg -m isnet-general-use
|
||
|
||
# 指定为篆刻模式
|
||
python remove_background.py input.jpg output.png --artwork-type seal
|
||
|
||
# 自定义alpha matting参数
|
||
python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5
|
||
""",
|
||
)
|
||
|
||
# 必需参数
|
||
parser.add_argument(
|
||
"input",
|
||
nargs="?",
|
||
default="images",
|
||
help="输入文件或文件夹路径(默认: images)",
|
||
)
|
||
parser.add_argument(
|
||
"output",
|
||
nargs="?",
|
||
default=None,
|
||
help="输出文件或文件夹路径(可选,默认为output/)",
|
||
)
|
||
|
||
# 模型选择
|
||
parser.add_argument(
|
||
"-m",
|
||
"--model",
|
||
default="isnet-general-use",
|
||
choices=[
|
||
"u2net",
|
||
"u2netp",
|
||
"u2net_human_seg",
|
||
"silueta",
|
||
"isnet-general-use",
|
||
"isnet-anime",
|
||
"birefnet-general",
|
||
"birefnet-general-lite",
|
||
"birefnet-portrait",
|
||
"birefnet-dis",
|
||
"birefnet-hrsod",
|
||
"birefnet-cod",
|
||
"birefnet-massive",
|
||
],
|
||
help="选择 rembg 模型(仅在 --foreground-mode rembg/auto 时使用)",
|
||
)
|
||
parser.add_argument(
|
||
"--foreground-mode",
|
||
default="artwork",
|
||
choices=["artwork", "auto", "rembg"],
|
||
help="前景提取模式:artwork 为书画专用快速模式,auto 失败时回退 rembg,rembg 为旧流程 (默认: artwork)",
|
||
)
|
||
parser.add_argument(
|
||
"--artwork-type",
|
||
default="auto",
|
||
choices=["auto", "calligraphy", "seal"],
|
||
help="书画类型:auto 自动兼容书法与篆刻,seal 更偏重红色印章,calligraphy 更偏重墨色笔画 (默认: auto)",
|
||
)
|
||
parser.add_argument(
|
||
"--artwork-max-size",
|
||
type=int,
|
||
default=1600,
|
||
help="书画掩码估算时的最大边,越小越快,越大越精细 (默认: 1600)",
|
||
)
|
||
|
||
# Alpha Matting参数
|
||
parser.add_argument(
|
||
"-a",
|
||
"--alpha-matting",
|
||
"--alpha_matting",
|
||
action="store_true",
|
||
help="启用alpha matting后处理(默认: false)",
|
||
)
|
||
parser.add_argument(
|
||
"-ft",
|
||
"--foreground-threshold",
|
||
type=int,
|
||
default=245,
|
||
help="前景阈值 (0-255),值越大保留越多细节 (默认: 245)",
|
||
)
|
||
parser.add_argument(
|
||
"-bt",
|
||
"--background-threshold",
|
||
type=int,
|
||
default=8,
|
||
help="背景阈值 (0-255),值越大去除越多背景 (默认: 8)",
|
||
)
|
||
parser.add_argument(
|
||
"-es",
|
||
"--erode-size",
|
||
type=int,
|
||
default=2,
|
||
help="侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)",
|
||
)
|
||
|
||
# 其他选项
|
||
parser.add_argument(
|
||
"-p",
|
||
"--post-process",
|
||
"--post_process",
|
||
action="store_true",
|
||
help="启用mask后处理(默认: false)",
|
||
)
|
||
|
||
# 去主体补背景参数
|
||
parser.add_argument(
|
||
"--remove-subject",
|
||
"--remove_subject",
|
||
action="store_true",
|
||
help="去掉主体并补全背景(默认: false)",
|
||
)
|
||
parser.add_argument(
|
||
"--aot-root",
|
||
type=str,
|
||
default="AOT-GAN-for-Inpainting",
|
||
help="AOT-GAN目录(默认: AOT-GAN-for-Inpainting)",
|
||
)
|
||
parser.add_argument(
|
||
"--aot-pretrain",
|
||
type=str,
|
||
default=None,
|
||
help="AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)",
|
||
)
|
||
parser.add_argument(
|
||
"--aot-device", type=str, default="cpu", help="AOT-GAN设备(默认: cpu)"
|
||
)
|
||
parser.add_argument(
|
||
"--aot-block-num", type=int, default=8, help="AOTBlock数量(默认: 8)"
|
||
)
|
||
parser.add_argument(
|
||
"--aot-rates",
|
||
type=str,
|
||
default="1+2+4+8",
|
||
help="AOTBlock膨胀率(默认: 1+2+4+8)",
|
||
)
|
||
parser.add_argument(
|
||
"--aot-crop", action="store_true", help="AOT仅对mask区域裁剪修补(默认: false)"
|
||
)
|
||
parser.add_argument(
|
||
"--aot-crop-pad", type=int, default=0, help="AOT裁剪边缘留白像素(默认: 0)"
|
||
)
|
||
parser.add_argument(
|
||
"--aot-max-size",
|
||
type=int,
|
||
default=0,
|
||
help="AOT输入最大边限制,0为不限制(默认: 0)",
|
||
)
|
||
parser.add_argument(
|
||
"--aot-noise-prefill",
|
||
action="store_true",
|
||
help="AOT使用随机噪声预填充(默认: false)",
|
||
)
|
||
parser.add_argument(
|
||
"--aot-noise-strength",
|
||
type=float,
|
||
default=1.0,
|
||
help="AOT噪声强度系数(默认: 1.0)",
|
||
)
|
||
parser.add_argument(
|
||
"--mask-dilate", type=int, default=3, help="mask膨胀大小(默认: 3)"
|
||
)
|
||
parser.add_argument(
|
||
"--mask-blur", type=int, default=3, help="mask模糊大小(默认: 3,建议奇数)"
|
||
)
|
||
parser.add_argument(
|
||
"--mask-threshold", type=int, default=10, help="alpha阈值(默认: 10)"
|
||
)
|
||
parser.add_argument(
|
||
"--edge-grow", type=int, default=0, help="主体边缘扩张像素(默认: 0)"
|
||
)
|
||
parser.add_argument(
|
||
"--save-mask",
|
||
"--save_mask",
|
||
action="store_true",
|
||
help="保存mask到output目录(默认: false)",
|
||
)
|
||
parser.add_argument(
|
||
"--black-subject",
|
||
"--black_subject",
|
||
action="store_true",
|
||
help="将黑色内容也视为主体(默认: false)",
|
||
)
|
||
parser.add_argument(
|
||
"--black-threshold",
|
||
type=int,
|
||
default=50,
|
||
help="黑色阈值(0-255,灰度越小越黑,默认: 50)",
|
||
)
|
||
parser.add_argument(
|
||
"--gray-subject",
|
||
"--gray_subject",
|
||
action="store_true",
|
||
help="将灰阶内容也视为主体(默认: false)",
|
||
)
|
||
parser.add_argument(
|
||
"--gray-saturation-threshold",
|
||
type=int,
|
||
default=30,
|
||
help="灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)",
|
||
)
|
||
parser.add_argument(
|
||
"--gray-value-threshold",
|
||
type=int,
|
||
default=200,
|
||
help="灰阶亮度阈值(0-255,越小越暗,默认: 200)",
|
||
)
|
||
parser.add_argument(
|
||
"--feather", action="store_true", help="启用边缘过渡融合(默认: false)"
|
||
)
|
||
parser.add_argument(
|
||
"--feather-radius", type=int, default=5, help="边缘过渡半径(默认: 5)"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 50)
|
||
print("图片去背景工具")
|
||
print("=" * 50)
|
||
|
||
# 判断输入是文件还是文件夹
|
||
input_path = Path(args.input)
|
||
|
||
if not input_path.exists():
|
||
print(f"错误: 输入路径不存在: {args.input}")
|
||
exit(1)
|
||
|
||
# 处理单个文件
|
||
if input_path.is_file():
|
||
suffix = input_path.suffix.lower()
|
||
if args.remove_subject:
|
||
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
|
||
output_name = input_path.stem + "_bgfill" + suffix
|
||
else:
|
||
output_name = input_path.stem + "_bgfill.jpg"
|
||
else:
|
||
output_name = input_path.stem + "_nobg.png"
|
||
|
||
# 确定输出路径
|
||
if args.output is None:
|
||
if args.remove_subject:
|
||
output_path = input_path.parent / "output" / output_name
|
||
else:
|
||
output_path = input_path.parent / "output" / output_name
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
else:
|
||
output_candidate = Path(args.output)
|
||
if output_candidate.exists() and output_candidate.is_dir():
|
||
output_path = output_candidate / output_name
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
elif output_candidate.suffix == "":
|
||
output_candidate.mkdir(parents=True, exist_ok=True)
|
||
output_path = output_candidate / output_name
|
||
else:
|
||
output_path = output_candidate
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"输入文件: {input_path}")
|
||
print(f"输出文件: {output_path}")
|
||
print(f"前景提取模式: {args.foreground_mode}")
|
||
print(f"书画类型: {args.artwork_type}")
|
||
print(f"书画最大边: {args.artwork_max_size}")
|
||
if args.foreground_mode in {"rembg", "auto"}:
|
||
print(f"rembg模型: {args.model}")
|
||
print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}")
|
||
if args.alpha_matting:
|
||
print(f" - 前景阈值: {args.foreground_threshold}")
|
||
print(f" - 背景阈值: {args.background_threshold}")
|
||
print(f" - 侵蚀大小: {args.erode_size}")
|
||
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
|
||
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
|
||
if args.remove_subject:
|
||
print(f" - AOT目录: {args.aot_root}")
|
||
print(f" - AOT权重: {args.aot_pretrain}")
|
||
print(f" - AOT设备: {args.aot_device}")
|
||
print(f" - AOT块数: {args.aot_block_num}")
|
||
print(f" - AOT膨胀率: {args.aot_rates}")
|
||
print(f" - AOT裁剪: {'是' if args.aot_crop else '否'}")
|
||
if args.aot_crop:
|
||
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
|
||
print(f" - AOT最大边: {args.aot_max_size}")
|
||
print(f" - AOT噪声预填充: {'是' if args.aot_noise_prefill else '否'}")
|
||
if args.aot_noise_prefill:
|
||
print(f" - AOT噪声强度: {args.aot_noise_strength}")
|
||
print(f" - mask膨胀: {args.mask_dilate}")
|
||
print(f" - mask模糊: {args.mask_blur}")
|
||
print(f" - mask阈值: {args.mask_threshold}")
|
||
print(f" - 边缘扩张: {args.edge_grow}")
|
||
print(f" - 黑色内容作为主体: {'是' if args.black_subject else '否'}")
|
||
if args.black_subject:
|
||
print(f" - 黑色阈值: {args.black_threshold}")
|
||
print(f" - 灰阶内容作为主体: {'是' if args.gray_subject else '否'}")
|
||
if args.gray_subject:
|
||
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
|
||
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
|
||
print(f" - 边缘过渡: {'是' if args.feather else '否'}")
|
||
if args.feather:
|
||
print(f" - 过渡半径: {args.feather_radius}")
|
||
print(f" - 保存mask: {'是' if args.save_mask else '否'}")
|
||
print("-" * 50)
|
||
|
||
# 处理图片
|
||
if args.remove_subject:
|
||
mask_output_path = None
|
||
if args.save_mask:
|
||
mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
|
||
remove_subject_and_inpaint(
|
||
str(input_path),
|
||
str(output_path),
|
||
model_name=args.model,
|
||
foreground_mode=args.foreground_mode,
|
||
artwork_type=args.artwork_type,
|
||
artwork_max_size=args.artwork_max_size,
|
||
alpha_matting=args.alpha_matting,
|
||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||
alpha_matting_background_threshold=args.background_threshold,
|
||
alpha_matting_erode_size=args.erode_size,
|
||
post_process_mask=args.post_process,
|
||
mask_dilate=args.mask_dilate,
|
||
mask_blur=args.mask_blur,
|
||
mask_threshold=args.mask_threshold,
|
||
edge_grow=args.edge_grow,
|
||
aot_root=args.aot_root,
|
||
aot_pretrain=args.aot_pretrain,
|
||
aot_device=args.aot_device,
|
||
aot_block_num=args.aot_block_num,
|
||
aot_rates=args.aot_rates,
|
||
aot_crop=args.aot_crop,
|
||
aot_crop_pad=args.aot_crop_pad,
|
||
aot_max_size=args.aot_max_size,
|
||
aot_noise_prefill=args.aot_noise_prefill,
|
||
aot_noise_strength=args.aot_noise_strength,
|
||
black_subject=args.black_subject,
|
||
black_threshold=args.black_threshold,
|
||
gray_subject=args.gray_subject,
|
||
gray_saturation_threshold=args.gray_saturation_threshold,
|
||
gray_value_threshold=args.gray_value_threshold,
|
||
feather=args.feather,
|
||
feather_radius=args.feather_radius,
|
||
save_mask=args.save_mask,
|
||
mask_output_path=mask_output_path,
|
||
)
|
||
else:
|
||
remove_background(
|
||
str(input_path),
|
||
str(output_path),
|
||
model_name=args.model,
|
||
foreground_mode=args.foreground_mode,
|
||
artwork_type=args.artwork_type,
|
||
artwork_max_size=args.artwork_max_size,
|
||
alpha_matting=args.alpha_matting,
|
||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||
alpha_matting_background_threshold=args.background_threshold,
|
||
alpha_matting_erode_size=args.erode_size,
|
||
post_process_mask=args.post_process,
|
||
)
|
||
|
||
print("-" * 50)
|
||
print(f"处理完成!结果保存在: {output_path}")
|
||
|
||
# 处理文件夹
|
||
elif input_path.is_dir():
|
||
output_folder = args.output if args.output else "output"
|
||
|
||
process_images_folder(
|
||
str(input_path),
|
||
output_folder,
|
||
model_name=args.model,
|
||
foreground_mode=args.foreground_mode,
|
||
artwork_type=args.artwork_type,
|
||
artwork_max_size=args.artwork_max_size,
|
||
alpha_matting=args.alpha_matting,
|
||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||
alpha_matting_background_threshold=args.background_threshold,
|
||
alpha_matting_erode_size=args.erode_size,
|
||
post_process_mask=args.post_process,
|
||
remove_subject=args.remove_subject,
|
||
mask_dilate=args.mask_dilate,
|
||
mask_blur=args.mask_blur,
|
||
mask_threshold=args.mask_threshold,
|
||
edge_grow=args.edge_grow,
|
||
aot_root=args.aot_root,
|
||
aot_pretrain=args.aot_pretrain,
|
||
aot_device=args.aot_device,
|
||
aot_block_num=args.aot_block_num,
|
||
aot_rates=args.aot_rates,
|
||
aot_crop=args.aot_crop,
|
||
aot_crop_pad=args.aot_crop_pad,
|
||
aot_max_size=args.aot_max_size,
|
||
aot_noise_prefill=args.aot_noise_prefill,
|
||
aot_noise_strength=args.aot_noise_strength,
|
||
black_subject=args.black_subject,
|
||
black_threshold=args.black_threshold,
|
||
gray_subject=args.gray_subject,
|
||
gray_saturation_threshold=args.gray_saturation_threshold,
|
||
gray_value_threshold=args.gray_value_threshold,
|
||
feather=args.feather,
|
||
feather_radius=args.feather_radius,
|
||
save_mask=args.save_mask,
|
||
)
|
||
|
||
else:
|
||
print(f"错误: 不支持的输入类型: {args.input}")
|
||
exit(1)
|