update at 2026-03-28 16:46:40

This commit is contained in:
douboer
2026-03-28 16:46:40 +08:00
parent 57cd9a1f39
commit a1b0d6731c
13 changed files with 921 additions and 58 deletions

BIN
.DS_Store vendored

Binary file not shown.

Submodule AOT-GAN-for-Inpainting added at 2cd1afd8fd

103
README.md
View File

@@ -14,7 +14,9 @@ python remove_background.py
# 处理单个文件 # 处理单个文件
python remove_background.py input.jpg output.png python remove_background.py input.jpg output.png
# 查看所有参数-m, --model) # 查看所有参数
python remove_background.py -h
```
不同模型适用于不同场景: 不同模型适用于不同场景:
@@ -39,7 +41,19 @@ python remove_background.py input.jpg output.png -m birefnet-portrait
# 使用快速模型 # 使用快速模型
python remove_background.py input.jpg output.png -m u2netp python remove_background.py input.jpg output.png -m u2netp
``
# 书画类文字偏浅的温和参数示例
python remove_background.py images output \
--remove-subject --black_subject --gray_subject --save_mask \
--black-threshold 30\
--gray-saturation-threshold 30 --gray-value-threshold 30 \
--edge-grow 2 \
--feather --feather-radius 4 \
--aot-pretrain experiments/G0000000.pt \
--aot-max-size 1000
```
查看已下载模型:
```bash ```bash
ls -lh ~/.u2net/ ls -lh ~/.u2net/
``` ```
@@ -75,6 +89,40 @@ rm -rf ~/.u2net/
rm ~/.u2net/u2net.onnx rm ~/.u2net/u2net.onnx
``` ```
## AOT-GAN 修补后端
`--remove-subject` 默认使用 AOT-GAN 修补。
AOT-GAN 依赖 PyTorch官方仓库测试 Python 3.8 / torch 1.8.1)。建议使用独立虚拟环境或确保兼容版本。
```bash
# 安装依赖(示例)
pip install torch torchvision
```
下载预训练权重后,运行示例:
```bash
python remove_background.py "images/IMG_9259 2.JPG" \
--remove-subject --black-subject --gray-subject --save-mask \
--aot-pretrain experiments/places2.pth
```
CPU 无 GPU 时的加速建议(只裁剪主体区域并限制最大边):
```bash
python remove_background.py "images/IMG_9259 2.JPG" \
--remove-subject --black-subject --gray-subject --save-mask \
--aot-pretrain experiments/places2.pth \
--aot-crop --aot-crop-pad 24 --aot-max-size 1400
```
减少“补脸”倾向:启用随机噪声预填充
```bash
python remove_background.py "images/IMG_9259 2.JPG" \
--remove-subject --black-subject --gray-subject --save-mask \
--aot-pretrain experiments/places2.pth \
--aot-crop --aot-crop-pad 64 --aot-max-size 900 \
--aot-noise-prefill --aot-noise-strength 1.0
```
## 可调整参数说明 ## 可调整参数说明
### 1. 模型选择 (model_name) ### 1. 模型选择 (model_name)
@@ -97,10 +145,11 @@ rm ~/.u2net/u2net.onnx
Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头发、毛发等细节。 Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头发、毛发等细节。
#### alpha_matting (True/False) #### alpha_matting(开关)
- **作用**: 是否启用alpha matting - **作用**: 是否启用 alpha matting,提升边缘质量
- **默认**: False - **默认**: 关闭(不传 `-a/--alpha-matting`
- **建议**: 如果边缘不自然,启用此选项 - **启用方式**: 传入 `-a``--alpha-matting`
- **效果**: 有利于细节边缘(毛发/细线),但速度稍慢
#### alpha_matting_foreground_threshold (0-255) #### alpha_matting_foreground_threshold (0-255)
- **作用**: 前景阈值,控制哪些区域被认为是前景 - **作用**: 前景阈值,控制哪些区域被认为是前景
@@ -128,9 +177,45 @@ Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头
### 3. Mask后处理 (post_process_mask) ### 3. Mask后处理 (post_process_mask)
- **作用**: 对mask进行额外后处理 - **作用**: 对 mask 进行额外后处理
- **默认**: False - **默认**: 关闭(不传 `-p/--post-process`
- **建议**: 可以尝试启用看效果是否改善 - **启用方式**: 传入 `-p``--post-process`
- **效果**: 有助于减少毛边,但可能略损失细节
### 4. 去主体补背景 (remove_subject)
用于“去掉主体并补全背景”。当前仅使用 AOT-GAN 修补。
- **remove_subject开关**: 启用去主体补背景(默认关闭,传 `--remove-subject` 开启)
- **aot_root**: AOT-GAN 目录(默认: `AOT-GAN-for-Inpainting`
- **aot_pretrain**: AOT-GAN 权重文件路径(必填)
- **aot_device**: AOT-GAN 设备(默认: `cpu`
- **aot_block_num**: AOTBlock 数量(默认: 8
- **aot_rates**: AOTBlock 膨胀率(默认: `1+2+4+8`
- **aot_crop开关**: 仅对 mask 覆盖区域裁剪修补(默认关闭,传 `--aot-crop` 开启)
- **aot_crop_pad (像素)**: 裁剪边缘留白像素(默认: 0
- **aot_max_size (像素)**: AOT 输入最大边限制(默认: 0 表示不限制)
- **aot_noise_prefill开关**: AOT使用随机噪声预填充默认关闭
- **aot_noise_strength (系数)**: 噪声强度(默认: 1.0
- **mask_dilate (像素)**: mask 膨胀大小(默认: 3。越大去除范围越大风险更高
- **mask_blur (像素)**: mask 模糊大小(默认: 3。越大边缘越柔和但易过度
- **mask_threshold (0-255)**: alpha 阈值(默认: 10。越大保留越多主体
- **edge_grow (像素)**: 主体边缘额外扩张(默认: 0。用于清理残留边缘
- **save_mask开关**: 保存 mask 方便检查(默认关闭,传 `--save-mask` 开启)
- **black_subject开关**: 将黑色内容也视为主体(默认关闭,传 `--black-subject` 开启)
- **black_threshold (0-255)**: 黑色阈值(默认: 50。越大越容易把浅灰当黑
- **gray_subject开关**: 将灰阶内容也视为主体(默认关闭,传 `--gray-subject` 开启)
- **gray_saturation_threshold (0-255)**: 灰阶饱和度阈值(默认: 30。越大越容易把彩色当灰
- **gray_value_threshold (0-255)**: 灰阶亮度阈值(默认: 200。越大越容易把浅灰当灰
- **feather开关**: 启用边缘过渡(默认关闭,传 `--feather` 开启)
- **feather_radius (像素)**: 过渡半径(默认: 5。越大过渡越柔和但可能变糊
- **说明**: 过渡仅在 mask 外侧进行,避免把主体边缘带回
### 5. 参数调优建议(针对书画/字迹)
- 先开启 `--remove-subject`,仅看主体遮罩是否覆盖到字迹
- 文字残留:提高 `--black-threshold``--gray-*` 阈值
- 过度修补:降低 `--black-threshold``--gray-value-threshold`,并减小 `--mask-dilate/--mask-blur`
- 边缘不自然:尝试开启 `--feather` 并使用较小的 `--feather-radius`
## 常见问题解决 ## 常见问题解决

Binary file not shown.

BIN
experiments/D0000000.pt Normal file

Binary file not shown.

BIN
experiments/G0000000.pt Normal file

Binary file not shown.

BIN
experiments/O0000000.pt Normal file

Binary file not shown.

16
pyproject.toml Normal file
View File

@@ -0,0 +1,16 @@
[tool.mypy]
python_version = "3.12"
ignore_missing_imports = true
exclude = '^(generative_inpainting|images|output)/'
[tool.ruff]
exclude = [
"AOT-GAN-for-Inpainting",
"generative_inpainting",
"images",
"output",
"__pycache__",
".venv",
"venv",
".git",
]

View File

@@ -3,10 +3,288 @@
使用rembg库自动去除图片背景 使用rembg库自动去除图片背景
""" """
import os import os
import sys
import argparse import argparse
from pathlib import Path from pathlib import Path
import numpy as np
import cv2
from PIL import Image, ImageOps
# 避免 numba 在某些环境下缓存失败
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
from rembg import remove, new_session from rembg import remove, new_session
from PIL import Image
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
def _resolve_aot_paths(aot_root, aot_pretrain):
"""解析 AOT-GAN 的路径配置。"""
aot_root = Path(aot_root).resolve()
if aot_pretrain is None:
aot_pretrain = None
else:
aot_pretrain = Path(aot_pretrain)
if not aot_pretrain.is_absolute():
# 预训练权重相对路径以当前工作目录为基准
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
return aot_root, aot_pretrain
def _parse_aot_rates(rates_str):
parts = [p for p in rates_str.split("+") if p]
return [int(p) for p in parts]
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
"""加载/缓存 AOT-GAN 模型。"""
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
if not aot_root.exists():
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
if aot_pretrain is None:
raise ValueError("AOT-GAN需要指定预训练权重路径--aot-pretrain")
if not aot_pretrain.exists():
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
if rates is None:
rates = [1, 2, 4, 8]
rates_tuple = tuple(rates)
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
if key in AOT_MODEL_CACHE:
return AOT_MODEL_CACHE[key]
src_root = aot_root / "src"
if str(src_root) not in sys.path:
sys.path.insert(0, str(src_root))
import importlib
try:
import torch
except ImportError as exc:
raise ImportError("未找到 PyTorch请先按README安装依赖。") from exc
net = importlib.import_module("model.aotgan")
class _Args:
pass
args = _Args()
args.block_num = int(block_num)
args.rates = rates
model = net.InpaintGenerator(args)
state = torch.load(str(aot_pretrain), map_location=device)
if isinstance(state, dict):
if "state_dict" in state:
state = state["state_dict"]
elif "model" in state:
state = state["model"]
elif "generator" in state:
state = state["generator"]
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
state = {k.replace("module.", "", 1): v for k, v in state.items()}
model.load_state_dict(state, strict=True)
model.to(device)
model.eval()
AOT_MODEL_CACHE[key] = (model, device)
return AOT_MODEL_CACHE[key]
def _mask_bbox(mask):
ys, xs = np.where(mask > 0)
if len(xs) == 0:
return None
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
def _expand_bbox(bbox, pad, width, height):
if bbox is None:
return None
x0, y0, x1, y1 = bbox
pad = max(0, int(pad))
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(width - 1, x1 + pad)
y1 = min(height - 1, y1 + pad)
return x0, y0, x1, y1
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
expanded = _expand_bbox(bbox, pad, width, height)
if expanded is None:
return None
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
need_w = max(0, int(min_size) - roi_w)
need_h = max(0, int(min_size) - roi_h)
if need_w == 0 and need_h == 0:
return expanded
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
return _expand_bbox(bbox, pad + extra, width, height)
def _build_noise_prefill(img, mask_t, strength):
"""为mask区域生成噪声预填充img取值范围为[-1, 1]。"""
import torch
img01 = (img + 1.0) / 2.0
unmasked = 1.0 - mask_t
denom = unmasked.sum()
if denom.item() < 1.0:
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
else:
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
std = torch.sqrt(var + 1e-6)
noise = mean + std * float(strength) * torch.randn_like(img01)
noise = noise.clamp(0.0, 1.0)
return noise * 2.0 - 1.0
def _inpaint_with_aot_core(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
grid = 4
h2 = (h // grid) * grid
w2 = (w // grid) * grid
if h2 == 0 or w2 == 0:
return bgr
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
img_crop = rgb[:h2, :w2, :]
mask_crop = mask[:h2, :w2]
import torch
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
img = img * 2.0 - 1.0
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
img = img.unsqueeze(0)
mask_t = mask_t.unsqueeze(0)
if device != "cpu":
img = img.to(device)
mask_t = mask_t.to(device)
model, _ = _get_aot_model(
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
)
with torch.no_grad():
if noise_prefill:
noise = _build_noise_prefill(img, mask_t, noise_strength)
image_masked = img * (1 - mask_t) + noise * mask_t
else:
image_masked = img * (1 - mask_t) + mask_t
pred = model(image_masked, mask_t)
comp = pred * mask_t + img * (1 - mask_t)
comp = comp[0].clamp(-1.0, 1.0)
comp = (comp + 1.0) / 2.0 * 255.0
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
if h2 == h and w2 == w:
return result_bgr
output_full = bgr.copy()
output_full[:h2, :w2, :] = result_bgr
return output_full
def _inpaint_with_aot(
bgr,
mask,
aot_root,
aot_pretrain,
device="cpu",
block_num=8,
rates=None,
crop=False,
crop_pad=0,
max_size=0,
noise_prefill=False,
noise_strength=1.0,
):
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
min_side = 32
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.max() <= 1:
mask = mask * 255
if mask.max() == 0:
return bgr
h, w = mask.shape
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
if crop:
bbox = _mask_bbox(mask)
if bbox is None:
return bgr
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
if expanded is None:
return bgr
x0, y0, x1, y1 = expanded
roi_w = x1 - x0 + 1
roi_h = y1 - y0 + 1
if roi_w < min_side or roi_h < min_side:
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
crop = False
x0 = y0 = 0
x1 = w - 1
y1 = h - 1
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
roi_h, roi_w = bgr_roi.shape[:2]
scale = 1.0
max_size = int(max_size) if max_size else 0
if max_size > 0:
max_dim = max(roi_h, roi_w)
if max_dim > max_size:
scale = max_size / float(max_dim)
new_w = max(1, int(round(roi_w * scale)))
new_h = max(1, int(round(roi_h * scale)))
if min(new_w, new_h) < min_side:
scale = 1.0
new_w = roi_w
new_h = roi_h
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
if scale != 1.0:
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
filled_roi = _inpaint_with_aot_core(
bgr_roi,
mask_roi,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=device,
block_num=block_num,
rates=rates,
noise_prefill=noise_prefill,
noise_strength=noise_strength,
)
if scale != 1.0:
filled_roi = cv2.resize(
filled_roi,
(roi_w, roi_h),
interpolation=cv2.INTER_LINEAR,
)
if not crop:
return filled_roi
output_full = bgr.copy()
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
return output_full
# 支持HEIC格式 # 支持HEIC格式
try: try:
@@ -16,17 +294,6 @@ try:
except ImportError: except ImportError:
HEIC_SUPPORTED = False HEIC_SUPPORTED = False
def str2bool(v):
"""将字符串转换为布尔值"""
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('布尔值应为true或false')
def remove_background(input_path, output_path, session=None, **kwargs): def remove_background(input_path, output_path, session=None, **kwargs):
""" """
去除图片背景 去除图片背景
@@ -40,7 +307,7 @@ def remove_background(input_path, output_path, session=None, **kwargs):
print(f"正在处理: {input_path}") print(f"正在处理: {input_path}")
# 读取输入图片 # 读取输入图片
input_image = Image.open(input_path) input_image = ImageOps.exif_transpose(Image.open(input_path))
# 去除背景 # 去除背景
output_image = remove(input_image, session=session, **kwargs) output_image = remove(input_image, session=session, **kwargs)
@@ -50,10 +317,195 @@ def remove_background(input_path, output_path, session=None, **kwargs):
print(f"已保存: {output_path}") print(f"已保存: {output_path}")
def process_images_folder(input_folder, output_folder, model_name="u2net", def _pil_to_bgr(image):
alpha_matting=False, alpha_matting_foreground_threshold=240, """将PIL图片转换为OpenCV BGR格式"""
alpha_matting_background_threshold=10, alpha_matting_erode_size=10, rgb = image.convert("RGB")
post_process_mask=False): arr = np.array(rgb)
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
def _alpha_to_mask(alpha_channel, threshold=10):
"""将alpha通道转换为二值mask0或255"""
mask = (alpha_channel > threshold).astype(np.uint8) * 255
return mask
def _ensure_rgba_size(image, target_size):
if image.mode != "RGBA":
image = image.convert("RGBA")
if image.size != target_size:
image = image.resize(target_size, resample=Image.NEAREST)
return image
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
"""生成硬边mask与处理后mask用于填补/过渡)"""
if mask_dilate and mask_dilate > 0:
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=1)
if edge_grow and edge_grow > 0:
kernel = np.ones((3, 3), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
mask_hard = mask.copy()
mask_used = mask_hard
if mask_blur and mask_blur > 0:
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
mask_used = (mask_used > 0).astype(np.uint8) * 255
return mask_hard, mask_used
def remove_subject_and_inpaint(
input_path,
output_path,
session=None,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
mask_output_path=None,
**kwargs,
):
"""
去掉主体并补全背景
Args:
input_path: 输入图片路径
output_path: 输出图片路径
session: rembg会话对象可选
aot_root: AOT-GAN目录
aot_pretrain: AOT-GAN权重文件路径
aot_device: AOT-GAN设备
aot_block_num: AOT-GAN AOTBlock 数量
aot_rates: AOT-GAN AOTBlock 膨胀率
aot_crop: AOT仅裁剪mask区域进行修补
aot_crop_pad: AOT裁剪边缘留白像素
aot_max_size: AOT输入最大边限制0为不限制
aot_noise_prefill: AOT使用随机噪声预填充
aot_noise_strength: 噪声强度系数
mask_dilate: mask膨胀大小
mask_blur: mask模糊大小奇数
mask_threshold: alpha阈值
save_mask: 是否保存mask
mask_output_path: mask保存路径
**kwargs: rembg参数
"""
print(f"正在处理(去主体补背景): {input_path}")
input_image = ImageOps.exif_transpose(Image.open(input_path))
# 使用rembg获取主体mask
output_image = remove(input_image, session=session, **kwargs)
output_image = _ensure_rgba_size(output_image, input_image.size)
alpha = np.array(output_image.getchannel("A"))
mask = _alpha_to_mask(alpha, threshold=mask_threshold)
bgr = _pil_to_bgr(input_image)
if black_subject:
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, black_mask)
if gray_subject:
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
s = hsv[:, :, 1]
v = hsv[:, :, 2]
gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255
mask = cv2.bitwise_or(mask, gray_mask)
mask_hard, mask_used = _prepare_mask(
mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
edge_grow=edge_grow,
)
rates = _parse_aot_rates(aot_rates)
filled = _inpaint_with_aot(
bgr,
mask_used,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
device=aot_device,
block_num=aot_block_num,
rates=rates,
crop=aot_crop,
crop_pad=aot_crop_pad,
max_size=aot_max_size,
noise_prefill=aot_noise_prefill,
noise_strength=aot_noise_strength,
)
if feather and feather_radius > 0:
# 仅在mask外侧做过渡避免把原主体带回
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
alpha = np.ones_like(dist_out, dtype=np.float32)
outside = mask_bin == 0
alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0)
alpha = alpha[:, :, None]
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
else:
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
Image.fromarray(result).save(output_path)
if save_mask:
if mask_output_path is None:
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
Image.fromarray(mask_used).save(mask_output_path)
print(f"已保存mask: {mask_output_path}")
print(f"已保存: {output_path}")
def process_images_folder(
input_folder,
output_folder,
model_name="u2net",
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_size=10,
post_process_mask=False,
remove_subject=False,
mask_dilate=3,
mask_blur=3,
mask_threshold=10,
edge_grow=0,
aot_root="AOT-GAN-for-Inpainting",
aot_pretrain=None,
aot_device="cpu",
aot_block_num=8,
aot_rates="1+2+4+8",
aot_crop=False,
aot_crop_pad=0,
aot_max_size=0,
aot_noise_prefill=False,
aot_noise_strength=1.0,
black_subject=False,
black_threshold=50,
gray_subject=False,
gray_saturation_threshold=30,
gray_value_threshold=200,
feather=False,
feather_radius=5,
save_mask=False,
):
""" """
批量处理文件夹中的所有图片 批量处理文件夹中的所有图片
@@ -105,26 +557,100 @@ def process_images_folder(input_folder, output_folder, model_name="u2net",
print(f" - 背景阈值: {alpha_matting_background_threshold}") print(f" - 背景阈值: {alpha_matting_background_threshold}")
print(f" - 侵蚀大小: {alpha_matting_erode_size}") print(f" - 侵蚀大小: {alpha_matting_erode_size}")
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}") print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
if remove_subject:
print(f" - AOT目录: {aot_root}")
print(f" - AOT权重: {aot_pretrain}")
print(f" - AOT设备: {aot_device}")
print(f" - AOT块数: {aot_block_num}")
print(f" - AOT膨胀率: {aot_rates}")
print(f" - AOT裁剪: {'' if aot_crop else ''}")
if aot_crop:
print(f" - AOT裁剪边界: {aot_crop_pad}")
print(f" - AOT最大边: {aot_max_size}")
print(f" - AOT噪声预填充: {'' if aot_noise_prefill else ''}")
if aot_noise_prefill:
print(f" - AOT噪声强度: {aot_noise_strength}")
print(f" - mask膨胀: {mask_dilate}")
print(f" - mask模糊: {mask_blur}")
print(f" - mask阈值: {mask_threshold}")
print(f" - 边缘扩张: {edge_grow}")
print(f" - 黑色内容作为主体: {'' if black_subject else ''}")
if black_subject:
print(f" - 黑色阈值: {black_threshold}")
print(f" - 灰阶内容作为主体: {'' if gray_subject else ''}")
if gray_subject:
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
print(f" - 边缘过渡: {'' if feather else ''}")
if feather:
print(f" - 过渡半径: {feather_radius}")
print(f" - 保存mask: {'' if save_mask else ''}")
print("-" * 50) print("-" * 50)
# 处理每张图片 # 处理每张图片
for i, image_file in enumerate(image_files, 1): for i, image_file in enumerate(image_files, 1):
try: try:
# 输出文件名保持原始名称改为PNG格式以支持透明背景 if remove_subject:
output_filename = image_file.stem + '_nobg.png' suffix = image_file.suffix.lower()
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_filename = image_file.stem + "_bgfill" + suffix
else:
output_filename = image_file.stem + "_bgfill.jpg"
else:
# 去背景默认使用PNG格式以支持透明背景
output_filename = image_file.stem + "_nobg.png"
output_path = Path(output_folder) / output_filename output_path = Path(output_folder) / output_filename
print(f"[{i}/{len(image_files)}] ", end="") print(f"[{i}/{len(image_files)}] ", end="")
remove_background( if remove_subject:
str(image_file), mask_output_path = None
str(output_path), if save_mask:
session=session, mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png"))
alpha_matting=alpha_matting, remove_subject_and_inpaint(
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, str(image_file),
alpha_matting_background_threshold=alpha_matting_background_threshold, str(output_path),
alpha_matting_erode_size=alpha_matting_erode_size, session=session,
post_process_mask=post_process_mask alpha_matting=alpha_matting,
) alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
mask_dilate=mask_dilate,
mask_blur=mask_blur,
mask_threshold=mask_threshold,
edge_grow=edge_grow,
aot_root=aot_root,
aot_pretrain=aot_pretrain,
aot_device=aot_device,
aot_block_num=aot_block_num,
aot_rates=aot_rates,
aot_crop=aot_crop,
aot_crop_pad=aot_crop_pad,
aot_max_size=aot_max_size,
aot_noise_prefill=aot_noise_prefill,
aot_noise_strength=aot_noise_strength,
black_subject=black_subject,
black_threshold=black_threshold,
gray_subject=gray_subject,
gray_saturation_threshold=gray_saturation_threshold,
gray_value_threshold=gray_value_threshold,
feather=feather,
feather_radius=feather_radius,
save_mask=save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(image_file),
str(output_path),
session=session,
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask,
)
except Exception as e: except Exception as e:
print(f"处理 {image_file.name} 时出错: {e}") print(f"处理 {image_file.name} 时出错: {e}")
@@ -171,9 +697,8 @@ if __name__ == "__main__":
help='选择使用的模型 (默认: isnet-general-use)') help='选择使用的模型 (默认: isnet-general-use)')
# Alpha Matting参数 # Alpha Matting参数
parser.add_argument('-a', '--alpha-matting', type=str2bool, nargs='?', const=True, default=True, parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true',
metavar='true/false', help='启用alpha matting后处理默认: false')
help='启用alpha matting后处理默认: true。用法: -a 或 -a true 或 -a false')
parser.add_argument('-ft', '--foreground-threshold', type=int, default=245, parser.add_argument('-ft', '--foreground-threshold', type=int, default=245,
help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)') help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)')
parser.add_argument('-bt', '--background-threshold', type=int, default=8, parser.add_argument('-bt', '--background-threshold', type=int, default=8,
@@ -182,9 +707,56 @@ if __name__ == "__main__":
help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)') help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)')
# 其他选项 # 其他选项
parser.add_argument('-p', '--post-process', type=str2bool, nargs='?', const=True, default=True, parser.add_argument('-p', '--post-process', '--post_process', action='store_true',
metavar='true/false', help='启用mask后处理默认: false')
help='启用mask后处理默认: true。用法: -p 或 -p true 或 -p false')
# 去主体补背景参数
parser.add_argument('--remove-subject', '--remove_subject', action='store_true',
help='去掉主体并补全背景(默认: false')
parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting',
help='AOT-GAN目录默认: AOT-GAN-for-Inpainting')
parser.add_argument('--aot-pretrain', type=str, default=None,
help='AOT-GAN预训练权重文件路径必填相对路径基于当前目录')
parser.add_argument('--aot-device', type=str, default='cpu',
help='AOT-GAN设备默认: cpu')
parser.add_argument('--aot-block-num', type=int, default=8,
help='AOTBlock数量默认: 8')
parser.add_argument('--aot-rates', type=str, default='1+2+4+8',
help='AOTBlock膨胀率默认: 1+2+4+8')
parser.add_argument('--aot-crop', action='store_true',
help='AOT仅对mask区域裁剪修补默认: false')
parser.add_argument('--aot-crop-pad', type=int, default=0,
help='AOT裁剪边缘留白像素默认: 0')
parser.add_argument('--aot-max-size', type=int, default=0,
help='AOT输入最大边限制0为不限制默认: 0')
parser.add_argument('--aot-noise-prefill', action='store_true',
help='AOT使用随机噪声预填充默认: false')
parser.add_argument('--aot-noise-strength', type=float, default=1.0,
help='AOT噪声强度系数默认: 1.0')
parser.add_argument('--mask-dilate', type=int, default=3,
help='mask膨胀大小默认: 3')
parser.add_argument('--mask-blur', type=int, default=3,
help='mask模糊大小默认: 3建议奇数')
parser.add_argument('--mask-threshold', type=int, default=10,
help='alpha阈值默认: 10')
parser.add_argument('--edge-grow', type=int, default=0,
help='主体边缘扩张像素(默认: 0')
parser.add_argument('--save-mask', '--save_mask', action='store_true',
help='保存mask到output目录默认: false')
parser.add_argument('--black-subject', '--black_subject', action='store_true',
help='将黑色内容也视为主体(默认: false')
parser.add_argument('--black-threshold', type=int, default=50,
help='黑色阈值(0-255灰度越小越黑默认: 50)')
parser.add_argument('--gray-subject', '--gray_subject', action='store_true',
help='将灰阶内容也视为主体(默认: false')
parser.add_argument('--gray-saturation-threshold', type=int, default=30,
help='灰阶饱和度阈值(0-255越小越接近灰阶默认: 30)')
parser.add_argument('--gray-value-threshold', type=int, default=200,
help='灰阶亮度阈值(0-255越小越暗默认: 200)')
parser.add_argument('--feather', action='store_true',
help='启用边缘过渡融合(默认: false')
parser.add_argument('--feather-radius', type=int, default=5,
help='边缘过渡半径(默认: 5')
args = parser.parse_args() args = parser.parse_args()
@@ -201,13 +773,33 @@ if __name__ == "__main__":
# 处理单个文件 # 处理单个文件
if input_path.is_file(): if input_path.is_file():
suffix = input_path.suffix.lower()
if args.remove_subject:
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
output_name = input_path.stem + "_bgfill" + suffix
else:
output_name = input_path.stem + "_bgfill.jpg"
else:
output_name = input_path.stem + "_nobg.png"
# 确定输出路径 # 确定输出路径
if args.output is None: if args.output is None:
output_path = input_path.parent / 'output' / (input_path.stem + '_nobg.png') if args.remove_subject:
output_path = input_path.parent / "output" / output_name
else:
output_path = input_path.parent / 'output' / output_name
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
else: else:
output_path = Path(args.output) output_candidate = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True) if output_candidate.exists() and output_candidate.is_dir():
output_path = output_candidate / output_name
output_path.parent.mkdir(parents=True, exist_ok=True)
elif output_candidate.suffix == "":
output_candidate.mkdir(parents=True, exist_ok=True)
output_path = output_candidate / output_name
else:
output_path = output_candidate
output_path.parent.mkdir(parents=True, exist_ok=True)
print(f"输入文件: {input_path}") print(f"输入文件: {input_path}")
print(f"输出文件: {output_path}") print(f"输出文件: {output_path}")
@@ -218,22 +810,89 @@ if __name__ == "__main__":
print(f" - 背景阈值: {args.background_threshold}") print(f" - 背景阈值: {args.background_threshold}")
print(f" - 侵蚀大小: {args.erode_size}") print(f" - 侵蚀大小: {args.erode_size}")
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}") print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
if args.remove_subject:
print(f" - AOT目录: {args.aot_root}")
print(f" - AOT权重: {args.aot_pretrain}")
print(f" - AOT设备: {args.aot_device}")
print(f" - AOT块数: {args.aot_block_num}")
print(f" - AOT膨胀率: {args.aot_rates}")
print(f" - AOT裁剪: {'' if args.aot_crop else ''}")
if args.aot_crop:
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
print(f" - AOT最大边: {args.aot_max_size}")
print(f" - AOT噪声预填充: {'' if args.aot_noise_prefill else ''}")
if args.aot_noise_prefill:
print(f" - AOT噪声强度: {args.aot_noise_strength}")
print(f" - mask膨胀: {args.mask_dilate}")
print(f" - mask模糊: {args.mask_blur}")
print(f" - mask阈值: {args.mask_threshold}")
print(f" - 边缘扩张: {args.edge_grow}")
print(f" - 黑色内容作为主体: {'' if args.black_subject else ''}")
if args.black_subject:
print(f" - 黑色阈值: {args.black_threshold}")
print(f" - 灰阶内容作为主体: {'' if args.gray_subject else ''}")
if args.gray_subject:
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
print(f" - 边缘过渡: {'' if args.feather else ''}")
if args.feather:
print(f" - 过渡半径: {args.feather_radius}")
print(f" - 保存mask: {'' if args.save_mask else ''}")
print("-" * 50) print("-" * 50)
# 创建会话 # 创建会话
session = new_session(args.model) session = new_session(args.model)
# 处理图片 # 处理图片
remove_background( if args.remove_subject:
str(input_path), mask_output_path = None
str(output_path), if args.save_mask:
session=session, mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
alpha_matting=args.alpha_matting, remove_subject_and_inpaint(
alpha_matting_foreground_threshold=args.foreground_threshold, str(input_path),
alpha_matting_background_threshold=args.background_threshold, str(output_path),
alpha_matting_erode_size=args.erode_size, session=session,
post_process_mask=args.post_process alpha_matting=args.alpha_matting,
) alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
mask_output_path=mask_output_path,
)
else:
remove_background(
str(input_path),
str(output_path),
session=session,
alpha_matting=args.alpha_matting,
alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process,
)
print("-" * 50) print("-" * 50)
print(f"处理完成!结果保存在: {output_path}") print(f"处理完成!结果保存在: {output_path}")
@@ -250,7 +909,30 @@ if __name__ == "__main__":
alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_foreground_threshold=args.foreground_threshold,
alpha_matting_background_threshold=args.background_threshold, alpha_matting_background_threshold=args.background_threshold,
alpha_matting_erode_size=args.erode_size, alpha_matting_erode_size=args.erode_size,
post_process_mask=args.post_process post_process_mask=args.post_process,
remove_subject=args.remove_subject,
mask_dilate=args.mask_dilate,
mask_blur=args.mask_blur,
mask_threshold=args.mask_threshold,
edge_grow=args.edge_grow,
aot_root=args.aot_root,
aot_pretrain=args.aot_pretrain,
aot_device=args.aot_device,
aot_block_num=args.aot_block_num,
aot_rates=args.aot_rates,
aot_crop=args.aot_crop,
aot_crop_pad=args.aot_crop_pad,
aot_max_size=args.aot_max_size,
aot_noise_prefill=args.aot_noise_prefill,
aot_noise_strength=args.aot_noise_strength,
black_subject=args.black_subject,
black_threshold=args.black_threshold,
gray_subject=args.gray_subject,
gray_saturation_threshold=args.gray_saturation_threshold,
gray_value_threshold=args.gray_value_threshold,
feather=args.feather,
feather_radius=args.feather_radius,
save_mask=args.save_mask,
) )
else: else:

View File

@@ -1,3 +1,4 @@
rembg[gpu] rembg[gpu]
pillow pillow
pillow-heif pillow-heif
opencv-python

View File

@@ -0,0 +1,78 @@
import importlib
import sys
import types
import unittest
def _import_remove_background():
if "remove_background" in sys.modules:
del sys.modules["remove_background"]
rembg_stub = types.ModuleType("rembg")
rembg_stub.remove = lambda *args, **kwargs: None
rembg_stub.new_session = lambda *args, **kwargs: None
sys.modules["rembg"] = rembg_stub
return importlib.import_module("remove_background")
class RemoveBackgroundTests(unittest.TestCase):
def setUp(self):
self.mod = _import_remove_background()
def test_alpha_to_mask_threshold(self):
alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8)
mask = self.mod._alpha_to_mask(alpha, threshold=10)
self.assertEqual(mask.shape, alpha.shape)
self.assertEqual(mask[0, 0], 0)
self.assertEqual(mask[0, 1], 0)
self.assertEqual(mask[0, 2], 0)
self.assertEqual(mask[0, 3], 255)
def test_prepare_mask_outputs_binary(self):
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
mask[2, 2] = 255
mask_hard, mask_used = self.mod._prepare_mask(mask, mask_dilate=0, mask_blur=3)
self.assertEqual(mask_hard.shape, mask.shape)
self.assertEqual(mask_used.shape, mask.shape)
self.assertTrue(self.mod.np.all(self.mod.np.isin(mask_used, [0, 255])))
def test_resolve_aot_paths_defaults(self):
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
"AOT-GAN-for-Inpainting", None
)
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
self.assertIsNone(aot_pretrain)
def test_resolve_aot_paths_relative_pretrain(self):
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
"AOT-GAN-for-Inpainting", "experiments/foo.pt"
)
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
self.assertTrue(str(aot_pretrain).endswith("experiments/foo.pt"))
def test_parse_aot_rates(self):
rates = self.mod._parse_aot_rates("1+2+4+8")
self.assertEqual(rates, [1, 2, 4, 8])
def test_mask_bbox_empty(self):
mask = self.mod.np.zeros((3, 3), dtype=self.mod.np.uint8)
self.assertIsNone(self.mod._mask_bbox(mask))
def test_mask_bbox_and_expand(self):
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
mask[1:3, 2:4] = 255
bbox = self.mod._mask_bbox(mask)
self.assertEqual(bbox, (2, 1, 3, 2))
expanded = self.mod._expand_bbox(bbox, 2, 5, 5)
self.assertEqual(expanded, (0, 0, 4, 4))
expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4)
self.assertEqual(expanded_min, (1, 0, 4, 3))
def test_ensure_rgba_size_resizes(self):
img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0))
out = self.mod._ensure_rgba_size(img, (4, 5))
self.assertEqual(out.mode, "RGBA")
self.assertEqual(out.size, (4, 5))
if __name__ == "__main__":
unittest.main()