update at 2026-03-28 16:46:40
This commit is contained in:
1
AOT-GAN-for-Inpainting
Submodule
1
AOT-GAN-for-Inpainting
Submodule
Submodule AOT-GAN-for-Inpainting added at 2cd1afd8fd
103
README.md
103
README.md
@@ -14,7 +14,9 @@ python remove_background.py
|
||||
# 处理单个文件
|
||||
python remove_background.py input.jpg output.png
|
||||
|
||||
# 查看所有参数-m, --model)
|
||||
# 查看所有参数
|
||||
python remove_background.py -h
|
||||
```
|
||||
|
||||
不同模型适用于不同场景:
|
||||
|
||||
@@ -39,7 +41,19 @@ python remove_background.py input.jpg output.png -m birefnet-portrait
|
||||
|
||||
# 使用快速模型
|
||||
python remove_background.py input.jpg output.png -m u2netp
|
||||
``
|
||||
|
||||
# 书画类文字偏浅的温和参数示例
|
||||
python remove_background.py images output \
|
||||
--remove-subject --black_subject --gray_subject --save_mask \
|
||||
--black-threshold 30\
|
||||
--gray-saturation-threshold 30 --gray-value-threshold 30 \
|
||||
--edge-grow 2 \
|
||||
--feather --feather-radius 4 \
|
||||
--aot-pretrain experiments/G0000000.pt \
|
||||
--aot-max-size 1000
|
||||
```
|
||||
|
||||
查看已下载模型:
|
||||
```bash
|
||||
ls -lh ~/.u2net/
|
||||
```
|
||||
@@ -75,6 +89,40 @@ rm -rf ~/.u2net/
|
||||
rm ~/.u2net/u2net.onnx
|
||||
```
|
||||
|
||||
## AOT-GAN 修补后端
|
||||
|
||||
`--remove-subject` 默认使用 AOT-GAN 修补。
|
||||
AOT-GAN 依赖 PyTorch(官方仓库测试 Python 3.8 / torch 1.8.1)。建议使用独立虚拟环境或确保兼容版本。
|
||||
|
||||
```bash
|
||||
# 安装依赖(示例)
|
||||
pip install torch torchvision
|
||||
```
|
||||
|
||||
下载预训练权重后,运行示例:
|
||||
```bash
|
||||
python remove_background.py "images/IMG_9259 2.JPG" \
|
||||
--remove-subject --black-subject --gray-subject --save-mask \
|
||||
--aot-pretrain experiments/places2.pth
|
||||
```
|
||||
|
||||
CPU 无 GPU 时的加速建议(只裁剪主体区域并限制最大边):
|
||||
```bash
|
||||
python remove_background.py "images/IMG_9259 2.JPG" \
|
||||
--remove-subject --black-subject --gray-subject --save-mask \
|
||||
--aot-pretrain experiments/places2.pth \
|
||||
--aot-crop --aot-crop-pad 24 --aot-max-size 1400
|
||||
```
|
||||
|
||||
减少“补脸”倾向:启用随机噪声预填充
|
||||
```bash
|
||||
python remove_background.py "images/IMG_9259 2.JPG" \
|
||||
--remove-subject --black-subject --gray-subject --save-mask \
|
||||
--aot-pretrain experiments/places2.pth \
|
||||
--aot-crop --aot-crop-pad 64 --aot-max-size 900 \
|
||||
--aot-noise-prefill --aot-noise-strength 1.0
|
||||
```
|
||||
|
||||
## 可调整参数说明
|
||||
|
||||
### 1. 模型选择 (model_name)
|
||||
@@ -97,10 +145,11 @@ rm ~/.u2net/u2net.onnx
|
||||
|
||||
Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头发、毛发等细节。
|
||||
|
||||
#### alpha_matting (True/False)
|
||||
- **作用**: 是否启用alpha matting
|
||||
- **默认**: False
|
||||
- **建议**: 如果边缘不自然,启用此选项
|
||||
#### alpha_matting(开关)
|
||||
- **作用**: 是否启用 alpha matting,提升边缘质量
|
||||
- **默认**: 关闭(不传 `-a/--alpha-matting`)
|
||||
- **启用方式**: 传入 `-a` 或 `--alpha-matting`
|
||||
- **效果**: 有利于细节边缘(毛发/细线),但速度稍慢
|
||||
|
||||
#### alpha_matting_foreground_threshold (0-255)
|
||||
- **作用**: 前景阈值,控制哪些区域被认为是前景
|
||||
@@ -128,9 +177,45 @@ Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头
|
||||
|
||||
### 3. Mask后处理 (post_process_mask)
|
||||
|
||||
- **作用**: 对mask进行额外的后处理
|
||||
- **默认**: False
|
||||
- **建议**: 可以尝试启用看效果是否改善
|
||||
- **作用**: 对 mask 进行额外后处理
|
||||
- **默认**: 关闭(不传 `-p/--post-process`)
|
||||
- **启用方式**: 传入 `-p` 或 `--post-process`
|
||||
- **效果**: 有助于减少毛边,但可能略损失细节
|
||||
|
||||
### 4. 去主体补背景 (remove_subject)
|
||||
|
||||
用于“去掉主体并补全背景”。当前仅使用 AOT-GAN 修补。
|
||||
|
||||
- **remove_subject(开关)**: 启用去主体补背景(默认关闭,传 `--remove-subject` 开启)
|
||||
- **aot_root**: AOT-GAN 目录(默认: `AOT-GAN-for-Inpainting`)
|
||||
- **aot_pretrain**: AOT-GAN 权重文件路径(必填)
|
||||
- **aot_device**: AOT-GAN 设备(默认: `cpu`)
|
||||
- **aot_block_num**: AOTBlock 数量(默认: 8)
|
||||
- **aot_rates**: AOTBlock 膨胀率(默认: `1+2+4+8`)
|
||||
- **aot_crop(开关)**: 仅对 mask 覆盖区域裁剪修补(默认关闭,传 `--aot-crop` 开启)
|
||||
- **aot_crop_pad (像素)**: 裁剪边缘留白像素(默认: 0)
|
||||
- **aot_max_size (像素)**: AOT 输入最大边限制(默认: 0 表示不限制)
|
||||
- **aot_noise_prefill(开关)**: AOT使用随机噪声预填充(默认关闭)
|
||||
- **aot_noise_strength (系数)**: 噪声强度(默认: 1.0)
|
||||
- **mask_dilate (像素)**: mask 膨胀大小(默认: 3)。越大去除范围越大,风险更高
|
||||
- **mask_blur (像素)**: mask 模糊大小(默认: 3)。越大边缘越柔和但易过度
|
||||
- **mask_threshold (0-255)**: alpha 阈值(默认: 10)。越大保留越多主体
|
||||
- **edge_grow (像素)**: 主体边缘额外扩张(默认: 0)。用于清理残留边缘
|
||||
- **save_mask(开关)**: 保存 mask 方便检查(默认关闭,传 `--save-mask` 开启)
|
||||
- **black_subject(开关)**: 将黑色内容也视为主体(默认关闭,传 `--black-subject` 开启)
|
||||
- **black_threshold (0-255)**: 黑色阈值(默认: 50)。越大越容易把浅灰当黑
|
||||
- **gray_subject(开关)**: 将灰阶内容也视为主体(默认关闭,传 `--gray-subject` 开启)
|
||||
- **gray_saturation_threshold (0-255)**: 灰阶饱和度阈值(默认: 30)。越大越容易把彩色当灰
|
||||
- **gray_value_threshold (0-255)**: 灰阶亮度阈值(默认: 200)。越大越容易把浅灰当灰
|
||||
- **feather(开关)**: 启用边缘过渡(默认关闭,传 `--feather` 开启)
|
||||
- **feather_radius (像素)**: 过渡半径(默认: 5)。越大过渡越柔和但可能变糊
|
||||
- **说明**: 过渡仅在 mask 外侧进行,避免把主体边缘带回
|
||||
|
||||
### 5. 参数调优建议(针对书画/字迹)
|
||||
- 先开启 `--remove-subject`,仅看主体遮罩是否覆盖到字迹
|
||||
- 文字残留:提高 `--black-threshold` 或 `--gray-*` 阈值
|
||||
- 过度修补:降低 `--black-threshold`、`--gray-value-threshold`,并减小 `--mask-dilate/--mask-blur`
|
||||
- 边缘不自然:尝试开启 `--feather` 并使用较小的 `--feather-radius`
|
||||
|
||||
## 常见问题解决
|
||||
|
||||
|
||||
BIN
__pycache__/remove_background.cpython-312.pyc
Normal file
BIN
__pycache__/remove_background.cpython-312.pyc
Normal file
Binary file not shown.
BIN
experiments/D0000000.pt
Normal file
BIN
experiments/D0000000.pt
Normal file
Binary file not shown.
BIN
experiments/G0000000.pt
Normal file
BIN
experiments/G0000000.pt
Normal file
Binary file not shown.
BIN
experiments/O0000000.pt
Normal file
BIN
experiments/O0000000.pt
Normal file
Binary file not shown.
0
output/.DS_Store → images/.DS_Store
vendored
0
output/.DS_Store → images/.DS_Store
vendored
16
pyproject.toml
Normal file
16
pyproject.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
ignore_missing_imports = true
|
||||
exclude = '^(generative_inpainting|images|output)/'
|
||||
|
||||
[tool.ruff]
|
||||
exclude = [
|
||||
"AOT-GAN-for-Inpainting",
|
||||
"generative_inpainting",
|
||||
"images",
|
||||
"output",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".git",
|
||||
]
|
||||
@@ -3,10 +3,288 @@
|
||||
使用rembg库自动去除图片背景
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
# 避免 numba 在某些环境下缓存失败
|
||||
os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache")
|
||||
|
||||
from rembg import remove, new_session
|
||||
from PIL import Image
|
||||
|
||||
AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {}
|
||||
|
||||
def _resolve_aot_paths(aot_root, aot_pretrain):
|
||||
"""解析 AOT-GAN 的路径配置。"""
|
||||
aot_root = Path(aot_root).resolve()
|
||||
if aot_pretrain is None:
|
||||
aot_pretrain = None
|
||||
else:
|
||||
aot_pretrain = Path(aot_pretrain)
|
||||
if not aot_pretrain.is_absolute():
|
||||
# 预训练权重相对路径以当前工作目录为基准
|
||||
aot_pretrain = (Path.cwd() / aot_pretrain).resolve()
|
||||
return aot_root, aot_pretrain
|
||||
|
||||
def _parse_aot_rates(rates_str):
|
||||
parts = [p for p in rates_str.split("+") if p]
|
||||
return [int(p) for p in parts]
|
||||
|
||||
def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None):
|
||||
"""加载/缓存 AOT-GAN 模型。"""
|
||||
aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain)
|
||||
if not aot_root.exists():
|
||||
raise FileNotFoundError(f"AOT-GAN目录不存在: {aot_root}")
|
||||
if aot_pretrain is None:
|
||||
raise ValueError("AOT-GAN需要指定预训练权重路径(--aot-pretrain)")
|
||||
if not aot_pretrain.exists():
|
||||
raise FileNotFoundError(f"AOT-GAN权重不存在: {aot_pretrain}")
|
||||
if rates is None:
|
||||
rates = [1, 2, 4, 8]
|
||||
rates_tuple = tuple(rates)
|
||||
key = (str(aot_pretrain), device, int(block_num), rates_tuple)
|
||||
if key in AOT_MODEL_CACHE:
|
||||
return AOT_MODEL_CACHE[key]
|
||||
|
||||
src_root = aot_root / "src"
|
||||
if str(src_root) not in sys.path:
|
||||
sys.path.insert(0, str(src_root))
|
||||
import importlib
|
||||
try:
|
||||
import torch
|
||||
except ImportError as exc:
|
||||
raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc
|
||||
|
||||
net = importlib.import_module("model.aotgan")
|
||||
class _Args:
|
||||
pass
|
||||
args = _Args()
|
||||
args.block_num = int(block_num)
|
||||
args.rates = rates
|
||||
model = net.InpaintGenerator(args)
|
||||
state = torch.load(str(aot_pretrain), map_location=device)
|
||||
if isinstance(state, dict):
|
||||
if "state_dict" in state:
|
||||
state = state["state_dict"]
|
||||
elif "model" in state:
|
||||
state = state["model"]
|
||||
elif "generator" in state:
|
||||
state = state["generator"]
|
||||
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
|
||||
state = {k.replace("module.", "", 1): v for k, v in state.items()}
|
||||
model.load_state_dict(state, strict=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
AOT_MODEL_CACHE[key] = (model, device)
|
||||
return AOT_MODEL_CACHE[key]
|
||||
|
||||
def _mask_bbox(mask):
|
||||
ys, xs = np.where(mask > 0)
|
||||
if len(xs) == 0:
|
||||
return None
|
||||
return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
|
||||
|
||||
def _expand_bbox(bbox, pad, width, height):
|
||||
if bbox is None:
|
||||
return None
|
||||
x0, y0, x1, y1 = bbox
|
||||
pad = max(0, int(pad))
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(width - 1, x1 + pad)
|
||||
y1 = min(height - 1, y1 + pad)
|
||||
return x0, y0, x1, y1
|
||||
|
||||
def _expand_bbox_min_size(bbox, pad, width, height, min_size):
|
||||
"""在给定 pad 基础上,确保裁剪区域至少为 min_size。"""
|
||||
expanded = _expand_bbox(bbox, pad, width, height)
|
||||
if expanded is None:
|
||||
return None
|
||||
x0, y0, x1, y1 = expanded
|
||||
roi_w = x1 - x0 + 1
|
||||
roi_h = y1 - y0 + 1
|
||||
need_w = max(0, int(min_size) - roi_w)
|
||||
need_h = max(0, int(min_size) - roi_h)
|
||||
if need_w == 0 and need_h == 0:
|
||||
return expanded
|
||||
extra = max((need_w + 1) // 2, (need_h + 1) // 2)
|
||||
return _expand_bbox(bbox, pad + extra, width, height)
|
||||
|
||||
def _build_noise_prefill(img, mask_t, strength):
|
||||
"""为mask区域生成噪声预填充,img取值范围为[-1, 1]。"""
|
||||
import torch
|
||||
img01 = (img + 1.0) / 2.0
|
||||
unmasked = 1.0 - mask_t
|
||||
denom = unmasked.sum()
|
||||
if denom.item() < 1.0:
|
||||
mean = torch.full((1, 3, 1, 1), 0.5, device=img.device)
|
||||
std = torch.full((1, 3, 1, 1), 0.2, device=img.device)
|
||||
else:
|
||||
mean = (img01 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
|
||||
var = ((img01 - mean) ** 2 * unmasked).sum(dim=(0, 2, 3), keepdim=True) / denom
|
||||
std = torch.sqrt(var + 1e-6)
|
||||
noise = mean + std * float(strength) * torch.randn_like(img01)
|
||||
noise = noise.clamp(0.0, 1.0)
|
||||
return noise * 2.0 - 1.0
|
||||
|
||||
def _inpaint_with_aot_core(
|
||||
bgr,
|
||||
mask,
|
||||
aot_root,
|
||||
aot_pretrain,
|
||||
device="cpu",
|
||||
block_num=8,
|
||||
rates=None,
|
||||
noise_prefill=False,
|
||||
noise_strength=1.0,
|
||||
):
|
||||
"""使用 AOT-GAN 进行修补,返回 BGR 图像(与输入同尺寸)。"""
|
||||
if mask.dtype != np.uint8:
|
||||
mask = mask.astype(np.uint8)
|
||||
if mask.max() <= 1:
|
||||
mask = mask * 255
|
||||
if mask.max() == 0:
|
||||
return bgr
|
||||
|
||||
h, w = mask.shape
|
||||
grid = 4
|
||||
h2 = (h // grid) * grid
|
||||
w2 = (w // grid) * grid
|
||||
if h2 == 0 or w2 == 0:
|
||||
return bgr
|
||||
|
||||
rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
|
||||
img_crop = rgb[:h2, :w2, :]
|
||||
mask_crop = mask[:h2, :w2]
|
||||
|
||||
import torch
|
||||
img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0
|
||||
img = img * 2.0 - 1.0
|
||||
mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0)
|
||||
img = img.unsqueeze(0)
|
||||
mask_t = mask_t.unsqueeze(0)
|
||||
if device != "cpu":
|
||||
img = img.to(device)
|
||||
mask_t = mask_t.to(device)
|
||||
|
||||
model, _ = _get_aot_model(
|
||||
aot_root=aot_root,
|
||||
aot_pretrain=aot_pretrain,
|
||||
device=device,
|
||||
block_num=block_num,
|
||||
rates=rates,
|
||||
)
|
||||
with torch.no_grad():
|
||||
if noise_prefill:
|
||||
noise = _build_noise_prefill(img, mask_t, noise_strength)
|
||||
image_masked = img * (1 - mask_t) + noise * mask_t
|
||||
else:
|
||||
image_masked = img * (1 - mask_t) + mask_t
|
||||
pred = model(image_masked, mask_t)
|
||||
comp = pred * mask_t + img * (1 - mask_t)
|
||||
comp = comp[0].clamp(-1.0, 1.0)
|
||||
comp = (comp + 1.0) / 2.0 * 255.0
|
||||
comp = comp.permute(1, 2, 0).byte().cpu().numpy()
|
||||
result_bgr = cv2.cvtColor(comp, cv2.COLOR_RGB2BGR)
|
||||
|
||||
if h2 == h and w2 == w:
|
||||
return result_bgr
|
||||
output_full = bgr.copy()
|
||||
output_full[:h2, :w2, :] = result_bgr
|
||||
return output_full
|
||||
|
||||
def _inpaint_with_aot(
|
||||
bgr,
|
||||
mask,
|
||||
aot_root,
|
||||
aot_pretrain,
|
||||
device="cpu",
|
||||
block_num=8,
|
||||
rates=None,
|
||||
crop=False,
|
||||
crop_pad=0,
|
||||
max_size=0,
|
||||
noise_prefill=False,
|
||||
noise_strength=1.0,
|
||||
):
|
||||
"""使用 AOT-GAN 进行修补,支持裁剪与限幅以加速。"""
|
||||
min_side = 32
|
||||
if mask.dtype != np.uint8:
|
||||
mask = mask.astype(np.uint8)
|
||||
if mask.max() <= 1:
|
||||
mask = mask * 255
|
||||
if mask.max() == 0:
|
||||
return bgr
|
||||
|
||||
h, w = mask.shape
|
||||
x0 = y0 = 0
|
||||
x1 = w - 1
|
||||
y1 = h - 1
|
||||
if crop:
|
||||
bbox = _mask_bbox(mask)
|
||||
if bbox is None:
|
||||
return bgr
|
||||
expanded = _expand_bbox_min_size(bbox, crop_pad, w, h, min_side)
|
||||
if expanded is None:
|
||||
return bgr
|
||||
x0, y0, x1, y1 = expanded
|
||||
roi_w = x1 - x0 + 1
|
||||
roi_h = y1 - y0 + 1
|
||||
if roi_w < min_side or roi_h < min_side:
|
||||
# 裁剪区域过小会导致 AOT 失败,回退到全图修补
|
||||
crop = False
|
||||
x0 = y0 = 0
|
||||
x1 = w - 1
|
||||
y1 = h - 1
|
||||
|
||||
bgr_roi = bgr[y0 : y1 + 1, x0 : x1 + 1]
|
||||
mask_roi = mask[y0 : y1 + 1, x0 : x1 + 1]
|
||||
roi_h, roi_w = bgr_roi.shape[:2]
|
||||
|
||||
scale = 1.0
|
||||
max_size = int(max_size) if max_size else 0
|
||||
if max_size > 0:
|
||||
max_dim = max(roi_h, roi_w)
|
||||
if max_dim > max_size:
|
||||
scale = max_size / float(max_dim)
|
||||
new_w = max(1, int(round(roi_w * scale)))
|
||||
new_h = max(1, int(round(roi_h * scale)))
|
||||
if min(new_w, new_h) < min_side:
|
||||
scale = 1.0
|
||||
new_w = roi_w
|
||||
new_h = roi_h
|
||||
interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
|
||||
if scale != 1.0:
|
||||
bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp)
|
||||
mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
filled_roi = _inpaint_with_aot_core(
|
||||
bgr_roi,
|
||||
mask_roi,
|
||||
aot_root=aot_root,
|
||||
aot_pretrain=aot_pretrain,
|
||||
device=device,
|
||||
block_num=block_num,
|
||||
rates=rates,
|
||||
noise_prefill=noise_prefill,
|
||||
noise_strength=noise_strength,
|
||||
)
|
||||
|
||||
if scale != 1.0:
|
||||
filled_roi = cv2.resize(
|
||||
filled_roi,
|
||||
(roi_w, roi_h),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
if not crop:
|
||||
return filled_roi
|
||||
output_full = bgr.copy()
|
||||
output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi
|
||||
return output_full
|
||||
|
||||
# 支持HEIC格式
|
||||
try:
|
||||
@@ -16,17 +294,6 @@ try:
|
||||
except ImportError:
|
||||
HEIC_SUPPORTED = False
|
||||
|
||||
def str2bool(v):
|
||||
"""将字符串转换为布尔值"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('布尔值应为true或false')
|
||||
|
||||
def remove_background(input_path, output_path, session=None, **kwargs):
|
||||
"""
|
||||
去除图片背景
|
||||
@@ -40,7 +307,7 @@ def remove_background(input_path, output_path, session=None, **kwargs):
|
||||
print(f"正在处理: {input_path}")
|
||||
|
||||
# 读取输入图片
|
||||
input_image = Image.open(input_path)
|
||||
input_image = ImageOps.exif_transpose(Image.open(input_path))
|
||||
|
||||
# 去除背景
|
||||
output_image = remove(input_image, session=session, **kwargs)
|
||||
@@ -50,10 +317,195 @@ def remove_background(input_path, output_path, session=None, **kwargs):
|
||||
|
||||
print(f"已保存: {output_path}")
|
||||
|
||||
def process_images_folder(input_folder, output_folder, model_name="u2net",
|
||||
alpha_matting=False, alpha_matting_foreground_threshold=240,
|
||||
alpha_matting_background_threshold=10, alpha_matting_erode_size=10,
|
||||
post_process_mask=False):
|
||||
def _pil_to_bgr(image):
|
||||
"""将PIL图片转换为OpenCV BGR格式"""
|
||||
rgb = image.convert("RGB")
|
||||
arr = np.array(rgb)
|
||||
return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
|
||||
|
||||
def _alpha_to_mask(alpha_channel, threshold=10):
|
||||
"""将alpha通道转换为二值mask(0或255)"""
|
||||
mask = (alpha_channel > threshold).astype(np.uint8) * 255
|
||||
return mask
|
||||
|
||||
def _ensure_rgba_size(image, target_size):
|
||||
if image.mode != "RGBA":
|
||||
image = image.convert("RGBA")
|
||||
if image.size != target_size:
|
||||
image = image.resize(target_size, resample=Image.NEAREST)
|
||||
return image
|
||||
|
||||
def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0):
|
||||
"""生成硬边mask与处理后mask(用于填补/过渡)"""
|
||||
if mask_dilate and mask_dilate > 0:
|
||||
kernel = np.ones((mask_dilate, mask_dilate), np.uint8)
|
||||
mask = cv2.dilate(mask, kernel, iterations=1)
|
||||
if edge_grow and edge_grow > 0:
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
mask = cv2.dilate(mask, kernel, iterations=int(edge_grow))
|
||||
|
||||
mask_hard = mask.copy()
|
||||
mask_used = mask_hard
|
||||
if mask_blur and mask_blur > 0:
|
||||
k = mask_blur if mask_blur % 2 == 1 else mask_blur + 1
|
||||
mask_used = cv2.GaussianBlur(mask_hard, (k, k), 0)
|
||||
mask_used = (mask_used > 0).astype(np.uint8) * 255
|
||||
return mask_hard, mask_used
|
||||
|
||||
def remove_subject_and_inpaint(
|
||||
input_path,
|
||||
output_path,
|
||||
session=None,
|
||||
mask_dilate=3,
|
||||
mask_blur=3,
|
||||
mask_threshold=10,
|
||||
edge_grow=0,
|
||||
aot_root="AOT-GAN-for-Inpainting",
|
||||
aot_pretrain=None,
|
||||
aot_device="cpu",
|
||||
aot_block_num=8,
|
||||
aot_rates="1+2+4+8",
|
||||
aot_crop=False,
|
||||
aot_crop_pad=0,
|
||||
aot_max_size=0,
|
||||
aot_noise_prefill=False,
|
||||
aot_noise_strength=1.0,
|
||||
black_subject=False,
|
||||
black_threshold=50,
|
||||
gray_subject=False,
|
||||
gray_saturation_threshold=30,
|
||||
gray_value_threshold=200,
|
||||
feather=False,
|
||||
feather_radius=5,
|
||||
save_mask=False,
|
||||
mask_output_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
去掉主体并补全背景
|
||||
|
||||
Args:
|
||||
input_path: 输入图片路径
|
||||
output_path: 输出图片路径
|
||||
session: rembg会话对象(可选)
|
||||
aot_root: AOT-GAN目录
|
||||
aot_pretrain: AOT-GAN权重文件路径
|
||||
aot_device: AOT-GAN设备
|
||||
aot_block_num: AOT-GAN AOTBlock 数量
|
||||
aot_rates: AOT-GAN AOTBlock 膨胀率
|
||||
aot_crop: AOT仅裁剪mask区域进行修补
|
||||
aot_crop_pad: AOT裁剪边缘留白像素
|
||||
aot_max_size: AOT输入最大边限制(0为不限制)
|
||||
aot_noise_prefill: AOT使用随机噪声预填充
|
||||
aot_noise_strength: 噪声强度系数
|
||||
mask_dilate: mask膨胀大小
|
||||
mask_blur: mask模糊大小(奇数)
|
||||
mask_threshold: alpha阈值
|
||||
save_mask: 是否保存mask
|
||||
mask_output_path: mask保存路径
|
||||
**kwargs: rembg参数
|
||||
"""
|
||||
print(f"正在处理(去主体补背景): {input_path}")
|
||||
|
||||
input_image = ImageOps.exif_transpose(Image.open(input_path))
|
||||
|
||||
# 使用rembg获取主体mask
|
||||
output_image = remove(input_image, session=session, **kwargs)
|
||||
output_image = _ensure_rgba_size(output_image, input_image.size)
|
||||
|
||||
alpha = np.array(output_image.getchannel("A"))
|
||||
mask = _alpha_to_mask(alpha, threshold=mask_threshold)
|
||||
|
||||
bgr = _pil_to_bgr(input_image)
|
||||
if black_subject:
|
||||
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
|
||||
black_mask = (gray <= black_threshold).astype(np.uint8) * 255
|
||||
mask = cv2.bitwise_or(mask, black_mask)
|
||||
if gray_subject:
|
||||
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
|
||||
s = hsv[:, :, 1]
|
||||
v = hsv[:, :, 2]
|
||||
gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255
|
||||
mask = cv2.bitwise_or(mask, gray_mask)
|
||||
|
||||
mask_hard, mask_used = _prepare_mask(
|
||||
mask,
|
||||
mask_dilate=mask_dilate,
|
||||
mask_blur=mask_blur,
|
||||
edge_grow=edge_grow,
|
||||
)
|
||||
|
||||
rates = _parse_aot_rates(aot_rates)
|
||||
filled = _inpaint_with_aot(
|
||||
bgr,
|
||||
mask_used,
|
||||
aot_root=aot_root,
|
||||
aot_pretrain=aot_pretrain,
|
||||
device=aot_device,
|
||||
block_num=aot_block_num,
|
||||
rates=rates,
|
||||
crop=aot_crop,
|
||||
crop_pad=aot_crop_pad,
|
||||
max_size=aot_max_size,
|
||||
noise_prefill=aot_noise_prefill,
|
||||
noise_strength=aot_noise_strength,
|
||||
)
|
||||
|
||||
if feather and feather_radius > 0:
|
||||
# 仅在mask外侧做过渡,避免把原主体带回
|
||||
mask_bin = (mask_hard > 0).astype(np.uint8) * 255
|
||||
dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3)
|
||||
alpha = np.ones_like(dist_out, dtype=np.float32)
|
||||
outside = mask_bin == 0
|
||||
alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0)
|
||||
alpha = alpha[:, :, None]
|
||||
blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8)
|
||||
result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
result = cv2.cvtColor(filled, cv2.COLOR_BGR2RGB)
|
||||
Image.fromarray(result).save(output_path)
|
||||
|
||||
if save_mask:
|
||||
if mask_output_path is None:
|
||||
mask_output_path = str(Path(output_path).with_suffix("")) + "_mask.png"
|
||||
Image.fromarray(mask_used).save(mask_output_path)
|
||||
print(f"已保存mask: {mask_output_path}")
|
||||
|
||||
print(f"已保存: {output_path}")
|
||||
|
||||
def process_images_folder(
|
||||
input_folder,
|
||||
output_folder,
|
||||
model_name="u2net",
|
||||
alpha_matting=False,
|
||||
alpha_matting_foreground_threshold=240,
|
||||
alpha_matting_background_threshold=10,
|
||||
alpha_matting_erode_size=10,
|
||||
post_process_mask=False,
|
||||
remove_subject=False,
|
||||
mask_dilate=3,
|
||||
mask_blur=3,
|
||||
mask_threshold=10,
|
||||
edge_grow=0,
|
||||
aot_root="AOT-GAN-for-Inpainting",
|
||||
aot_pretrain=None,
|
||||
aot_device="cpu",
|
||||
aot_block_num=8,
|
||||
aot_rates="1+2+4+8",
|
||||
aot_crop=False,
|
||||
aot_crop_pad=0,
|
||||
aot_max_size=0,
|
||||
aot_noise_prefill=False,
|
||||
aot_noise_strength=1.0,
|
||||
black_subject=False,
|
||||
black_threshold=50,
|
||||
gray_subject=False,
|
||||
gray_saturation_threshold=30,
|
||||
gray_value_threshold=200,
|
||||
feather=False,
|
||||
feather_radius=5,
|
||||
save_mask=False,
|
||||
):
|
||||
"""
|
||||
批量处理文件夹中的所有图片
|
||||
|
||||
@@ -105,26 +557,100 @@ def process_images_folder(input_folder, output_folder, model_name="u2net",
|
||||
print(f" - 背景阈值: {alpha_matting_background_threshold}")
|
||||
print(f" - 侵蚀大小: {alpha_matting_erode_size}")
|
||||
print(f"Mask后处理: {'启用' if post_process_mask else '禁用'}")
|
||||
print(f"去主体补背景: {'启用' if remove_subject else '禁用'}")
|
||||
if remove_subject:
|
||||
print(f" - AOT目录: {aot_root}")
|
||||
print(f" - AOT权重: {aot_pretrain}")
|
||||
print(f" - AOT设备: {aot_device}")
|
||||
print(f" - AOT块数: {aot_block_num}")
|
||||
print(f" - AOT膨胀率: {aot_rates}")
|
||||
print(f" - AOT裁剪: {'是' if aot_crop else '否'}")
|
||||
if aot_crop:
|
||||
print(f" - AOT裁剪边界: {aot_crop_pad}")
|
||||
print(f" - AOT最大边: {aot_max_size}")
|
||||
print(f" - AOT噪声预填充: {'是' if aot_noise_prefill else '否'}")
|
||||
if aot_noise_prefill:
|
||||
print(f" - AOT噪声强度: {aot_noise_strength}")
|
||||
print(f" - mask膨胀: {mask_dilate}")
|
||||
print(f" - mask模糊: {mask_blur}")
|
||||
print(f" - mask阈值: {mask_threshold}")
|
||||
print(f" - 边缘扩张: {edge_grow}")
|
||||
print(f" - 黑色内容作为主体: {'是' if black_subject else '否'}")
|
||||
if black_subject:
|
||||
print(f" - 黑色阈值: {black_threshold}")
|
||||
print(f" - 灰阶内容作为主体: {'是' if gray_subject else '否'}")
|
||||
if gray_subject:
|
||||
print(f" - 灰阶饱和度阈值: {gray_saturation_threshold}")
|
||||
print(f" - 灰阶亮度阈值: {gray_value_threshold}")
|
||||
print(f" - 边缘过渡: {'是' if feather else '否'}")
|
||||
if feather:
|
||||
print(f" - 过渡半径: {feather_radius}")
|
||||
print(f" - 保存mask: {'是' if save_mask else '否'}")
|
||||
print("-" * 50)
|
||||
|
||||
# 处理每张图片
|
||||
for i, image_file in enumerate(image_files, 1):
|
||||
try:
|
||||
# 输出文件名(保持原始名称,改为PNG格式以支持透明背景)
|
||||
output_filename = image_file.stem + '_nobg.png'
|
||||
if remove_subject:
|
||||
suffix = image_file.suffix.lower()
|
||||
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
|
||||
output_filename = image_file.stem + "_bgfill" + suffix
|
||||
else:
|
||||
output_filename = image_file.stem + "_bgfill.jpg"
|
||||
else:
|
||||
# 去背景默认使用PNG格式以支持透明背景
|
||||
output_filename = image_file.stem + "_nobg.png"
|
||||
output_path = Path(output_folder) / output_filename
|
||||
|
||||
print(f"[{i}/{len(image_files)}] ", end="")
|
||||
remove_background(
|
||||
str(image_file),
|
||||
str(output_path),
|
||||
session=session,
|
||||
alpha_matting=alpha_matting,
|
||||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||||
post_process_mask=post_process_mask
|
||||
)
|
||||
if remove_subject:
|
||||
mask_output_path = None
|
||||
if save_mask:
|
||||
mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png"))
|
||||
remove_subject_and_inpaint(
|
||||
str(image_file),
|
||||
str(output_path),
|
||||
session=session,
|
||||
alpha_matting=alpha_matting,
|
||||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||||
post_process_mask=post_process_mask,
|
||||
mask_dilate=mask_dilate,
|
||||
mask_blur=mask_blur,
|
||||
mask_threshold=mask_threshold,
|
||||
edge_grow=edge_grow,
|
||||
aot_root=aot_root,
|
||||
aot_pretrain=aot_pretrain,
|
||||
aot_device=aot_device,
|
||||
aot_block_num=aot_block_num,
|
||||
aot_rates=aot_rates,
|
||||
aot_crop=aot_crop,
|
||||
aot_crop_pad=aot_crop_pad,
|
||||
aot_max_size=aot_max_size,
|
||||
aot_noise_prefill=aot_noise_prefill,
|
||||
aot_noise_strength=aot_noise_strength,
|
||||
black_subject=black_subject,
|
||||
black_threshold=black_threshold,
|
||||
gray_subject=gray_subject,
|
||||
gray_saturation_threshold=gray_saturation_threshold,
|
||||
gray_value_threshold=gray_value_threshold,
|
||||
feather=feather,
|
||||
feather_radius=feather_radius,
|
||||
save_mask=save_mask,
|
||||
mask_output_path=mask_output_path,
|
||||
)
|
||||
else:
|
||||
remove_background(
|
||||
str(image_file),
|
||||
str(output_path),
|
||||
session=session,
|
||||
alpha_matting=alpha_matting,
|
||||
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
|
||||
alpha_matting_background_threshold=alpha_matting_background_threshold,
|
||||
alpha_matting_erode_size=alpha_matting_erode_size,
|
||||
post_process_mask=post_process_mask,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 {image_file.name} 时出错: {e}")
|
||||
@@ -171,9 +697,8 @@ if __name__ == "__main__":
|
||||
help='选择使用的模型 (默认: isnet-general-use)')
|
||||
|
||||
# Alpha Matting参数
|
||||
parser.add_argument('-a', '--alpha-matting', type=str2bool, nargs='?', const=True, default=True,
|
||||
metavar='true/false',
|
||||
help='启用alpha matting后处理(默认: true)。用法: -a 或 -a true 或 -a false')
|
||||
parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true',
|
||||
help='启用alpha matting后处理(默认: false)')
|
||||
parser.add_argument('-ft', '--foreground-threshold', type=int, default=245,
|
||||
help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)')
|
||||
parser.add_argument('-bt', '--background-threshold', type=int, default=8,
|
||||
@@ -182,9 +707,56 @@ if __name__ == "__main__":
|
||||
help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)')
|
||||
|
||||
# 其他选项
|
||||
parser.add_argument('-p', '--post-process', type=str2bool, nargs='?', const=True, default=True,
|
||||
metavar='true/false',
|
||||
help='启用mask后处理(默认: true)。用法: -p 或 -p true 或 -p false')
|
||||
parser.add_argument('-p', '--post-process', '--post_process', action='store_true',
|
||||
help='启用mask后处理(默认: false)')
|
||||
|
||||
# 去主体补背景参数
|
||||
parser.add_argument('--remove-subject', '--remove_subject', action='store_true',
|
||||
help='去掉主体并补全背景(默认: false)')
|
||||
parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting',
|
||||
help='AOT-GAN目录(默认: AOT-GAN-for-Inpainting)')
|
||||
parser.add_argument('--aot-pretrain', type=str, default=None,
|
||||
help='AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)')
|
||||
parser.add_argument('--aot-device', type=str, default='cpu',
|
||||
help='AOT-GAN设备(默认: cpu)')
|
||||
parser.add_argument('--aot-block-num', type=int, default=8,
|
||||
help='AOTBlock数量(默认: 8)')
|
||||
parser.add_argument('--aot-rates', type=str, default='1+2+4+8',
|
||||
help='AOTBlock膨胀率(默认: 1+2+4+8)')
|
||||
parser.add_argument('--aot-crop', action='store_true',
|
||||
help='AOT仅对mask区域裁剪修补(默认: false)')
|
||||
parser.add_argument('--aot-crop-pad', type=int, default=0,
|
||||
help='AOT裁剪边缘留白像素(默认: 0)')
|
||||
parser.add_argument('--aot-max-size', type=int, default=0,
|
||||
help='AOT输入最大边限制,0为不限制(默认: 0)')
|
||||
parser.add_argument('--aot-noise-prefill', action='store_true',
|
||||
help='AOT使用随机噪声预填充(默认: false)')
|
||||
parser.add_argument('--aot-noise-strength', type=float, default=1.0,
|
||||
help='AOT噪声强度系数(默认: 1.0)')
|
||||
parser.add_argument('--mask-dilate', type=int, default=3,
|
||||
help='mask膨胀大小(默认: 3)')
|
||||
parser.add_argument('--mask-blur', type=int, default=3,
|
||||
help='mask模糊大小(默认: 3,建议奇数)')
|
||||
parser.add_argument('--mask-threshold', type=int, default=10,
|
||||
help='alpha阈值(默认: 10)')
|
||||
parser.add_argument('--edge-grow', type=int, default=0,
|
||||
help='主体边缘扩张像素(默认: 0)')
|
||||
parser.add_argument('--save-mask', '--save_mask', action='store_true',
|
||||
help='保存mask到output目录(默认: false)')
|
||||
parser.add_argument('--black-subject', '--black_subject', action='store_true',
|
||||
help='将黑色内容也视为主体(默认: false)')
|
||||
parser.add_argument('--black-threshold', type=int, default=50,
|
||||
help='黑色阈值(0-255,灰度越小越黑,默认: 50)')
|
||||
parser.add_argument('--gray-subject', '--gray_subject', action='store_true',
|
||||
help='将灰阶内容也视为主体(默认: false)')
|
||||
parser.add_argument('--gray-saturation-threshold', type=int, default=30,
|
||||
help='灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)')
|
||||
parser.add_argument('--gray-value-threshold', type=int, default=200,
|
||||
help='灰阶亮度阈值(0-255,越小越暗,默认: 200)')
|
||||
parser.add_argument('--feather', action='store_true',
|
||||
help='启用边缘过渡融合(默认: false)')
|
||||
parser.add_argument('--feather-radius', type=int, default=5,
|
||||
help='边缘过渡半径(默认: 5)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -201,13 +773,33 @@ if __name__ == "__main__":
|
||||
|
||||
# 处理单个文件
|
||||
if input_path.is_file():
|
||||
suffix = input_path.suffix.lower()
|
||||
if args.remove_subject:
|
||||
if suffix in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}:
|
||||
output_name = input_path.stem + "_bgfill" + suffix
|
||||
else:
|
||||
output_name = input_path.stem + "_bgfill.jpg"
|
||||
else:
|
||||
output_name = input_path.stem + "_nobg.png"
|
||||
|
||||
# 确定输出路径
|
||||
if args.output is None:
|
||||
output_path = input_path.parent / 'output' / (input_path.stem + '_nobg.png')
|
||||
if args.remove_subject:
|
||||
output_path = input_path.parent / "output" / output_name
|
||||
else:
|
||||
output_path = input_path.parent / 'output' / output_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_candidate = Path(args.output)
|
||||
if output_candidate.exists() and output_candidate.is_dir():
|
||||
output_path = output_candidate / output_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
elif output_candidate.suffix == "":
|
||||
output_candidate.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_candidate / output_name
|
||||
else:
|
||||
output_path = output_candidate
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"输入文件: {input_path}")
|
||||
print(f"输出文件: {output_path}")
|
||||
@@ -218,22 +810,89 @@ if __name__ == "__main__":
|
||||
print(f" - 背景阈值: {args.background_threshold}")
|
||||
print(f" - 侵蚀大小: {args.erode_size}")
|
||||
print(f"Mask后处理: {'启用' if args.post_process else '禁用'}")
|
||||
print(f"去主体补背景: {'启用' if args.remove_subject else '禁用'}")
|
||||
if args.remove_subject:
|
||||
print(f" - AOT目录: {args.aot_root}")
|
||||
print(f" - AOT权重: {args.aot_pretrain}")
|
||||
print(f" - AOT设备: {args.aot_device}")
|
||||
print(f" - AOT块数: {args.aot_block_num}")
|
||||
print(f" - AOT膨胀率: {args.aot_rates}")
|
||||
print(f" - AOT裁剪: {'是' if args.aot_crop else '否'}")
|
||||
if args.aot_crop:
|
||||
print(f" - AOT裁剪边界: {args.aot_crop_pad}")
|
||||
print(f" - AOT最大边: {args.aot_max_size}")
|
||||
print(f" - AOT噪声预填充: {'是' if args.aot_noise_prefill else '否'}")
|
||||
if args.aot_noise_prefill:
|
||||
print(f" - AOT噪声强度: {args.aot_noise_strength}")
|
||||
print(f" - mask膨胀: {args.mask_dilate}")
|
||||
print(f" - mask模糊: {args.mask_blur}")
|
||||
print(f" - mask阈值: {args.mask_threshold}")
|
||||
print(f" - 边缘扩张: {args.edge_grow}")
|
||||
print(f" - 黑色内容作为主体: {'是' if args.black_subject else '否'}")
|
||||
if args.black_subject:
|
||||
print(f" - 黑色阈值: {args.black_threshold}")
|
||||
print(f" - 灰阶内容作为主体: {'是' if args.gray_subject else '否'}")
|
||||
if args.gray_subject:
|
||||
print(f" - 灰阶饱和度阈值: {args.gray_saturation_threshold}")
|
||||
print(f" - 灰阶亮度阈值: {args.gray_value_threshold}")
|
||||
print(f" - 边缘过渡: {'是' if args.feather else '否'}")
|
||||
if args.feather:
|
||||
print(f" - 过渡半径: {args.feather_radius}")
|
||||
print(f" - 保存mask: {'是' if args.save_mask else '否'}")
|
||||
print("-" * 50)
|
||||
|
||||
# 创建会话
|
||||
session = new_session(args.model)
|
||||
|
||||
# 处理图片
|
||||
remove_background(
|
||||
str(input_path),
|
||||
str(output_path),
|
||||
session=session,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||||
alpha_matting_background_threshold=args.background_threshold,
|
||||
alpha_matting_erode_size=args.erode_size,
|
||||
post_process_mask=args.post_process
|
||||
)
|
||||
if args.remove_subject:
|
||||
mask_output_path = None
|
||||
if args.save_mask:
|
||||
mask_output_path = str(output_path.with_suffix("")) + "_mask.png"
|
||||
remove_subject_and_inpaint(
|
||||
str(input_path),
|
||||
str(output_path),
|
||||
session=session,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||||
alpha_matting_background_threshold=args.background_threshold,
|
||||
alpha_matting_erode_size=args.erode_size,
|
||||
post_process_mask=args.post_process,
|
||||
mask_dilate=args.mask_dilate,
|
||||
mask_blur=args.mask_blur,
|
||||
mask_threshold=args.mask_threshold,
|
||||
edge_grow=args.edge_grow,
|
||||
aot_root=args.aot_root,
|
||||
aot_pretrain=args.aot_pretrain,
|
||||
aot_device=args.aot_device,
|
||||
aot_block_num=args.aot_block_num,
|
||||
aot_rates=args.aot_rates,
|
||||
aot_crop=args.aot_crop,
|
||||
aot_crop_pad=args.aot_crop_pad,
|
||||
aot_max_size=args.aot_max_size,
|
||||
aot_noise_prefill=args.aot_noise_prefill,
|
||||
aot_noise_strength=args.aot_noise_strength,
|
||||
black_subject=args.black_subject,
|
||||
black_threshold=args.black_threshold,
|
||||
gray_subject=args.gray_subject,
|
||||
gray_saturation_threshold=args.gray_saturation_threshold,
|
||||
gray_value_threshold=args.gray_value_threshold,
|
||||
feather=args.feather,
|
||||
feather_radius=args.feather_radius,
|
||||
save_mask=args.save_mask,
|
||||
mask_output_path=mask_output_path,
|
||||
)
|
||||
else:
|
||||
remove_background(
|
||||
str(input_path),
|
||||
str(output_path),
|
||||
session=session,
|
||||
alpha_matting=args.alpha_matting,
|
||||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||||
alpha_matting_background_threshold=args.background_threshold,
|
||||
alpha_matting_erode_size=args.erode_size,
|
||||
post_process_mask=args.post_process,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print(f"处理完成!结果保存在: {output_path}")
|
||||
@@ -250,7 +909,30 @@ if __name__ == "__main__":
|
||||
alpha_matting_foreground_threshold=args.foreground_threshold,
|
||||
alpha_matting_background_threshold=args.background_threshold,
|
||||
alpha_matting_erode_size=args.erode_size,
|
||||
post_process_mask=args.post_process
|
||||
post_process_mask=args.post_process,
|
||||
remove_subject=args.remove_subject,
|
||||
mask_dilate=args.mask_dilate,
|
||||
mask_blur=args.mask_blur,
|
||||
mask_threshold=args.mask_threshold,
|
||||
edge_grow=args.edge_grow,
|
||||
aot_root=args.aot_root,
|
||||
aot_pretrain=args.aot_pretrain,
|
||||
aot_device=args.aot_device,
|
||||
aot_block_num=args.aot_block_num,
|
||||
aot_rates=args.aot_rates,
|
||||
aot_crop=args.aot_crop,
|
||||
aot_crop_pad=args.aot_crop_pad,
|
||||
aot_max_size=args.aot_max_size,
|
||||
aot_noise_prefill=args.aot_noise_prefill,
|
||||
aot_noise_strength=args.aot_noise_strength,
|
||||
black_subject=args.black_subject,
|
||||
black_threshold=args.black_threshold,
|
||||
gray_subject=args.gray_subject,
|
||||
gray_saturation_threshold=args.gray_saturation_threshold,
|
||||
gray_value_threshold=args.gray_value_threshold,
|
||||
feather=args.feather,
|
||||
feather_radius=args.feather_radius,
|
||||
save_mask=args.save_mask,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
rembg[gpu]
|
||||
pillow
|
||||
pillow-heif
|
||||
opencv-python
|
||||
|
||||
BIN
tests/__pycache__/test_remove_background.cpython-312.pyc
Normal file
BIN
tests/__pycache__/test_remove_background.cpython-312.pyc
Normal file
Binary file not shown.
78
tests/test_remove_background.py
Normal file
78
tests/test_remove_background.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
|
||||
def _import_remove_background():
|
||||
if "remove_background" in sys.modules:
|
||||
del sys.modules["remove_background"]
|
||||
rembg_stub = types.ModuleType("rembg")
|
||||
rembg_stub.remove = lambda *args, **kwargs: None
|
||||
rembg_stub.new_session = lambda *args, **kwargs: None
|
||||
sys.modules["rembg"] = rembg_stub
|
||||
return importlib.import_module("remove_background")
|
||||
|
||||
|
||||
class RemoveBackgroundTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mod = _import_remove_background()
|
||||
|
||||
def test_alpha_to_mask_threshold(self):
|
||||
alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8)
|
||||
mask = self.mod._alpha_to_mask(alpha, threshold=10)
|
||||
self.assertEqual(mask.shape, alpha.shape)
|
||||
self.assertEqual(mask[0, 0], 0)
|
||||
self.assertEqual(mask[0, 1], 0)
|
||||
self.assertEqual(mask[0, 2], 0)
|
||||
self.assertEqual(mask[0, 3], 255)
|
||||
|
||||
def test_prepare_mask_outputs_binary(self):
|
||||
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
|
||||
mask[2, 2] = 255
|
||||
mask_hard, mask_used = self.mod._prepare_mask(mask, mask_dilate=0, mask_blur=3)
|
||||
self.assertEqual(mask_hard.shape, mask.shape)
|
||||
self.assertEqual(mask_used.shape, mask.shape)
|
||||
self.assertTrue(self.mod.np.all(self.mod.np.isin(mask_used, [0, 255])))
|
||||
|
||||
def test_resolve_aot_paths_defaults(self):
|
||||
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
|
||||
"AOT-GAN-for-Inpainting", None
|
||||
)
|
||||
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
|
||||
self.assertIsNone(aot_pretrain)
|
||||
|
||||
def test_resolve_aot_paths_relative_pretrain(self):
|
||||
aot_root, aot_pretrain = self.mod._resolve_aot_paths(
|
||||
"AOT-GAN-for-Inpainting", "experiments/foo.pt"
|
||||
)
|
||||
self.assertTrue(str(aot_root).endswith("AOT-GAN-for-Inpainting"))
|
||||
self.assertTrue(str(aot_pretrain).endswith("experiments/foo.pt"))
|
||||
|
||||
def test_parse_aot_rates(self):
|
||||
rates = self.mod._parse_aot_rates("1+2+4+8")
|
||||
self.assertEqual(rates, [1, 2, 4, 8])
|
||||
|
||||
def test_mask_bbox_empty(self):
|
||||
mask = self.mod.np.zeros((3, 3), dtype=self.mod.np.uint8)
|
||||
self.assertIsNone(self.mod._mask_bbox(mask))
|
||||
|
||||
def test_mask_bbox_and_expand(self):
|
||||
mask = self.mod.np.zeros((5, 5), dtype=self.mod.np.uint8)
|
||||
mask[1:3, 2:4] = 255
|
||||
bbox = self.mod._mask_bbox(mask)
|
||||
self.assertEqual(bbox, (2, 1, 3, 2))
|
||||
expanded = self.mod._expand_bbox(bbox, 2, 5, 5)
|
||||
self.assertEqual(expanded, (0, 0, 4, 4))
|
||||
expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4)
|
||||
self.assertEqual(expanded_min, (1, 0, 4, 3))
|
||||
|
||||
def test_ensure_rgba_size_resizes(self):
|
||||
img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0))
|
||||
out = self.mod._ensure_rgba_size(img, (4, 5))
|
||||
self.assertEqual(out.mode, "RGBA")
|
||||
self.assertEqual(out.size, (4, 5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user