update at 2026-03-28 16:46:40

This commit is contained in:
douboer
2026-03-28 16:46:40 +08:00
parent 57cd9a1f39
commit a1b0d6731c
13 changed files with 921 additions and 58 deletions

BIN
.DS_Store vendored

Binary file not shown.

Submodule AOT-GAN-for-Inpainting added at 2cd1afd8fd

103
README.md
View File

@@ -14,7 +14,9 @@ python remove_background.py
# 处理单个文件
python remove_background.py input.jpg output.png
# 查看所有参数-m, --model)
# 查看所有参数
python remove_background.py -h
```
不同模型适用于不同场景:
@@ -39,7 +41,19 @@ python remove_background.py input.jpg output.png -m birefnet-portrait
# 使用快速模型
python remove_background.py input.jpg output.png -m u2netp
``
# 书画类文字偏浅的温和参数示例
python remove_background.py images output \
--remove-subject --black_subject --gray_subject --save_mask \
--black-threshold 30\
--gray-saturation-threshold 30 --gray-value-threshold 30 \
--edge-grow 2 \
--feather --feather-radius 4 \
--aot-pretrain experiments/G0000000.pt \
--aot-max-size 1000
```
查看已下载模型:
```bash
ls -lh ~/.u2net/
```
@@ -75,6 +89,40 @@ rm -rf ~/.u2net/
rm ~/.u2net/u2net.onnx
```
## AOT-GAN 修补后端
`--remove-subject` 默认使用 AOT-GAN 修补。
AOT-GAN 依赖 PyTorch官方仓库测试 Python 3.8 / torch 1.8.1)。建议使用独立虚拟环境或确保兼容版本。
```bash
# 安装依赖(示例)
pip install torch torchvision
```
下载预训练权重后,运行示例:
```bash
python remove_background.py "images/IMG_9259 2.JPG" \
--remove-subject --black-subject --gray-subject --save-mask \
--aot-pretrain experiments/places2.pth
```
CPU 无 GPU 时的加速建议(只裁剪主体区域并限制最大边):
```bash
python remove_background.py "images/IMG_9259 2.JPG" \
--remove-subject --black-subject --gray-subject --save-mask \
--aot-pretrain experiments/places2.pth \
--aot-crop --aot-crop-pad 24 --aot-max-size 1400
```
减少“补脸”倾向:启用随机噪声预填充
```bash
python remove_background.py "images/IMG_9259 2.JPG" \
--remove-subject --black-subject --gray-subject --save-mask \
--aot-pretrain experiments/places2.pth \
--aot-crop --aot-crop-pad 64 --aot-max-size 900 \
--aot-noise-prefill --aot-noise-strength 1.0
```
## 可调整参数说明
### 1. 模型选择 (model_name)
@@ -97,10 +145,11 @@ rm ~/.u2net/u2net.onnx
Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头发、毛发等细节。
#### alpha_matting (True/False)
- **作用**: 是否启用alpha matting
- **默认**: False
- **建议**: 如果边缘不自然,启用此选项
#### alpha_matting(开关)
- **作用**: 是否启用 alpha matting,提升边缘质量
- **默认**: 关闭(不传 `-a/--alpha-matting`
- **启用方式**: 传入 `-a``--alpha-matting`
- **效果**: 有利于细节边缘(毛发/细线),但速度稍慢
#### alpha_matting_foreground_threshold (0-255)
- **作用**: 前景阈值,控制哪些区域被认为是前景
@@ -128,9 +177,45 @@ Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头
### 3. Mask后处理 (post_process_mask)
- **作用**: 对mask进行额外后处理
- **默认**: False
- **建议**: 可以尝试启用看效果是否改善
- **作用**: 对 mask 进行额外后处理
- **默认**: 关闭(不传 `-p/--post-process`
- **启用方式**: 传入 `-p``--post-process`
- **效果**: 有助于减少毛边,但可能略损失细节
### 4. 去主体补背景 (remove_subject)
用于“去掉主体并补全背景”。当前仅使用 AOT-GAN 修补。
- **remove_subject开关**: 启用去主体补背景(默认关闭,传 `--remove-subject` 开启)
- **aot_root**: AOT-GAN 目录(默认: `AOT-GAN-for-Inpainting`
- **aot_pretrain**: AOT-GAN 权重文件路径(必填)
- **aot_device**: AOT-GAN 设备(默认: `cpu`
- **aot_block_num**: AOTBlock 数量(默认: 8
- **aot_rates**: AOTBlock 膨胀率(默认: `1+2+4+8`
- **aot_crop开关**: 仅对 mask 覆盖区域裁剪修补(默认关闭,传 `--aot-crop` 开启)
- **aot_crop_pad (像素)**: 裁剪边缘留白像素(默认: 0
- **aot_max_size (像素)**: AOT 输入最大边限制(默认: 0 表示不限制)
- **aot_noise_prefill开关**: AOT使用随机噪声预填充默认关闭
- **aot_noise_strength (系数)**: 噪声强度(默认: 1.0
- **mask_dilate (像素)**: mask 膨胀大小(默认: 3。越大去除范围越大风险更高
- **mask_blur (像素)**: mask 模糊大小(默认: 3。越大边缘越柔和但易过度
- **mask_threshold (0-255)**: alpha 阈值(默认: 10。越大保留越多主体
- **edge_grow (像素)**: 主体边缘额外扩张(默认: 0。用于清理残留边缘
- **save_mask开关**: 保存 mask 方便检查(默认关闭,传 `--save-mask` 开启)
- **black_subject开关**: 将黑色内容也视为主体(默认关闭,传 `--black-subject` 开启)
- **black_threshold (0-255)**: 黑色阈值(默认: 50。越大越容易把浅灰当黑
- **gray_subject开关**: 将灰阶内容也视为主体(默认关闭,传 `--gray-subject` 开启)
- **gray_saturation_threshold (0-255)**: 灰阶饱和度阈值(默认: 30。越大越容易把彩色当灰
- **gray_value_threshold (0-255)**: 灰阶亮度阈值(默认: 200。越大越容易把浅灰当灰
- **feather开关**: 启用边缘过渡(默认关闭,传 `--feather` 开启)
- **feather_radius (像素)**: 过渡半径(默认: 5。越大过渡越柔和但可能变糊
- **说明**: 过渡仅在 mask 外侧进行,避免把主体边缘带回
### 5. 参数调优建议(针对书画/字迹)
- 先开启 `--remove-subject`,仅看主体遮罩是否覆盖到字迹
- 文字残留:提高 `--black-threshold``--gray-*` 阈值
- 过度修补:降低 `--black-threshold``--gray-value-threshold`,并减小 `--mask-dilate/--mask-blur`
- 边缘不自然:尝试开启 `--feather` 并使用较小的 `--feather-radius`
## 常见问题解决

Binary file not shown.

BIN
experiments/D0000000.pt Normal file

Binary file not shown.

BIN
experiments/G0000000.pt Normal file

Binary file not shown.

BIN
experiments/O0000000.pt Normal file

Binary file not shown.

16
pyproject.toml Normal file
View File

@@ -0,0 +1,16 @@
[tool.mypy]
python_version = "3.12"
ignore_missing_imports = true
exclude = '^(generative_inpainting|images|output)/'
[tool.ruff]
exclude = [
"AOT-GAN-for-Inpainting",
"generative_inpainting",
"images",
"output",
"__pycache__",
".venv",
"venv",
".git",
]

View File

@@ -3,10 +3,288 @@
使用rembg库自动去除图片背景
"""
import os
import sys
import argparse
from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageOps
# 避免 numba 在某些环境下缓存失败
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
from rembg import remove, new_session
from PIL import Image
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
def _resolve_aot_paths(aot_root, aot_pretrain):
"""解析 AOT-GAN 的路径配置。"""
aot_root = Path(aot_root).resolve()
if aot_pretrain is None:
aot_pretrain = None
else:
aot_pretrain = Path(aot_pretrain)
if not aot_pretrain.is_absolute():
# 预训练权重相对路径以当前工作目录为基准
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
return aot_root, aot_pretrain
def _parse_aot_rates(rates_str):
parts = [p for p in rates_str.split("+") if p]
return [int(p) for p in parts]
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
"""加载/缓存 AOT-GAN 模型。"""
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
if not aot_root.exists():
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
if aot_pretrain is None:
raise ValueError("AOT-GAN需要指定预训练权重路径--aot-pretrain")
if not aot_pretrain.exists():
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
if rates is None:
rates = [1, 2, 4, 8]
rates_tuple = tuple(rates)
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
if key in AOT_MODEL_CACHE:
return AOT_MODEL_CACHE[key]
src_root = aot_root / "src"
if str(src_root) not in sys.path:
sys.path.insert(0, str(src_root))
import importlib
try:
import torch
except ImportError as exc:
raise ImportError("未找到 PyTorch请先按README安装依赖。") from exc
net = importlib.import_module("model.aotgan")
class _Args:
pass
args = _Args()
args.block_num = int(block_num)
args.rates = rates
model = net.InpaintGenerator(args)
state = torch.load(str(aot_pretrain), map_location=device)
if isinstance(state, dict):
if "state_dict" in state:
state = state["state_dict"]
elif "model" in state:
state = state["model"]
elif "generator" in state:
state = state["generator"]
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
state = {k.replace("module.", "", 1): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
model.to(device)
model.eval()
AOT_MODEL_CACHE[key] = (model, device)
return AOT_MODEL_CACHE[key]
def _mask_bbox(mask):
ys, xs = np.where(mask > 0)
if len(xs) == 0:
return None
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
def _expand_bbox(bbox, pad, width, height):
if bbox is None:
return None
x0, y0, x1, y1 = bbox
pad = max(0, int(pad))
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(width - 1, x1 + pad)
y1 = min(height - 1, y1 + pad)
return x0, y0, x1, y1
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
expanded = _expand_bbox(bbox, pad, width, height)
if expanded is None:
return None
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
need_w = max(0, int(min_size) - roi_w)
need_h = max(0, int(min_size) - roi_h)
if need_w == 0 and need_h == 0:
return expanded
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
return _expand_bbox(bbox, pad + extra, width, height)
def _build_noise_prefill(img, mask_t, strength):
"""为mask区域生成噪声预填充img取值范围为[-1, 1]。"""
import torch
img01 = (img + 1.0) / 2.0
unmasked = 1.0 - mask_t
denom = unmasked.sum()
if denom.item() < 1.0:
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
else:
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
std = torch.sqrt(var + 1e-6)
noise = mean + std * float(strength) * torch.randn_like(img01)
noise = noise.clamp(0.0, 1.0)
return noise * 2.0 - 1.0
def _inpaint_with_aot_core(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
grid = 4
h2 = (h // grid) * grid
w2 = (w // grid) * grid
if h2 == 0 or w2 == 0:
return bgr
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
img_crop = rgb[:h2, :w2, :]
mask_crop = mask[:h2, :w2]
import torch
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
img = img * 2.0 - 1.0
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
img = img.unsqueeze(0)
mask_t = mask_t.unsqueeze(0)
if device != "cpu":
img = img.to(device)
mask_t = mask_t.to(device)
model, _ = _get_aot_model(
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
)
with torch.no_grad():
if noise_prefill:
noise = _build_noise_prefill(img, mask_t, noise_strength)
image_masked = img * (1 - mask_t) + noise * mask_t
else:
image_masked = img * (1 - mask_t) + mask_t
pred = model(image_masked, mask_t)
comp = pred * mask_t + img * (1 - mask_t)
comp = comp[0].clamp(-1.0, 1.0)
comp = (comp + 1.0) / 2.0 * 255.0
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
if h2 == h and w2 == w:
return result_bgr
output_full = bgr.copy()
output_full[:h2, :w2, :] = result_bgr
return output_full
def _inpaint_with_aot(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
crop=False,
crop_pad=0,
max_size=0,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
min_side = 32
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
if crop:
bbox = _mask_bbox(mask)
if bbox is None:
return bgr
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
if expanded is None:
return bgr
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
if roi_w < min_side or roi_h < min_side:
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
crop = False
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
roi_h, roi_w = bgr_roi.shape[:2]
scale = 1.0
max_size = int(max_size) if max_size else 0
if max_size > 0:
max_dim = max(roi_h, roi_w)
if max_dim > max_size:
scale = max_size / float(max_dim)
new_w = max(1, int(round(roi_w * scale)))
new_h = max(1, int(round(roi_h * scale)))
if min(new_w, new_h) < min_side:
scale = 1.0
new_w = roi_w
new_h = roi_h
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
if scale != 1.0:
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
filled_roi = _inpaint_with_aot_core(
bgr_roi,
mask_roi,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
noise_prefill=noise_prefill,
noise_strength=noise_strength,
)
if scale != 1.0:
filled_roi = cv2.resize(
filled_roi,
(roi_w, roi_h),
interpolation=cv2.INTER_LINEAR,
)
if not crop:
return filled_roi
output_full = bgr.copy()
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
return output_full
# 支持HEIC格式
try:
@@ -16,17 +294,6 @@ try:
except ImportError:
HEIC_SUPPORTED = False
def str2bool(v):
"""将字符串转换为布尔值"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('布尔值应为true或false')
def remove_background(input_path, output_path, session=None, **kwargs):
"""
去除图片背景
@@ -40,7 +307,7 @@ def remove_background(input_path, output_path, session=None, **kwargs):
print(f"正在处理: {input_path}")
# 读取输入图片
input_image = Image.open(input_path)
input_image = ImageOps.exif_transpose(Image.open(input_path))
# 去除背景
output_image = remove(input_image, session=session, **kwargs)
@@ -50,10 +317,195 @@ def remove_background(input_path, output_path, session=None, **kwargs):
print(f"已保存: {output_path}")
def process_images_folder(input_folder, output_folder, model_name="u2net",
alpha_matting=False, alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10, alpha_matting_erode_size=10,
post_process_mask=False):
def _pil_to_bgr(image):
"""将PIL图片转换为OpenCV BGR格式"""
rgb = image.convert("RGB")
arr = np.array(rgb)
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
def _alpha_to_mask(alpha_channel, threshold=10):
"""将alpha通道转换为二值mask0或255"""
mask = (alpha_channel > threshold).astype(np.uint8) * 255
return mask
def _ensure_rgba_size(image, target_size):
if image.mode != "RGBA":
image = image.convert("RGBA")
if image.size != target_size:
image = image.resize(target_size, resample=Image.NEAREST)
return image
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
"""生成硬边mask与处理后mask用于填补/过渡)"""
if mask_dilate and mask_dilate > 0:
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=1)
if edge_grow and edge_grow > 0:
kernel = np.ones((3, 3), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
mask_hard = mask.copy()
mask_used = mask_hard
if mask_blur and mask_blur > 0:
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
mask_used = (mask_used > 0).astype(np.uint8) * 255
return mask_hard, mask_used
def remove_subject_and_inpaint(
input_path,
output_path,
session=None,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
mask_output_path=None,
**kwargs,
):
"""
去掉主体并补全背景
Args:
input_path: 输入图片路径
output_path: 输出图片路径
session: rembg会话对象可选
aot_root: AOT-GAN目录
aot_pretrain: AOT-GAN权重文件路径
aot_device: AOT-GAN设备
aot_block_num: AOT-GAN AOTBlock 数量
aot_rates: AOT-GAN AOTBlock 膨胀率
aot_crop: AOT仅裁剪mask区域进行修补
aot_crop_pad: AOT裁剪边缘留白像素
aot_max_size: AOT输入最大边限制0为不限制
aot_noise_prefill: AOT使用随机噪声预填充
aot_noise_strength: 噪声强度系数
mask_dilate: mask膨胀大小
mask_blur: mask模糊大小奇数
mask_threshold: alpha阈值
save_mask: 是否保存mask
mask_output_path: mask保存路径
**kwargs: rembg参数
"""
print(f"正在处理(去主体补背景): {input_path}")
input_image = ImageOps.exif_transpose(Image.open(input_path))
# 使用rembg获取主体mask
output_image = remove(input_image, session=session, **kwargs)
output_image = _ensure_rgba_size(output_image, input_image.size)
alpha = np.array(output_image.getchannel("A"))
mask = _alpha_to_mask(alpha, threshold=mask_threshold)
bgr = _pil_to_bgr(input_image)
if black_subject:
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, black_mask)
if gray_subject:
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
s = hsv[:, :, 1]
v = hsv[:, :, 2]
gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, gray_mask)
mask_hard, mask_used = _prepare_mask(
mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
edge_grow=edge_grow,
)
rates = _parse_aot_rates(aot_rates)
filled = _inpaint_with_aot(
bgr,
mask_used,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=aot_device,
block_num=aot_block_num,
rates=rates,
crop=aot_crop,
crop_pad=aot_crop_pad,
max_size=aot_max_size,
noise_prefill=aot_noise_prefill,
noise_strength=aot_noise_strength,
)
if feather and feather_radius > 0:
# 仅在mask外侧做过渡避免把原主体带回
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
alpha = np.ones_like(dist_out, dtype=np.float32)
outside = mask_bin == 0
alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0)
alpha = alpha[:, :, None]
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
else:
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
Image.fromarray(result).save(output_path)
if save_mask:
if mask_output_path is None:
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
Image.fromarray(mask_used).save(mask_output_path)
print(f"已保存mask: {mask_output_path}")
print(f"已保存: {output_path}")
def process_images_folder(
input_folder,
output_folder,
model_name="u2net",
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_size=10,
post_process_mask=False,
remove_subject=False,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
):
"""
批量处理文件夹中的所有图片
@@ -105,26 +557,100 @@ def process_images_folder(input_folder, output_folder, model_name="u2net",
print(f" - 背景阈值: {alpha_matting_background_threshold}")
print(f" - 侵蚀大小: {alpha_matting_erode_size}")
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
if remove_subject:
print(f" - AOT目录: {aot_root}")
print(f" - AOT权重: {aot_pretrain}")
print(f" - AOT设备: {aot_device}")
print(f" - AOT块数: {aot_block_num}")
print(f" - AOT膨胀率: {aot_rates}")
print(f" - AOT裁剪: {'' if aot_crop else ''}")
if aot_crop:
print(f" - AOT裁剪边界: {aot_crop_pad}")
print(f" - AOT最大边: {aot_max_size}")
print(f" - AOT噪声预填充: {'' if aot_noise_prefill else ''}")
if aot_noise_prefill:
print(f" - AOT噪声强度: {aot_noise_strength}")
print(f" - mask膨胀: {mask_dilate}")
print(f" - mask模糊: {mask_blur}")
print(f" - mask阈值: {mask_threshold}")
print(f" - 边缘扩张: {edge_grow}")
print(f" - 黑色内容作为主体: {'' if black_subject else ''}")
if black_subject:
print(f" - 黑色阈值: {black_threshold}")
print(f" - 灰阶内容作为主体: {'' if gray_subject else ''}")
if gray_subject:
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
print(f" - 边缘过渡: {'' if feather else ''}")
if feather:
print(f" - 过渡半径: {feather_radius}")
print(f" - 保存mask: {'' if save_mask else ''}")
print("-" * 50)
# 处理每张图片
for i, image_file in enumerate(image_files, 1):
try:
# 输出文件名保持原始名称改为PNG格式以支持透明背景
output_filename = image_file.stem + '_nobg.png'
if remove_subject:
suffix = image_file.suffix.lower()
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_filename = image_file.stem + "_bgfill" + suffix
else:
output_filename = image_file.stem + "_bgfill.jpg"
else:
# 去背景默认使用PNG格式以支持透明背景
output_filename = image_file.stem + "_nobg.png"
output_path = Path(output_folder) / output_filename
print(f"[{i}/{len(image_files)}] ", end="")
remove_background(
str(image_file),
str(output_path),
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask
)
if remove_subject:
mask_output_path = None
if save_mask:
mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png"))
remove_subject_and_inpaint(
str(image_file),
str(output_path),
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
mask_threshold=mask_threshold,
edge_grow=edge_grow,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
aot_device=aot_device,
aot_block_num=aot_block_num,
aot_rates=aot_rates,
aot_crop=aot_crop,
aot_crop_pad=aot_crop_pad,
aot_max_size=aot_max_size,
aot_noise_prefill=aot_noise_prefill,
aot_noise_strength=aot_noise_strength,
black_subject=black_subject,
black_threshold=black_threshold,
gray_subject=gray_subject,
gray_saturation_threshold=gray_saturation_threshold,
gray_value_threshold=gray_value_threshold,
feather=feather,
feather_radius=feather_radius,
save_mask=save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(image_file),
str(output_path),
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
)
except Exception as e:
print(f"处理 {image_file.name} 时出错: {e}")
@@ -171,9 +697,8 @@ if __name__ == "__main__":
help='选择使用的模型 (默认: isnet-general-use)')
# Alpha Matting参数
parser.add_argument('-a', '--alpha-matting', type=str2bool, nargs='?', const=True, default=True,
metavar='true/false',
help='启用alpha matting后处理默认: true。用法: -a 或 -a true 或 -a false')
parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true',
help='启用alpha matting后处理默认: false')
parser.add_argument('-ft', '--foreground-threshold', type=int, default=245,
help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)')
parser.add_argument('-bt', '--background-threshold', type=int, default=8,
@@ -182,9 +707,56 @@ if __name__ == "__main__":
help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)')
# 其他选项
parser.add_argument('-p', '--post-process', type=str2bool, nargs='?', const=True, default=True,
metavar='true/false',
help='启用mask后处理默认: true。用法: -p 或 -p true 或 -p false')
parser.add_argument('-p', '--post-process', '--post_process', action='store_true',
help='启用mask后处理默认: false')
# 去主体补背景参数
parser.add_argument('--remove-subject', '--remove_subject', action='store_true',
help='去掉主体并补全背景(默认: false')
parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting',
help='AOT-GAN目录默认: AOT-GAN-for-Inpainting')
parser.add_argument('--aot-pretrain', type=str, default=None,
help='AOT-GAN预训练权重文件路径必填相对路径基于当前目录')
parser.add_argument('--aot-device', type=str, default='cpu',
help='AOT-GAN设备默认: cpu')
parser.add_argument('--aot-block-num', type=int, default=8,
help='AOTBlock数量默认: 8')
parser.add_argument('--aot-rates', type=str, default='1+2+4+8',
help='AOTBlock膨胀率默认: 1+2+4+8')
parser.add_argument('--aot-crop', action='store_true',
help='AOT仅对mask区域裁剪修补默认: false')
parser.add_argument('--aot-crop-pad', type=int, default=0,
help='AOT裁剪边缘留白像素默认: 0')
parser.add_argument('--aot-max-size', type=int, default=0,
help='AOT输入最大边限制0为不限制默认: 0')
parser.add_argument('--aot-noise-prefill', action='store_true',
help='AOT使用随机噪声预填充默认: false')
parser.add_argument('--aot-noise-strength', type=float, default=1.0,
help='AOT噪声强度系数默认: 1.0')
parser.add_argument('--mask-dilate', type=int, default=3,
help='mask膨胀大小默认: 3')
parser.add_argument('--mask-blur', type=int, default=3,
help='mask模糊大小默认: 3建议奇数')
parser.add_argument('--mask-threshold', type=int, default=10,
help='alpha阈值默认: 10')
parser.add_argument('--edge-grow', type=int, default=0,
help='主体边缘扩张像素(默认: 0')
parser.add_argument('--save-mask', '--save_mask', action='store_true',
help='保存mask到output目录默认: false')
parser.add_argument('--black-subject', '--black_subject', action='store_true',
help='将黑色内容也视为主体(默认: false')
parser.add_argument('--black-threshold', type=int, default=50,
help='黑色阈值(0-255灰度越小越黑默认: 50)')
parser.add_argument('--gray-subject', '--gray_subject', action='store_true',
help='将灰阶内容也视为主体(默认: false')
parser.add_argument('--gray-saturation-threshold', type=int, default=30,
help='灰阶饱和度阈值(0-255越小越接近灰阶默认: 30)')
parser.add_argument('--gray-value-threshold', type=int, default=200,
help='灰阶亮度阈值(0-255越小越暗默认: 200)')
parser.add_argument('--feather', action='store_true',
help='启用边缘过渡融合(默认: false')
parser.add_argument('--feather-radius', type=int, default=5,
help='边缘过渡半径(默认: 5')
args = parser.parse_args()
@@ -201,13 +773,33 @@ if __name__ == "__main__":
# 处理单个文件
if input_path.is_file():
suffix = input_path.suffix.lower()
if args.remove_subject:
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_name = input_path.stem + "_bgfill" + suffix
else:
output_name = input_path.stem + "_bgfill.jpg"
else:
output_name = input_path.stem + "_nobg.png"
# 确定输出路径
if args.output is None:
output_path = input_path.parent / 'output' / (input_path.stem + '_nobg.png')
if args.remove_subject:
output_path = input_path.parent / "output" / output_name
else:
output_path = input_path.parent / 'output' / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
else:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_candidate = Path(args.output)
if output_candidate.exists() and output_candidate.is_dir():
output_path = output_candidate / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
elif output_candidate.suffix == "":
output_candidate.mkdir(parents=True, exist_ok=True)
output_path = output_candidate / output_name
else:
output_path = output_candidate
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"输入文件: {input_path}")
print(f"输出文件: {output_path}")
@@ -218,22 +810,89 @@ if __name__ == "__main__":
print(f" - 背景阈值: {args.background_threshold}")
print(f" - 侵蚀大小: {args.erode_size}")
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
if args.remove_subject:
print(f" - AOT目录: {args.aot_root}")
print(f" - AOT权重: {args.aot_pretrain}")
print(f" - AOT设备: {args.aot_device}")
print(f" - AOT块数: {args.aot_block_num}")
print(f" - AOT膨胀率: {args.aot_rates}")
print(f" - AOT裁剪: {'' if args.aot_crop else ''}")
if args.aot_crop:
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
print(f" - AOT最大边: {args.aot_max_size}")
print(f" - AOT噪声预填充: {'' if args.aot_noise_prefill else ''}")
if args.aot_noise_prefill:
print(f" - AOT噪声强度: {args.aot_noise_strength}")
print(f" - mask膨胀: {args.mask_dilate}")
print(f" - mask模糊: {args.mask_blur}")
print(f" - mask阈值: {args.mask_threshold}")
print(f" - 边缘扩张: {args.edge_grow}")
print(f" - 黑色内容作为主体: {'' if args.black_subject else ''}")
if args.black_subject:
print(f" - 黑色阈值: {args.black_threshold}")
print(f" - 灰阶内容作为主体: {'' if args.gray_subject else ''}")
if args.gray_subject:
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
print(f" - 边缘过渡: {'' if args.feather else ''}")
if args.feather:
print(f" - 过渡半径: {args.feather_radius}")
print(f" - 保存mask: {'' if args.save_mask else ''}")
print("-" * 50)
# 创建会话
session = new_session(args.model)
# 处理图片
remove_background(
str(input_path),
str(output_path),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process
)
if args.remove_subject:
mask_output_path = None
if args.save_mask:
mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
remove_subject_and_inpaint(
str(input_path),
str(output_path),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(input_path),
str(output_path),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
)
print("-" * 50)
print(f"处理完成!结果保存在: {output_path}")
@@ -250,7 +909,30 @@ if __name__ == "__main__":
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process
post_process_mask=args.post_process,
remove_subject=args.remove_subject,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
)
else:

View File

@@ -1,3 +1,4 @@
rembg[gpu]
pillow
pillow-heif
opencv-python

View File

@@ -0,0 +1,78 @@
import importlib
import sys
import types
import unittest
def _import_remove_background():
if "remove_background" in sys.modules:
del sys.modules["remove_background"]
rembg_stub = types.ModuleType("rembg")
rembg_stub.remove = lambda *args, **kwargs: None
rembg_stub.new_session = lambda *args, **kwargs: None
sys.modules["rembg"] = rembg_stub
return importlib.import_module("remove_background")
class RemoveBackgroundTests(unittest.TestCase):
def setUp(self):
self.mod = _import_remove_background()
def test_alpha_to_mask_threshold(self):
alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8)
mask = self.mod._alpha_to_mask(alpha, threshold=10)
self.assertEqual(mask.shape, alpha.shape)
self.assertEqual(mask[0, 0], 0)
self.assertEqual(mask[0, 1], 0)
self.assertEqual(mask[0, 2], 0)
self.assertEqual(mask[0, 3], 255)
def test_prepare_mask_outputs_binary(self):
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
mask[2, 2] = 255
mask_hard, mask_used = self.mod._prepare_mask(mask, mask_dilate=0, mask_blur=3)
self.assertEqual(mask_hard.shape, mask.shape)
self.assertEqual(mask_used.shape, mask.shape)
self.assertTrue(self.mod.np.all(self.mod.np.isin(mask_used, [0, 255])))
def test_resolve_aot_paths_defaults(self):
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
"AOT-GAN-for-Inpainting", None
)
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
self.assertIsNone(aot_pretrain)
def test_resolve_aot_paths_relative_pretrain(self):
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
"AOT-GAN-for-Inpainting", "experiments/foo.pt"
)
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
self.assertTrue(str(aot_pretrain).endswith("experiments/foo.pt"))
def test_parse_aot_rates(self):
rates = self.mod._parse_aot_rates("1+2+4+8")
self.assertEqual(rates, [1, 2, 4, 8])
def test_mask_bbox_empty(self):
mask = self.mod.np.zeros((3, 3), dtype=self.mod.np.uint8)
self.assertIsNone(self.mod._mask_bbox(mask))
def test_mask_bbox_and_expand(self):
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
mask[1:3, 2:4] = 255
bbox = self.mod._mask_bbox(mask)
self.assertEqual(bbox, (2, 1, 3, 2))
expanded = self.mod._expand_bbox(bbox, 2, 5, 5)
self.assertEqual(expanded, (0, 0, 4, 4))
expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4)
self.assertEqual(expanded_min, (1, 0, 4, 3))
def test_ensure_rgba_size_resizes(self):
img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0))
out = self.mod._ensure_rgba_size(img, (4, 5))
self.assertEqual(out.mode, "RGBA")
self.assertEqual(out.size, (4, 5))
if __name__ == "__main__":
unittest.main()