941 lines
36 KiB
Python
941 lines
36 KiB
Python
"""
|
||
图片去背景工具
|
||
使用rembg库自动去除图片背景
|
||
"""
|
||
import os
|
||
import sys
|
||
import argparse
|
||
from pathlib import Path
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image, ImageOps
|
||
|
||
# 避免 numba 在某些环境下缓存失败
|
||
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
|
||
|
||
from rembg import remove, new_session
|
||
|
||
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
|
||
|
||
def _resolve_aot_paths(aot_root, aot_pretrain):
|
||
"""解析 AOT-GAN 的路径配置。"""
|
||
aot_root = Path(aot_root).resolve()
|
||
if aot_pretrain is None:
|
||
aot_pretrain = None
|
||
else:
|
||
aot_pretrain = Path(aot_pretrain)
|
||
if not aot_pretrain.is_absolute():
|
||
# 预训练权重相对路径以当前工作目录为基准
|
||
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
|
||
return aot_root, aot_pretrain
|
||
|
||
def _parse_aot_rates(rates_str):
|
||
parts = [p for p in rates_str.split("+") if p]
|
||
return [int(p) for p in parts]
|
||
|
||
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
|
||
"""加载/缓存 AOT-GAN 模型。"""
|
||
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
|
||
if not aot_root.exists():
|
||
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
|
||
if aot_pretrain is None:
|
||
raise ValueError("AOT-GAN需要指定预训练权重路径(--aot-pretrain)")
|
||
if not aot_pretrain.exists():
|
||
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
|
||
if rates is None:
|
||
rates = [1, 2, 4, 8]
|
||
rates_tuple = tuple(rates)
|
||
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
|
||
if key in AOT_MODEL_CACHE:
|
||
return AOT_MODEL_CACHE[key]
|
||
|
||
src_root = aot_root / "src"
|
||
if str(src_root) not in sys.path:
|
||
sys.path.insert(0, str(src_root))
|
||
import importlib
|
||
try:
|
||
import torch
|
||
except ImportError as exc:
|
||
raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc
|
||
|
||
net = importlib.import_module("model.aotgan")
|
||
class _Args:
|
||
pass
|
||
args = _Args()
|
||
args.block_num = int(block_num)
|
||
args.rates = rates
|
||
model = net.InpaintGenerator(args)
|
||
state = torch.load(str(aot_pretrain), map_location=device)
|
||
if isinstance(state, dict):
|
||
if "state_dict" in state:
|
||
state = state["state_dict"]
|
||
elif "model" in state:
|
||
state = state["model"]
|
||
elif "generator" in state:
|
||
state = state["generator"]
|
||
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
|
||
state = {k.replace("module.", "", 1): v for k, v in state.items()}
|
||
model.load_state_dict(state, strict=True)
|
||
model.to(device)
|
||
model.eval()
|
||
AOT_MODEL_CACHE[key] = (model, device)
|
||
return AOT_MODEL_CACHE[key]
|
||
|
||
def _mask_bbox(mask):
|
||
ys, xs = np.where(mask > 0)
|
||
if len(xs) == 0:
|
||
return None
|
||
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
|
||
|
||
def _expand_bbox(bbox, pad, width, height):
|
||
if bbox is None:
|
||
return None
|
||
x0, y0, x1, y1 = bbox
|
||
pad = max(0, int(pad))
|
||
x0 = max(0, x0 - pad)
|
||
y0 = max(0, y0 - pad)
|
||
x1 = min(width - 1, x1 + pad)
|
||
y1 = min(height - 1, y1 + pad)
|
||
return x0, y0, x1, y1
|
||
|
||
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
|
||
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
|
||
expanded = _expand_bbox(bbox, pad, width, height)
|
||
if expanded is None:
|
||
return None
|
||
x0, y0, x1, y1 = expanded
|
||
roi_w = x1 - x0 + 1
|
||
roi_h = y1 - y0 + 1
|
||
need_w = max(0, int(min_size) - roi_w)
|
||
need_h = max(0, int(min_size) - roi_h)
|
||
if need_w == 0 and need_h == 0:
|
||
return expanded
|
||
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
|
||
return _expand_bbox(bbox, pad + extra, width, height)
|
||
|
||
def _build_noise_prefill(img, mask_t, strength):
|
||
"""为mask区域生成噪声预填充,img取值范围为[-1, 1]。"""
|
||
import torch
|
||
img01 = (img + 1.0) / 2.0
|
||
unmasked = 1.0 - mask_t
|
||
denom = unmasked.sum()
|
||
if denom.item() < 1.0:
|
||
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
|
||
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
|
||
else:
|
||
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
|
||
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
|
||
std = torch.sqrt(var + 1e-6)
|
||
noise = mean + std * float(strength) * torch.randn_like(img01)
|
||
noise = noise.clamp(0.0, 1.0)
|
||
return noise * 2.0 - 1.0
|
||
|
||
def _inpaint_with_aot_core(
|
||
bgr,
|
||
mask,
|
||
aot_root,
|
||
aot_pretrain,
|
||
device="cpu",
|
||
block_num=8,
|
||
rates=None,
|
||
noise_prefill=False,
|
||
noise_strength=1.0,
|
||
):
|
||
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
|
||
if mask.dtype != np.uint8:
|
||
mask = mask.astype(np.uint8)
|
||
if mask.max() <= 1:
|
||
mask = mask * 255
|
||
if mask.max() == 0:
|
||
return bgr
|
||
|
||
h, w = mask.shape
|
||
grid = 4
|
||
h2 = (h // grid) * grid
|
||
w2 = (w // grid) * grid
|
||
if h2 == 0 or w2 == 0:
|
||
return bgr
|
||
|
||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||
img_crop = rgb[:h2, :w2, :]
|
||
mask_crop = mask[:h2, :w2]
|
||
|
||
import torch
|
||
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
|
||
img = img * 2.0 - 1.0
|
||
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
|
||
img = img.unsqueeze(0)
|
||
mask_t = mask_t.unsqueeze(0)
|
||
if device != "cpu":
|
||
img = img.to(device)
|
||
mask_t = mask_t.to(device)
|
||
|
||
model, _ = _get_aot_model(
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
device=device,
|
||
block_num=block_num,
|
||
rates=rates,
|
||
)
|
||
with torch.no_grad():
|
||
if noise_prefill:
|
||
noise = _build_noise_prefill(img, mask_t, noise_strength)
|
||
image_masked = img * (1 - mask_t) + noise * mask_t
|
||
else:
|
||
image_masked = img * (1 - mask_t) + mask_t
|
||
pred = model(image_masked, mask_t)
|
||
comp = pred * mask_t + img * (1 - mask_t)
|
||
comp = comp[0].clamp(-1.0, 1.0)
|
||
comp = (comp + 1.0) / 2.0 * 255.0
|
||
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
|
||
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
|
||
|
||
if h2 == h and w2 == w:
|
||
return result_bgr
|
||
output_full = bgr.copy()
|
||
output_full[:h2, :w2, :] = result_bgr
|
||
return output_full
|
||
|
||
def _inpaint_with_aot(
|
||
bgr,
|
||
mask,
|
||
aot_root,
|
||
aot_pretrain,
|
||
device="cpu",
|
||
block_num=8,
|
||
rates=None,
|
||
crop=False,
|
||
crop_pad=0,
|
||
max_size=0,
|
||
noise_prefill=False,
|
||
noise_strength=1.0,
|
||
):
|
||
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
|
||
min_side = 32
|
||
if mask.dtype != np.uint8:
|
||
mask = mask.astype(np.uint8)
|
||
if mask.max() <= 1:
|
||
mask = mask * 255
|
||
if mask.max() == 0:
|
||
return bgr
|
||
|
||
h, w = mask.shape
|
||
x0 = y0 = 0
|
||
x1 = w - 1
|
||
y1 = h - 1
|
||
if crop:
|
||
bbox = _mask_bbox(mask)
|
||
if bbox is None:
|
||
return bgr
|
||
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
|
||
if expanded is None:
|
||
return bgr
|
||
x0, y0, x1, y1 = expanded
|
||
roi_w = x1 - x0 + 1
|
||
roi_h = y1 - y0 + 1
|
||
if roi_w < min_side or roi_h < min_side:
|
||
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
|
||
crop = False
|
||
x0 = y0 = 0
|
||
x1 = w - 1
|
||
y1 = h - 1
|
||
|
||
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
|
||
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
|
||
roi_h, roi_w = bgr_roi.shape[:2]
|
||
|
||
scale = 1.0
|
||
max_size = int(max_size) if max_size else 0
|
||
if max_size > 0:
|
||
max_dim = max(roi_h, roi_w)
|
||
if max_dim > max_size:
|
||
scale = max_size / float(max_dim)
|
||
new_w = max(1, int(round(roi_w * scale)))
|
||
new_h = max(1, int(round(roi_h * scale)))
|
||
if min(new_w, new_h) < min_side:
|
||
scale = 1.0
|
||
new_w = roi_w
|
||
new_h = roi_h
|
||
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
|
||
if scale != 1.0:
|
||
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
|
||
mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
||
|
||
filled_roi = _inpaint_with_aot_core(
|
||
bgr_roi,
|
||
mask_roi,
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
device=device,
|
||
block_num=block_num,
|
||
rates=rates,
|
||
noise_prefill=noise_prefill,
|
||
noise_strength=noise_strength,
|
||
)
|
||
|
||
if scale != 1.0:
|
||
filled_roi = cv2.resize(
|
||
filled_roi,
|
||
(roi_w, roi_h),
|
||
interpolation=cv2.INTER_LINEAR,
|
||
)
|
||
|
||
if not crop:
|
||
return filled_roi
|
||
output_full = bgr.copy()
|
||
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
|
||
return output_full
|
||
|
||
# 支持HEIC格式
|
||
try:
|
||
from pillow_heif import register_heif_opener
|
||
register_heif_opener()
|
||
HEIC_SUPPORTED = True
|
||
except ImportError:
|
||
HEIC_SUPPORTED = False
|
||
|
||
def remove_background(input_path, output_path, session=None, **kwargs):
|
||
"""
|
||
去除图片背景
|
||
|
||
Args:
|
||
input_path: 输入图片路径
|
||
output_path: 输出图片路径
|
||
session: rembg会话对象(可选)
|
||
**kwargs: 其他参数,如alpha_matting相关参数
|
||
"""
|
||
print(f"正在处理: {input_path}")
|
||
|
||
# 读取输入图片
|
||
input_image = ImageOps.exif_transpose(Image.open(input_path))
|
||
|
||
# 去除背景
|
||
output_image = remove(input_image, session=session, **kwargs)
|
||
|
||
# 保存输出图片
|
||
output_image.save(output_path)
|
||
|
||
print(f"已保存: {output_path}")
|
||
|
||
def _pil_to_bgr(image):
|
||
"""将PIL图片转换为OpenCV BGR格式"""
|
||
rgb = image.convert("RGB")
|
||
arr = np.array(rgb)
|
||
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
||
|
||
def _alpha_to_mask(alpha_channel, threshold=10):
|
||
"""将alpha通道转换为二值mask(0或255)"""
|
||
mask = (alpha_channel > threshold).astype(np.uint8) * 255
|
||
return mask
|
||
|
||
def _ensure_rgba_size(image, target_size):
|
||
if image.mode != "RGBA":
|
||
image = image.convert("RGBA")
|
||
if image.size != target_size:
|
||
image = image.resize(target_size, resample=Image.NEAREST)
|
||
return image
|
||
|
||
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
|
||
"""生成硬边mask与处理后mask(用于填补/过渡)"""
|
||
if mask_dilate and mask_dilate > 0:
|
||
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
|
||
mask = cv2.dilate(mask, kernel, iterations=1)
|
||
if edge_grow and edge_grow > 0:
|
||
kernel = np.ones((3, 3), np.uint8)
|
||
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
|
||
|
||
mask_hard = mask.copy()
|
||
mask_used = mask_hard
|
||
if mask_blur and mask_blur > 0:
|
||
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
|
||
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
|
||
mask_used = (mask_used > 0).astype(np.uint8) * 255
|
||
return mask_hard, mask_used
|
||
|
||
def remove_subject_and_inpaint(
|
||
input_path,
|
||
output_path,
|
||
session=None,
|
||
mask_dilate=3,
|
||
mask_blur=3,
|
||
mask_threshold=10,
|
||
edge_grow=0,
|
||
aot_root="AOT-GAN-for-Inpainting",
|
||
aot_pretrain=None,
|
||
aot_device="cpu",
|
||
aot_block_num=8,
|
||
aot_rates="1+2+4+8",
|
||
aot_crop=False,
|
||
aot_crop_pad=0,
|
||
aot_max_size=0,
|
||
aot_noise_prefill=False,
|
||
aot_noise_strength=1.0,
|
||
black_subject=False,
|
||
black_threshold=50,
|
||
gray_subject=False,
|
||
gray_saturation_threshold=30,
|
||
gray_value_threshold=200,
|
||
feather=False,
|
||
feather_radius=5,
|
||
save_mask=False,
|
||
mask_output_path=None,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
去掉主体并补全背景
|
||
|
||
Args:
|
||
input_path: 输入图片路径
|
||
output_path: 输出图片路径
|
||
session: rembg会话对象(可选)
|
||
aot_root: AOT-GAN目录
|
||
aot_pretrain: AOT-GAN权重文件路径
|
||
aot_device: AOT-GAN设备
|
||
aot_block_num: AOT-GAN AOTBlock 数量
|
||
aot_rates: AOT-GAN AOTBlock 膨胀率
|
||
aot_crop: AOT仅裁剪mask区域进行修补
|
||
aot_crop_pad: AOT裁剪边缘留白像素
|
||
aot_max_size: AOT输入最大边限制(0为不限制)
|
||
aot_noise_prefill: AOT使用随机噪声预填充
|
||
aot_noise_strength: 噪声强度系数
|
||
mask_dilate: mask膨胀大小
|
||
mask_blur: mask模糊大小(奇数)
|
||
mask_threshold: alpha阈值
|
||
save_mask: 是否保存mask
|
||
mask_output_path: mask保存路径
|
||
**kwargs: rembg参数
|
||
"""
|
||
print(f"正在处理(去主体补背景): {input_path}")
|
||
|
||
input_image = ImageOps.exif_transpose(Image.open(input_path))
|
||
|
||
# 使用rembg获取主体mask
|
||
output_image = remove(input_image, session=session, **kwargs)
|
||
output_image = _ensure_rgba_size(output_image, input_image.size)
|
||
|
||
alpha = np.array(output_image.getchannel("A"))
|
||
mask = _alpha_to_mask(alpha, threshold=mask_threshold)
|
||
|
||
bgr = _pil_to_bgr(input_image)
|
||
if black_subject:
|
||
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
||
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
|
||
mask = cv2.bitwise_or(mask, black_mask)
|
||
if gray_subject:
|
||
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||
s = hsv[:, :, 1]
|
||
v = hsv[:, :, 2]
|
||
gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255
|
||
mask = cv2.bitwise_or(mask, gray_mask)
|
||
|
||
mask_hard, mask_used = _prepare_mask(
|
||
mask,
|
||
mask_dilate=mask_dilate,
|
||
mask_blur=mask_blur,
|
||
edge_grow=edge_grow,
|
||
)
|
||
|
||
rates = _parse_aot_rates(aot_rates)
|
||
filled = _inpaint_with_aot(
|
||
bgr,
|
||
mask_used,
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
device=aot_device,
|
||
block_num=aot_block_num,
|
||
rates=rates,
|
||
crop=aot_crop,
|
||
crop_pad=aot_crop_pad,
|
||
max_size=aot_max_size,
|
||
noise_prefill=aot_noise_prefill,
|
||
noise_strength=aot_noise_strength,
|
||
)
|
||
|
||
if feather and feather_radius > 0:
|
||
# 仅在mask外侧做过渡,避免把原主体带回
|
||
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
|
||
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
|
||
alpha = np.ones_like(dist_out, dtype=np.float32)
|
||
outside = mask_bin == 0
|
||
alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0)
|
||
alpha = alpha[:, :, None]
|
||
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
|
||
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
|
||
else:
|
||
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
|
||
Image.fromarray(result).save(output_path)
|
||
|
||
if save_mask:
|
||
if mask_output_path is None:
|
||
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
|
||
Image.fromarray(mask_used).save(mask_output_path)
|
||
print(f"已保存mask: {mask_output_path}")
|
||
|
||
print(f"已保存: {output_path}")
|
||
|
||
def process_images_folder(
|
||
input_folder,
|
||
output_folder,
|
||
model_name="u2net",
|
||
alpha_matting=False,
|
||
alpha_matting_foreground_threshold=240,
|
||
alpha_matting_background_threshold=10,
|
||
alpha_matting_erode_size=10,
|
||
post_process_mask=False,
|
||
remove_subject=False,
|
||
mask_dilate=3,
|
||
mask_blur=3,
|
||
mask_threshold=10,
|
||
edge_grow=0,
|
||
aot_root="AOT-GAN-for-Inpainting",
|
||
aot_pretrain=None,
|
||
aot_device="cpu",
|
||
aot_block_num=8,
|
||
aot_rates="1+2+4+8",
|
||
aot_crop=False,
|
||
aot_crop_pad=0,
|
||
aot_max_size=0,
|
||
aot_noise_prefill=False,
|
||
aot_noise_strength=1.0,
|
||
black_subject=False,
|
||
black_threshold=50,
|
||
gray_subject=False,
|
||
gray_saturation_threshold=30,
|
||
gray_value_threshold=200,
|
||
feather=False,
|
||
feather_radius=5,
|
||
save_mask=False,
|
||
):
|
||
"""
|
||
批量处理文件夹中的所有图片
|
||
|
||
Args:
|
||
input_folder: 输入文件夹路径
|
||
output_folder: 输出文件夹路径
|
||
model_name: 模型名称,可选值:
|
||
- u2net (默认): 通用模型
|
||
- u2netp: 轻量版u2net
|
||
- u2net_human_seg: 人物分割
|
||
- silueta: 精简版u2net (43MB)
|
||
- isnet-general-use: 新的通用模型
|
||
- isnet-anime: 动漫角色高精度分割
|
||
- birefnet-general: 通用模型
|
||
- birefnet-portrait: 人像模型
|
||
alpha_matting: 是否启用alpha matting后处理(改善边缘质量)
|
||
alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景
|
||
alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景
|
||
alpha_matting_erode_size: 侵蚀大小,用于平滑边缘
|
||
post_process_mask: 是否启用mask后处理
|
||
"""
|
||
# 创建输出文件夹
|
||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||
|
||
# 创建会话(重用会话可以提高性能)
|
||
print(f"使用模型: {model_name}")
|
||
session = new_session(model_name)
|
||
|
||
# 支持的图片格式
|
||
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
|
||
if HEIC_SUPPORTED:
|
||
image_extensions.update({'.heic', '.heif'})
|
||
else:
|
||
print("提示: 未安装pillow-heif,HEIC格式不可用。安装方法: pip install pillow-heif")
|
||
|
||
# 获取所有图片文件
|
||
input_path = Path(input_folder)
|
||
image_files = [f for f in input_path.iterdir()
|
||
if f.is_file() and f.suffix.lower() in image_extensions]
|
||
|
||
if not image_files:
|
||
print(f"在 {input_folder} 中没有找到图片文件")
|
||
return
|
||
|
||
print(f"找到 {len(image_files)} 张图片,开始处理...")
|
||
print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}")
|
||
if alpha_matting:
|
||
print(f" - 前景阈值: {alpha_matting_foreground_threshold}")
|
||
print(f" - 背景阈值: {alpha_matting_background_threshold}")
|
||
print(f" - 侵蚀大小: {alpha_matting_erode_size}")
|
||
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
|
||
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
|
||
if remove_subject:
|
||
print(f" - AOT目录: {aot_root}")
|
||
print(f" - AOT权重: {aot_pretrain}")
|
||
print(f" - AOT设备: {aot_device}")
|
||
print(f" - AOT块数: {aot_block_num}")
|
||
print(f" - AOT膨胀率: {aot_rates}")
|
||
print(f" - AOT裁剪: {'是' if aot_crop else '否'}")
|
||
if aot_crop:
|
||
print(f" - AOT裁剪边界: {aot_crop_pad}")
|
||
print(f" - AOT最大边: {aot_max_size}")
|
||
print(f" - AOT噪声预填充: {'是' if aot_noise_prefill else '否'}")
|
||
if aot_noise_prefill:
|
||
print(f" - AOT噪声强度: {aot_noise_strength}")
|
||
print(f" - mask膨胀: {mask_dilate}")
|
||
print(f" - mask模糊: {mask_blur}")
|
||
print(f" - mask阈值: {mask_threshold}")
|
||
print(f" - 边缘扩张: {edge_grow}")
|
||
print(f" - 黑色内容作为主体: {'是' if black_subject else '否'}")
|
||
if black_subject:
|
||
print(f" - 黑色阈值: {black_threshold}")
|
||
print(f" - 灰阶内容作为主体: {'是' if gray_subject else '否'}")
|
||
if gray_subject:
|
||
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
|
||
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
|
||
print(f" - 边缘过渡: {'是' if feather else '否'}")
|
||
if feather:
|
||
print(f" - 过渡半径: {feather_radius}")
|
||
print(f" - 保存mask: {'是' if save_mask else '否'}")
|
||
print("-" * 50)
|
||
|
||
# 处理每张图片
|
||
for i, image_file in enumerate(image_files, 1):
|
||
try:
|
||
if remove_subject:
|
||
suffix = image_file.suffix.lower()
|
||
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
|
||
output_filename = image_file.stem + "_bgfill" + suffix
|
||
else:
|
||
output_filename = image_file.stem + "_bgfill.jpg"
|
||
else:
|
||
# 去背景默认使用PNG格式以支持透明背景
|
||
output_filename = image_file.stem + "_nobg.png"
|
||
output_path = Path(output_folder) / output_filename
|
||
|
||
print(f"[{i}/{len(image_files)}] ", end="")
|
||
if remove_subject:
|
||
mask_output_path = None
|
||
if save_mask:
|
||
mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png"))
|
||
remove_subject_and_inpaint(
|
||
str(image_file),
|
||
str(output_path),
|
||
session=session,
|
||
alpha_matting=alpha_matting,
|
||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||
post_process_mask=post_process_mask,
|
||
mask_dilate=mask_dilate,
|
||
mask_blur=mask_blur,
|
||
mask_threshold=mask_threshold,
|
||
edge_grow=edge_grow,
|
||
aot_root=aot_root,
|
||
aot_pretrain=aot_pretrain,
|
||
aot_device=aot_device,
|
||
aot_block_num=aot_block_num,
|
||
aot_rates=aot_rates,
|
||
aot_crop=aot_crop,
|
||
aot_crop_pad=aot_crop_pad,
|
||
aot_max_size=aot_max_size,
|
||
aot_noise_prefill=aot_noise_prefill,
|
||
aot_noise_strength=aot_noise_strength,
|
||
black_subject=black_subject,
|
||
black_threshold=black_threshold,
|
||
gray_subject=gray_subject,
|
||
gray_saturation_threshold=gray_saturation_threshold,
|
||
gray_value_threshold=gray_value_threshold,
|
||
feather=feather,
|
||
feather_radius=feather_radius,
|
||
save_mask=save_mask,
|
||
mask_output_path=mask_output_path,
|
||
)
|
||
else:
|
||
remove_background(
|
||
str(image_file),
|
||
str(output_path),
|
||
session=session,
|
||
alpha_matting=alpha_matting,
|
||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||
post_process_mask=post_process_mask,
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"处理 {image_file.name} 时出错: {e}")
|
||
|
||
print("-" * 50)
|
||
print(f"处理完成!结果保存在 {output_folder} 文件夹中")
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(
|
||
description='图片去背景工具 - 使用rembg自动去除图片背景',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog='''
|
||
示例用法:
|
||
# 使用默认参数处理images文件夹
|
||
python remove_background.py
|
||
|
||
# 处理单个文件
|
||
python remove_background.py input.jpg output.png
|
||
|
||
# 处理指定文件夹
|
||
python remove_background.py my_images/ my_output/
|
||
|
||
# 使用不同模型
|
||
python remove_background.py input.jpg output.png -m birefnet-portrait
|
||
|
||
# 自定义alpha matting参数
|
||
python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5
|
||
'''
|
||
)
|
||
|
||
# 必需参数
|
||
parser.add_argument('input', nargs='?', default='images',
|
||
help='输入文件或文件夹路径(默认: images)')
|
||
parser.add_argument('output', nargs='?', default=None,
|
||
help='输出文件或文件夹路径(可选,默认为output/)')
|
||
|
||
# 模型选择
|
||
parser.add_argument('-m', '--model', default='isnet-general-use',
|
||
choices=['u2net', 'u2netp', 'u2net_human_seg', 'silueta',
|
||
'isnet-general-use', 'isnet-anime',
|
||
'birefnet-general', 'birefnet-general-lite',
|
||
'birefnet-portrait', 'birefnet-dis',
|
||
'birefnet-hrsod', 'birefnet-cod', 'birefnet-massive'],
|
||
help='选择使用的模型 (默认: isnet-general-use)')
|
||
|
||
# Alpha Matting参数
|
||
parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true',
|
||
help='启用alpha matting后处理(默认: false)')
|
||
parser.add_argument('-ft', '--foreground-threshold', type=int, default=245,
|
||
help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)')
|
||
parser.add_argument('-bt', '--background-threshold', type=int, default=8,
|
||
help='背景阈值 (0-255),值越大去除越多背景 (默认: 8)')
|
||
parser.add_argument('-es', '--erode-size', type=int, default=2,
|
||
help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)')
|
||
|
||
# 其他选项
|
||
parser.add_argument('-p', '--post-process', '--post_process', action='store_true',
|
||
help='启用mask后处理(默认: false)')
|
||
|
||
# 去主体补背景参数
|
||
parser.add_argument('--remove-subject', '--remove_subject', action='store_true',
|
||
help='去掉主体并补全背景(默认: false)')
|
||
parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting',
|
||
help='AOT-GAN目录(默认: AOT-GAN-for-Inpainting)')
|
||
parser.add_argument('--aot-pretrain', type=str, default=None,
|
||
help='AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)')
|
||
parser.add_argument('--aot-device', type=str, default='cpu',
|
||
help='AOT-GAN设备(默认: cpu)')
|
||
parser.add_argument('--aot-block-num', type=int, default=8,
|
||
help='AOTBlock数量(默认: 8)')
|
||
parser.add_argument('--aot-rates', type=str, default='1+2+4+8',
|
||
help='AOTBlock膨胀率(默认: 1+2+4+8)')
|
||
parser.add_argument('--aot-crop', action='store_true',
|
||
help='AOT仅对mask区域裁剪修补(默认: false)')
|
||
parser.add_argument('--aot-crop-pad', type=int, default=0,
|
||
help='AOT裁剪边缘留白像素(默认: 0)')
|
||
parser.add_argument('--aot-max-size', type=int, default=0,
|
||
help='AOT输入最大边限制,0为不限制(默认: 0)')
|
||
parser.add_argument('--aot-noise-prefill', action='store_true',
|
||
help='AOT使用随机噪声预填充(默认: false)')
|
||
parser.add_argument('--aot-noise-strength', type=float, default=1.0,
|
||
help='AOT噪声强度系数(默认: 1.0)')
|
||
parser.add_argument('--mask-dilate', type=int, default=3,
|
||
help='mask膨胀大小(默认: 3)')
|
||
parser.add_argument('--mask-blur', type=int, default=3,
|
||
help='mask模糊大小(默认: 3,建议奇数)')
|
||
parser.add_argument('--mask-threshold', type=int, default=10,
|
||
help='alpha阈值(默认: 10)')
|
||
parser.add_argument('--edge-grow', type=int, default=0,
|
||
help='主体边缘扩张像素(默认: 0)')
|
||
parser.add_argument('--save-mask', '--save_mask', action='store_true',
|
||
help='保存mask到output目录(默认: false)')
|
||
parser.add_argument('--black-subject', '--black_subject', action='store_true',
|
||
help='将黑色内容也视为主体(默认: false)')
|
||
parser.add_argument('--black-threshold', type=int, default=50,
|
||
help='黑色阈值(0-255,灰度越小越黑,默认: 50)')
|
||
parser.add_argument('--gray-subject', '--gray_subject', action='store_true',
|
||
help='将灰阶内容也视为主体(默认: false)')
|
||
parser.add_argument('--gray-saturation-threshold', type=int, default=30,
|
||
help='灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)')
|
||
parser.add_argument('--gray-value-threshold', type=int, default=200,
|
||
help='灰阶亮度阈值(0-255,越小越暗,默认: 200)')
|
||
parser.add_argument('--feather', action='store_true',
|
||
help='启用边缘过渡融合(默认: false)')
|
||
parser.add_argument('--feather-radius', type=int, default=5,
|
||
help='边缘过渡半径(默认: 5)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 50)
|
||
print("图片去背景工具")
|
||
print("=" * 50)
|
||
|
||
# 判断输入是文件还是文件夹
|
||
input_path = Path(args.input)
|
||
|
||
if not input_path.exists():
|
||
print(f"错误: 输入路径不存在: {args.input}")
|
||
exit(1)
|
||
|
||
# 处理单个文件
|
||
if input_path.is_file():
|
||
suffix = input_path.suffix.lower()
|
||
if args.remove_subject:
|
||
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
|
||
output_name = input_path.stem + "_bgfill" + suffix
|
||
else:
|
||
output_name = input_path.stem + "_bgfill.jpg"
|
||
else:
|
||
output_name = input_path.stem + "_nobg.png"
|
||
|
||
# 确定输出路径
|
||
if args.output is None:
|
||
if args.remove_subject:
|
||
output_path = input_path.parent / "output" / output_name
|
||
else:
|
||
output_path = input_path.parent / 'output' / output_name
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
else:
|
||
output_candidate = Path(args.output)
|
||
if output_candidate.exists() and output_candidate.is_dir():
|
||
output_path = output_candidate / output_name
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
elif output_candidate.suffix == "":
|
||
output_candidate.mkdir(parents=True, exist_ok=True)
|
||
output_path = output_candidate / output_name
|
||
else:
|
||
output_path = output_candidate
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"输入文件: {input_path}")
|
||
print(f"输出文件: {output_path}")
|
||
print(f"模型: {args.model}")
|
||
print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}")
|
||
if args.alpha_matting:
|
||
print(f" - 前景阈值: {args.foreground_threshold}")
|
||
print(f" - 背景阈值: {args.background_threshold}")
|
||
print(f" - 侵蚀大小: {args.erode_size}")
|
||
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
|
||
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
|
||
if args.remove_subject:
|
||
print(f" - AOT目录: {args.aot_root}")
|
||
print(f" - AOT权重: {args.aot_pretrain}")
|
||
print(f" - AOT设备: {args.aot_device}")
|
||
print(f" - AOT块数: {args.aot_block_num}")
|
||
print(f" - AOT膨胀率: {args.aot_rates}")
|
||
print(f" - AOT裁剪: {'是' if args.aot_crop else '否'}")
|
||
if args.aot_crop:
|
||
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
|
||
print(f" - AOT最大边: {args.aot_max_size}")
|
||
print(f" - AOT噪声预填充: {'是' if args.aot_noise_prefill else '否'}")
|
||
if args.aot_noise_prefill:
|
||
print(f" - AOT噪声强度: {args.aot_noise_strength}")
|
||
print(f" - mask膨胀: {args.mask_dilate}")
|
||
print(f" - mask模糊: {args.mask_blur}")
|
||
print(f" - mask阈值: {args.mask_threshold}")
|
||
print(f" - 边缘扩张: {args.edge_grow}")
|
||
print(f" - 黑色内容作为主体: {'是' if args.black_subject else '否'}")
|
||
if args.black_subject:
|
||
print(f" - 黑色阈值: {args.black_threshold}")
|
||
print(f" - 灰阶内容作为主体: {'是' if args.gray_subject else '否'}")
|
||
if args.gray_subject:
|
||
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
|
||
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
|
||
print(f" - 边缘过渡: {'是' if args.feather else '否'}")
|
||
if args.feather:
|
||
print(f" - 过渡半径: {args.feather_radius}")
|
||
print(f" - 保存mask: {'是' if args.save_mask else '否'}")
|
||
print("-" * 50)
|
||
|
||
# 创建会话
|
||
session = new_session(args.model)
|
||
|
||
# 处理图片
|
||
if args.remove_subject:
|
||
mask_output_path = None
|
||
if args.save_mask:
|
||
mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
|
||
remove_subject_and_inpaint(
|
||
str(input_path),
|
||
str(output_path),
|
||
session=session,
|
||
alpha_matting=args.alpha_matting,
|
||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||
alpha_matting_background_threshold=args.background_threshold,
|
||
alpha_matting_erode_size=args.erode_size,
|
||
post_process_mask=args.post_process,
|
||
mask_dilate=args.mask_dilate,
|
||
mask_blur=args.mask_blur,
|
||
mask_threshold=args.mask_threshold,
|
||
edge_grow=args.edge_grow,
|
||
aot_root=args.aot_root,
|
||
aot_pretrain=args.aot_pretrain,
|
||
aot_device=args.aot_device,
|
||
aot_block_num=args.aot_block_num,
|
||
aot_rates=args.aot_rates,
|
||
aot_crop=args.aot_crop,
|
||
aot_crop_pad=args.aot_crop_pad,
|
||
aot_max_size=args.aot_max_size,
|
||
aot_noise_prefill=args.aot_noise_prefill,
|
||
aot_noise_strength=args.aot_noise_strength,
|
||
black_subject=args.black_subject,
|
||
black_threshold=args.black_threshold,
|
||
gray_subject=args.gray_subject,
|
||
gray_saturation_threshold=args.gray_saturation_threshold,
|
||
gray_value_threshold=args.gray_value_threshold,
|
||
feather=args.feather,
|
||
feather_radius=args.feather_radius,
|
||
save_mask=args.save_mask,
|
||
mask_output_path=mask_output_path,
|
||
)
|
||
else:
|
||
remove_background(
|
||
str(input_path),
|
||
str(output_path),
|
||
session=session,
|
||
alpha_matting=args.alpha_matting,
|
||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||
alpha_matting_background_threshold=args.background_threshold,
|
||
alpha_matting_erode_size=args.erode_size,
|
||
post_process_mask=args.post_process,
|
||
)
|
||
|
||
print("-" * 50)
|
||
print(f"处理完成!结果保存在: {output_path}")
|
||
|
||
# 处理文件夹
|
||
elif input_path.is_dir():
|
||
output_folder = args.output if args.output else 'output'
|
||
|
||
process_images_folder(
|
||
str(input_path),
|
||
output_folder,
|
||
model_name=args.model,
|
||
alpha_matting=args.alpha_matting,
|
||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||
alpha_matting_background_threshold=args.background_threshold,
|
||
alpha_matting_erode_size=args.erode_size,
|
||
post_process_mask=args.post_process,
|
||
remove_subject=args.remove_subject,
|
||
mask_dilate=args.mask_dilate,
|
||
mask_blur=args.mask_blur,
|
||
mask_threshold=args.mask_threshold,
|
||
edge_grow=args.edge_grow,
|
||
aot_root=args.aot_root,
|
||
aot_pretrain=args.aot_pretrain,
|
||
aot_device=args.aot_device,
|
||
aot_block_num=args.aot_block_num,
|
||
aot_rates=args.aot_rates,
|
||
aot_crop=args.aot_crop,
|
||
aot_crop_pad=args.aot_crop_pad,
|
||
aot_max_size=args.aot_max_size,
|
||
aot_noise_prefill=args.aot_noise_prefill,
|
||
aot_noise_strength=args.aot_noise_strength,
|
||
black_subject=args.black_subject,
|
||
black_threshold=args.black_threshold,
|
||
gray_subject=args.gray_subject,
|
||
gray_saturation_threshold=args.gray_saturation_threshold,
|
||
gray_value_threshold=args.gray_value_threshold,
|
||
feather=args.feather,
|
||
feather_radius=args.feather_radius,
|
||
save_mask=args.save_mask,
|
||
)
|
||
|
||
else:
|
||
print(f"错误: 不支持的输入类型: {args.input}")
|
||
exit(1)
|