Files
remove_backgroud/remove_background.py
2026-03-28 16:46:40 +08:00

941 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图片去背景工具
使用rembg库自动去除图片背景
"""
import os
import sys
import argparse
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageOps
# 避免 numba 在某些环境下缓存失败
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
from rembg import remove, new_session
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
def _resolve_aot_paths(aot_root, aot_pretrain):
"""解析 AOT-GAN 的路径配置。"""
aot_root = Path(aot_root).resolve()
if aot_pretrain is None:
aot_pretrain = None
else:
aot_pretrain = Path(aot_pretrain)
if not aot_pretrain.is_absolute():
# 预训练权重相对路径以当前工作目录为基准
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
return aot_root, aot_pretrain
def _parse_aot_rates(rates_str):
parts = [p for p in rates_str.split("+") if p]
return [int(p) for p in parts]
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
"""加载/缓存 AOT-GAN 模型。"""
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
if not aot_root.exists():
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
if aot_pretrain is None:
raise ValueError("AOT-GAN需要指定预训练权重路径--aot-pretrain")
if not aot_pretrain.exists():
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
if rates is None:
rates = [1, 2, 4, 8]
rates_tuple = tuple(rates)
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
if key in AOT_MODEL_CACHE:
return AOT_MODEL_CACHE[key]
src_root = aot_root / "src"
if str(src_root) not in sys.path:
sys.path.insert(0, str(src_root))
import importlib
try:
import torch
except ImportError as exc:
raise ImportError("未找到 PyTorch请先按README安装依赖。") from exc
net = importlib.import_module("model.aotgan")
class _Args:
pass
args = _Args()
args.block_num = int(block_num)
args.rates = rates
model = net.InpaintGenerator(args)
state = torch.load(str(aot_pretrain), map_location=device)
if isinstance(state, dict):
if "state_dict" in state:
state = state["state_dict"]
elif "model" in state:
state = state["model"]
elif "generator" in state:
state = state["generator"]
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
state = {k.replace("module.", "", 1): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
model.to(device)
model.eval()
AOT_MODEL_CACHE[key] = (model, device)
return AOT_MODEL_CACHE[key]
def _mask_bbox(mask):
ys, xs = np.where(mask > 0)
if len(xs) == 0:
return None
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
def _expand_bbox(bbox, pad, width, height):
if bbox is None:
return None
x0, y0, x1, y1 = bbox
pad = max(0, int(pad))
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(width - 1, x1 + pad)
y1 = min(height - 1, y1 + pad)
return x0, y0, x1, y1
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
expanded = _expand_bbox(bbox, pad, width, height)
if expanded is None:
return None
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
need_w = max(0, int(min_size) - roi_w)
need_h = max(0, int(min_size) - roi_h)
if need_w == 0 and need_h == 0:
return expanded
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
return _expand_bbox(bbox, pad + extra, width, height)
def _build_noise_prefill(img, mask_t, strength):
"""为mask区域生成噪声预填充img取值范围为[-1, 1]。"""
import torch
img01 = (img + 1.0) / 2.0
unmasked = 1.0 - mask_t
denom = unmasked.sum()
if denom.item() < 1.0:
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
else:
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
std = torch.sqrt(var + 1e-6)
noise = mean + std * float(strength) * torch.randn_like(img01)
noise = noise.clamp(0.0, 1.0)
return noise * 2.0 - 1.0
def _inpaint_with_aot_core(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
grid = 4
h2 = (h // grid) * grid
w2 = (w // grid) * grid
if h2 == 0 or w2 == 0:
return bgr
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
img_crop = rgb[:h2, :w2, :]
mask_crop = mask[:h2, :w2]
import torch
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
img = img * 2.0 - 1.0
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
img = img.unsqueeze(0)
mask_t = mask_t.unsqueeze(0)
if device != "cpu":
img = img.to(device)
mask_t = mask_t.to(device)
model, _ = _get_aot_model(
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
)
with torch.no_grad():
if noise_prefill:
noise = _build_noise_prefill(img, mask_t, noise_strength)
image_masked = img * (1 - mask_t) + noise * mask_t
else:
image_masked = img * (1 - mask_t) + mask_t
pred = model(image_masked, mask_t)
comp = pred * mask_t + img * (1 - mask_t)
comp = comp[0].clamp(-1.0, 1.0)
comp = (comp + 1.0) / 2.0 * 255.0
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
if h2 == h and w2 == w:
return result_bgr
output_full = bgr.copy()
output_full[:h2, :w2, :] = result_bgr
return output_full
def _inpaint_with_aot(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
crop=False,
crop_pad=0,
max_size=0,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
min_side = 32
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
if crop:
bbox = _mask_bbox(mask)
if bbox is None:
return bgr
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
if expanded is None:
return bgr
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
if roi_w < min_side or roi_h < min_side:
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
crop = False
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
roi_h, roi_w = bgr_roi.shape[:2]
scale = 1.0
max_size = int(max_size) if max_size else 0
if max_size > 0:
max_dim = max(roi_h, roi_w)
if max_dim > max_size:
scale = max_size / float(max_dim)
new_w = max(1, int(round(roi_w * scale)))
new_h = max(1, int(round(roi_h * scale)))
if min(new_w, new_h) < min_side:
scale = 1.0
new_w = roi_w
new_h = roi_h
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
if scale != 1.0:
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
filled_roi = _inpaint_with_aot_core(
bgr_roi,
mask_roi,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
noise_prefill=noise_prefill,
noise_strength=noise_strength,
)
if scale != 1.0:
filled_roi = cv2.resize(
filled_roi,
(roi_w, roi_h),
interpolation=cv2.INTER_LINEAR,
)
if not crop:
return filled_roi
output_full = bgr.copy()
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
return output_full
# 支持HEIC格式
try:
from pillow_heif import register_heif_opener
register_heif_opener()
HEIC_SUPPORTED = True
except ImportError:
HEIC_SUPPORTED = False
def remove_background(input_path, output_path, session=None, **kwargs):
"""
去除图片背景
Args:
input_path: 输入图片路径
output_path: 输出图片路径
session: rembg会话对象可选
**kwargs: 其他参数如alpha_matting相关参数
"""
print(f"正在处理: {input_path}")
# 读取输入图片
input_image = ImageOps.exif_transpose(Image.open(input_path))
# 去除背景
output_image = remove(input_image, session=session, **kwargs)
# 保存输出图片
output_image.save(output_path)
print(f"已保存: {output_path}")
def _pil_to_bgr(image):
"""将PIL图片转换为OpenCV BGR格式"""
rgb = image.convert("RGB")
arr = np.array(rgb)
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
def _alpha_to_mask(alpha_channel, threshold=10):
"""将alpha通道转换为二值mask0或255"""
mask = (alpha_channel > threshold).astype(np.uint8) * 255
return mask
def _ensure_rgba_size(image, target_size):
if image.mode != "RGBA":
image = image.convert("RGBA")
if image.size != target_size:
image = image.resize(target_size, resample=Image.NEAREST)
return image
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
"""生成硬边mask与处理后mask用于填补/过渡)"""
if mask_dilate and mask_dilate > 0:
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=1)
if edge_grow and edge_grow > 0:
kernel = np.ones((3, 3), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
mask_hard = mask.copy()
mask_used = mask_hard
if mask_blur and mask_blur > 0:
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
mask_used = (mask_used > 0).astype(np.uint8) * 255
return mask_hard, mask_used
def remove_subject_and_inpaint(
input_path,
output_path,
session=None,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
mask_output_path=None,
**kwargs,
):
"""
去掉主体并补全背景
Args:
input_path: 输入图片路径
output_path: 输出图片路径
session: rembg会话对象可选
aot_root: AOT-GAN目录
aot_pretrain: AOT-GAN权重文件路径
aot_device: AOT-GAN设备
aot_block_num: AOT-GAN AOTBlock 数量
aot_rates: AOT-GAN AOTBlock 膨胀率
aot_crop: AOT仅裁剪mask区域进行修补
aot_crop_pad: AOT裁剪边缘留白像素
aot_max_size: AOT输入最大边限制0为不限制
aot_noise_prefill: AOT使用随机噪声预填充
aot_noise_strength: 噪声强度系数
mask_dilate: mask膨胀大小
mask_blur: mask模糊大小奇数
mask_threshold: alpha阈值
save_mask: 是否保存mask
mask_output_path: mask保存路径
**kwargs: rembg参数
"""
print(f"正在处理(去主体补背景): {input_path}")
input_image = ImageOps.exif_transpose(Image.open(input_path))
# 使用rembg获取主体mask
output_image = remove(input_image, session=session, **kwargs)
output_image = _ensure_rgba_size(output_image, input_image.size)
alpha = np.array(output_image.getchannel("A"))
mask = _alpha_to_mask(alpha, threshold=mask_threshold)
bgr = _pil_to_bgr(input_image)
if black_subject:
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, black_mask)
if gray_subject:
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
s = hsv[:, :, 1]
v = hsv[:, :, 2]
gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, gray_mask)
mask_hard, mask_used = _prepare_mask(
mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
edge_grow=edge_grow,
)
rates = _parse_aot_rates(aot_rates)
filled = _inpaint_with_aot(
bgr,
mask_used,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=aot_device,
block_num=aot_block_num,
rates=rates,
crop=aot_crop,
crop_pad=aot_crop_pad,
max_size=aot_max_size,
noise_prefill=aot_noise_prefill,
noise_strength=aot_noise_strength,
)
if feather and feather_radius > 0:
# 仅在mask外侧做过渡避免把原主体带回
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
alpha = np.ones_like(dist_out, dtype=np.float32)
outside = mask_bin == 0
alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0)
alpha = alpha[:, :, None]
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
else:
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
Image.fromarray(result).save(output_path)
if save_mask:
if mask_output_path is None:
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
Image.fromarray(mask_used).save(mask_output_path)
print(f"已保存mask: {mask_output_path}")
print(f"已保存: {output_path}")
def process_images_folder(
input_folder,
output_folder,
model_name="u2net",
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_size=10,
post_process_mask=False,
remove_subject=False,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
):
"""
批量处理文件夹中的所有图片
Args:
input_folder: 输入文件夹路径
output_folder: 输出文件夹路径
model_name: 模型名称,可选值:
- u2net (默认): 通用模型
- u2netp: 轻量版u2net
- u2net_human_seg: 人物分割
- silueta: 精简版u2net (43MB)
- isnet-general-use: 新的通用模型
- isnet-anime: 动漫角色高精度分割
- birefnet-general: 通用模型
- birefnet-portrait: 人像模型
alpha_matting: 是否启用alpha matting后处理改善边缘质量
alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景
alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景
alpha_matting_erode_size: 侵蚀大小,用于平滑边缘
post_process_mask: 是否启用mask后处理
"""
# 创建输出文件夹
Path(output_folder).mkdir(parents=True, exist_ok=True)
# 创建会话(重用会话可以提高性能)
print(f"使用模型: {model_name}")
session = new_session(model_name)
# 支持的图片格式
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
if HEIC_SUPPORTED:
image_extensions.update({'.heic', '.heif'})
else:
print("提示: 未安装pillow-heifHEIC格式不可用。安装方法: pip install pillow-heif")
# 获取所有图片文件
input_path = Path(input_folder)
image_files = [f for f in input_path.iterdir()
if f.is_file() and f.suffix.lower() in image_extensions]
if not image_files:
print(f"{input_folder} 中没有找到图片文件")
return
print(f"找到 {len(image_files)} 张图片,开始处理...")
print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}")
if alpha_matting:
print(f" - 前景阈值: {alpha_matting_foreground_threshold}")
print(f" - 背景阈值: {alpha_matting_background_threshold}")
print(f" - 侵蚀大小: {alpha_matting_erode_size}")
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
if remove_subject:
print(f" - AOT目录: {aot_root}")
print(f" - AOT权重: {aot_pretrain}")
print(f" - AOT设备: {aot_device}")
print(f" - AOT块数: {aot_block_num}")
print(f" - AOT膨胀率: {aot_rates}")
print(f" - AOT裁剪: {'' if aot_crop else ''}")
if aot_crop:
print(f" - AOT裁剪边界: {aot_crop_pad}")
print(f" - AOT最大边: {aot_max_size}")
print(f" - AOT噪声预填充: {'' if aot_noise_prefill else ''}")
if aot_noise_prefill:
print(f" - AOT噪声强度: {aot_noise_strength}")
print(f" - mask膨胀: {mask_dilate}")
print(f" - mask模糊: {mask_blur}")
print(f" - mask阈值: {mask_threshold}")
print(f" - 边缘扩张: {edge_grow}")
print(f" - 黑色内容作为主体: {'' if black_subject else ''}")
if black_subject:
print(f" - 黑色阈值: {black_threshold}")
print(f" - 灰阶内容作为主体: {'' if gray_subject else ''}")
if gray_subject:
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
print(f" - 边缘过渡: {'' if feather else ''}")
if feather:
print(f" - 过渡半径: {feather_radius}")
print(f" - 保存mask: {'' if save_mask else ''}")
print("-" * 50)
# 处理每张图片
for i, image_file in enumerate(image_files, 1):
try:
if remove_subject:
suffix = image_file.suffix.lower()
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_filename = image_file.stem + "_bgfill" + suffix
else:
output_filename = image_file.stem + "_bgfill.jpg"
else:
# 去背景默认使用PNG格式以支持透明背景
output_filename = image_file.stem + "_nobg.png"
output_path = Path(output_folder) / output_filename
print(f"[{i}/{len(image_files)}] ", end="")
if remove_subject:
mask_output_path = None
if save_mask:
mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png"))
remove_subject_and_inpaint(
str(image_file),
str(output_path),
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
mask_threshold=mask_threshold,
edge_grow=edge_grow,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
aot_device=aot_device,
aot_block_num=aot_block_num,
aot_rates=aot_rates,
aot_crop=aot_crop,
aot_crop_pad=aot_crop_pad,
aot_max_size=aot_max_size,
aot_noise_prefill=aot_noise_prefill,
aot_noise_strength=aot_noise_strength,
black_subject=black_subject,
black_threshold=black_threshold,
gray_subject=gray_subject,
gray_saturation_threshold=gray_saturation_threshold,
gray_value_threshold=gray_value_threshold,
feather=feather,
feather_radius=feather_radius,
save_mask=save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(image_file),
str(output_path),
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
)
except Exception as e:
print(f"处理 {image_file.name} 时出错: {e}")
print("-" * 50)
print(f"处理完成!结果保存在 {output_folder} 文件夹中")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='图片去背景工具 - 使用rembg自动去除图片背景',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
示例用法:
# 使用默认参数处理images文件夹
python remove_background.py
# 处理单个文件
python remove_background.py input.jpg output.png
# 处理指定文件夹
python remove_background.py my_images/ my_output/
# 使用不同模型
python remove_background.py input.jpg output.png -m birefnet-portrait
# 自定义alpha matting参数
python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5
'''
)
# 必需参数
parser.add_argument('input', nargs='?', default='images',
help='输入文件或文件夹路径(默认: images')
parser.add_argument('output', nargs='?', default=None,
help='输出文件或文件夹路径可选默认为output/')
# 模型选择
parser.add_argument('-m', '--model', default='isnet-general-use',
choices=['u2net', 'u2netp', 'u2net_human_seg', 'silueta',
'isnet-general-use', 'isnet-anime',
'birefnet-general', 'birefnet-general-lite',
'birefnet-portrait', 'birefnet-dis',
'birefnet-hrsod', 'birefnet-cod', 'birefnet-massive'],
help='选择使用的模型 (默认: isnet-general-use)')
# Alpha Matting参数
parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true',
help='启用alpha matting后处理默认: false')
parser.add_argument('-ft', '--foreground-threshold', type=int, default=245,
help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)')
parser.add_argument('-bt', '--background-threshold', type=int, default=8,
help='背景阈值 (0-255),值越大去除越多背景 (默认: 8)')
parser.add_argument('-es', '--erode-size', type=int, default=2,
help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)')
# 其他选项
parser.add_argument('-p', '--post-process', '--post_process', action='store_true',
help='启用mask后处理默认: false')
# 去主体补背景参数
parser.add_argument('--remove-subject', '--remove_subject', action='store_true',
help='去掉主体并补全背景(默认: false')
parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting',
help='AOT-GAN目录默认: AOT-GAN-for-Inpainting')
parser.add_argument('--aot-pretrain', type=str, default=None,
help='AOT-GAN预训练权重文件路径必填相对路径基于当前目录')
parser.add_argument('--aot-device', type=str, default='cpu',
help='AOT-GAN设备默认: cpu')
parser.add_argument('--aot-block-num', type=int, default=8,
help='AOTBlock数量默认: 8')
parser.add_argument('--aot-rates', type=str, default='1+2+4+8',
help='AOTBlock膨胀率默认: 1+2+4+8')
parser.add_argument('--aot-crop', action='store_true',
help='AOT仅对mask区域裁剪修补默认: false')
parser.add_argument('--aot-crop-pad', type=int, default=0,
help='AOT裁剪边缘留白像素默认: 0')
parser.add_argument('--aot-max-size', type=int, default=0,
help='AOT输入最大边限制0为不限制默认: 0')
parser.add_argument('--aot-noise-prefill', action='store_true',
help='AOT使用随机噪声预填充默认: false')
parser.add_argument('--aot-noise-strength', type=float, default=1.0,
help='AOT噪声强度系数默认: 1.0')
parser.add_argument('--mask-dilate', type=int, default=3,
help='mask膨胀大小默认: 3')
parser.add_argument('--mask-blur', type=int, default=3,
help='mask模糊大小默认: 3建议奇数')
parser.add_argument('--mask-threshold', type=int, default=10,
help='alpha阈值默认: 10')
parser.add_argument('--edge-grow', type=int, default=0,
help='主体边缘扩张像素(默认: 0')
parser.add_argument('--save-mask', '--save_mask', action='store_true',
help='保存mask到output目录默认: false')
parser.add_argument('--black-subject', '--black_subject', action='store_true',
help='将黑色内容也视为主体(默认: false')
parser.add_argument('--black-threshold', type=int, default=50,
help='黑色阈值(0-255灰度越小越黑默认: 50)')
parser.add_argument('--gray-subject', '--gray_subject', action='store_true',
help='将灰阶内容也视为主体(默认: false')
parser.add_argument('--gray-saturation-threshold', type=int, default=30,
help='灰阶饱和度阈值(0-255越小越接近灰阶默认: 30)')
parser.add_argument('--gray-value-threshold', type=int, default=200,
help='灰阶亮度阈值(0-255越小越暗默认: 200)')
parser.add_argument('--feather', action='store_true',
help='启用边缘过渡融合(默认: false')
parser.add_argument('--feather-radius', type=int, default=5,
help='边缘过渡半径(默认: 5')
args = parser.parse_args()
print("=" * 50)
print("图片去背景工具")
print("=" * 50)
# 判断输入是文件还是文件夹
input_path = Path(args.input)
if not input_path.exists():
print(f"错误: 输入路径不存在: {args.input}")
exit(1)
# 处理单个文件
if input_path.is_file():
suffix = input_path.suffix.lower()
if args.remove_subject:
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_name = input_path.stem + "_bgfill" + suffix
else:
output_name = input_path.stem + "_bgfill.jpg"
else:
output_name = input_path.stem + "_nobg.png"
# 确定输出路径
if args.output is None:
if args.remove_subject:
output_path = input_path.parent / "output" / output_name
else:
output_path = input_path.parent / 'output' / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
else:
output_candidate = Path(args.output)
if output_candidate.exists() and output_candidate.is_dir():
output_path = output_candidate / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
elif output_candidate.suffix == "":
output_candidate.mkdir(parents=True, exist_ok=True)
output_path = output_candidate / output_name
else:
output_path = output_candidate
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"输入文件: {input_path}")
print(f"输出文件: {output_path}")
print(f"模型: {args.model}")
print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}")
if args.alpha_matting:
print(f" - 前景阈值: {args.foreground_threshold}")
print(f" - 背景阈值: {args.background_threshold}")
print(f" - 侵蚀大小: {args.erode_size}")
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
if args.remove_subject:
print(f" - AOT目录: {args.aot_root}")
print(f" - AOT权重: {args.aot_pretrain}")
print(f" - AOT设备: {args.aot_device}")
print(f" - AOT块数: {args.aot_block_num}")
print(f" - AOT膨胀率: {args.aot_rates}")
print(f" - AOT裁剪: {'' if args.aot_crop else ''}")
if args.aot_crop:
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
print(f" - AOT最大边: {args.aot_max_size}")
print(f" - AOT噪声预填充: {'' if args.aot_noise_prefill else ''}")
if args.aot_noise_prefill:
print(f" - AOT噪声强度: {args.aot_noise_strength}")
print(f" - mask膨胀: {args.mask_dilate}")
print(f" - mask模糊: {args.mask_blur}")
print(f" - mask阈值: {args.mask_threshold}")
print(f" - 边缘扩张: {args.edge_grow}")
print(f" - 黑色内容作为主体: {'' if args.black_subject else ''}")
if args.black_subject:
print(f" - 黑色阈值: {args.black_threshold}")
print(f" - 灰阶内容作为主体: {'' if args.gray_subject else ''}")
if args.gray_subject:
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
print(f" - 边缘过渡: {'' if args.feather else ''}")
if args.feather:
print(f" - 过渡半径: {args.feather_radius}")
print(f" - 保存mask: {'' if args.save_mask else ''}")
print("-" * 50)
# 创建会话
session = new_session(args.model)
# 处理图片
if args.remove_subject:
mask_output_path = None
if args.save_mask:
mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
remove_subject_and_inpaint(
str(input_path),
str(output_path),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(input_path),
str(output_path),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
)
print("-" * 50)
print(f"处理完成!结果保存在: {output_path}")
# 处理文件夹
elif input_path.is_dir():
output_folder = args.output if args.output else 'output'
process_images_folder(
str(input_path),
output_folder,
model_name=args.model,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
remove_subject=args.remove_subject,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
)
else:
print(f"错误: 不支持的输入类型: {args.input}")
exit(1)