diff --git a/.DS_Store b/.DS_Store index eb4eeed..644aa33 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/README.md b/README.md index 42a424d..8a2f4a1 100644 --- a/README.md +++ b/README.md @@ -1,293 +1,294 @@ # 图片去背景工具 -使用rembg库实现的Python去背景工具 +用于书画、篆刻作品的去背景与去主体补背景脚本。 -## 快速开始 +默认模式优先走书画专用前景提取逻辑: +- 对墨色笔画做局部背景校正与明暗差分 +- 对印章做更严格的红色检测 +- 在粗掩码内进一步细化 alpha,尽量减少字边残留底色 +- 仅在需要时才回退到 `rembg` + +## 功能概览 + +- 单张图片去背景,输出透明 PNG +- 批量处理单层目录中的图片 +- 支持 `artwork`、`auto`、`rembg` 三种前景提取模式 +- 支持 `calligraphy`、`seal`、`auto` 三种书画类型 +- 支持 AOT-GAN 去主体补背景 +- 支持常见格式:`jpg`、`jpeg`、`png`、`bmp`、`webp` +- 安装 `pillow-heif` 后可读取 `heic` / `heif` + +## 依赖 + +项目当前依赖见 [requirements.txt](/Users/gavin/removeback/requirements.txt): + +```txt +rembg[gpu] +pillow +pillow-heif +opencv-python +``` + +如果要使用 AOT-GAN 补背景,还需要额外安装 `torch` / `torchvision`,并准备 AOT-GAN 代码目录与权重文件。 + +## 环境准备 + +推荐使用现有虚拟环境: ```bash -# 激活虚拟环境 source ~/venv/bin/activate - -# 使用默认参数处理images文件夹 -python remove_background.py - -# 处理单个文件 -python remove_background.py input.jpg output.png - -# 查看所有参数 -python remove_background.py -h ``` -不同模型适用于不同场景: +安装依赖: -| 模型名称 | 大小 | 适用场景 | 推荐度 | -|---------|------|---------|--------| -| **isnet-general-use** | 179MB | 通用场景 | ⭐⭐⭐⭐⭐ 默认推荐 | -| birefnet-general | 250MB | 通用场景,质量更高 | ⭐⭐⭐⭐⭐ | -| birefnet-portrait | 250MB | 人像专用 | ⭐⭐⭐⭐⭐ | -| u2net | 176MB | 经典通用模型 | ⭐⭐⭐⭐ | -| u2netp | 4.7MB | 快速处理 | ⭐⭐⭐ | -| u2net_human_seg | 176MB | 人物分割 | ⭐⭐⭐⭐ | -| isnet-anime | 179MB | 动漫角色 | ⭐⭐⭐⭐ | -| silueta | 43MB | 精简快速 | ⭐⭐⭐ | - -**使用示例**: ```bash -# 使用默认模型 -python remove_background.py input.jpg - -# 使用人像专用模型 -python remove_background.py input.jpg output.png -m birefnet-portrait - -# 使用快速模型 -python remove_background.py input.jpg output.png -m u2netp - -# 书画类文字偏浅的温和参数示例 -python remove_background.py images output \ - --remove-subject --black_subject --gray_subject --save_mask \ - --black-threshold 30\ - --gray-saturation-threshold 30 --gray-value-threshold 30 \ - --edge-grow 2 \ - --feather --feather-radius 4 \ - --aot-pretrain experiments/G0000000.pt \ - --aot-max-size 1000 +pip install -r requirements.txt ``` -查看已下载模型: -```bash -ls -lh ~/.u2net/ -``` - -### 模型大小参考 - -| 模型名称 | 文件大小 | 特点 | -|---------|---------|------| -| u2net | 176MB | 通用模型 | -| u2netp | 4.7MB | 轻量级,速度快 | -| isnet-general-use | 179MB | 新一代通用,推荐 | -| birefnet-general | ~250MB | 最新通用模型 | -| birefnet-portrait | ~250MB | 人像专用 | - -### 手动下载模型(网络问题时) +如果你需要启用去主体补背景: ```bash -# 创建目录 -mkdir -p ~/.u2net/ - -# 下载指定模型(以isnet-general-use为例) -curl -L "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx" \ - -o ~/.u2net/isnet-general-use.onnx -``` - -### 清理模型缓存 - -```bash -# 删除所有已下载的模型 -rm -rf ~/.u2net/ - -# 删除特定模型 -rm ~/.u2net/u2net.onnx -``` - -## AOT-GAN 修补后端 - -`--remove-subject` 默认使用 AOT-GAN 修补。 -AOT-GAN 依赖 PyTorch(官方仓库测试 Python 3.8 / torch 1.8.1)。建议使用独立虚拟环境或确保兼容版本。 - -```bash -# 安装依赖(示例) pip install torch torchvision ``` -下载预训练权重后,运行示例: +## 快速开始 + +处理默认 `images/` 目录中的图片,结果输出到 `output/`: + ```bash -python remove_background.py "images/IMG_9259 2.JPG" \ - --remove-subject --black-subject --gray-subject --save-mask \ - --aot-pretrain experiments/places2.pth +python remove_background.py ``` -CPU 无 GPU 时的加速建议(只裁剪主体区域并限制最大边): +处理单张图片: + ```bash -python remove_background.py "images/IMG_9259 2.JPG" \ - --remove-subject --black-subject --gray-subject --save-mask \ - --aot-pretrain experiments/places2.pth \ - --aot-crop --aot-crop-pad 24 --aot-max-size 1400 +python remove_background.py input.jpg output.png ``` -减少“补脸”倾向:启用随机噪声预填充 +处理指定目录: + ```bash -python remove_background.py "images/IMG_9259 2.JPG" \ - --remove-subject --black-subject --gray-subject --save-mask \ - --aot-pretrain experiments/places2.pth \ - --aot-crop --aot-crop-pad 64 --aot-max-size 900 \ - --aot-noise-prefill --aot-noise-strength 1.0 +python remove_background.py my_images/ my_output/ ``` -## 可调整参数说明 +指定为印章场景: -### 1. 模型选择 (model_name) - -不同模型适用于不同场景: - -- **u2net** (默认): 通用模型,适合大多数场景 -- **u2netp**: 轻量版,速度更快但精度稍低 -- **u2net_human_seg**: 专门用于人物分割 -- **silueta**: 精简版u2net (43MB),速度快 -- **isnet-general-use**: 新一代通用模型,效果可能更好 -- **isnet-anime**: 专门用于动漫角色 -- **birefnet-general**: 最新的通用模型,推荐尝试 -- **birefnet-portrait**: 专门用于人像 -- **birefnet-general-lite**: 轻量版birefnet - -**建议**: 如果u2net效果不好,试试 `isnet-general-use` 或 `birefnet-general` - -### 2. Alpha Matting 参数 - -Alpha Matting 是后处理步骤,可以显著改善边缘质量,特别是头发、毛发等细节。 - -#### alpha_matting(开关) -- **作用**: 是否启用 alpha matting,提升边缘质量 -- **默认**: 关闭(不传 `-a/--alpha-matting`) -- **启用方式**: 传入 `-a` 或 `--alpha-matting` -- **效果**: 有利于细节边缘(毛发/细线),但速度稍慢 - -#### alpha_matting_foreground_threshold (0-255) -- **作用**: 前景阈值,控制哪些区域被认为是前景 -- **默认**: 240 -- **调整建议**: - - 值越大(如270): 保留更多细节,但可能保留一些背景 - - 值越小(如210): 去除更彻底,但可能丢失细节 - - 如果前景被过度去除,增加此值 - - 如果背景残留太多,减小此值 - -#### alpha_matting_background_threshold (0-255) -- **作用**: 背景阈值,控制哪些区域被认为是背景 -- **默认**: 10 -- **调整建议**: - - 值越大(如20-30): 去除背景更彻底 - - 值越小(如5): 保留更多过渡区域 - - 如果背景残留,增加此值 - -#### alpha_matting_erode_size (像素) -- **作用**: 侵蚀大小,用于平滑边缘 -- **默认**: 10 -- **调整建议**: - - 值越大(如15-20): 边缘更平滑,但可能损失细节 - - 值越小(如5-8): 保留更多细节,但边缘可能不够平滑 - -### 3. Mask后处理 (post_process_mask) - -- **作用**: 对 mask 进行额外后处理 -- **默认**: 关闭(不传 `-p/--post-process`) -- **启用方式**: 传入 `-p` 或 `--post-process` -- **效果**: 有助于减少毛边,但可能略损失细节 - -### 4. 去主体补背景 (remove_subject) - -用于“去掉主体并补全背景”。当前仅使用 AOT-GAN 修补。 - -- **remove_subject(开关)**: 启用去主体补背景(默认关闭,传 `--remove-subject` 开启) -- **aot_root**: AOT-GAN 目录(默认: `AOT-GAN-for-Inpainting`) -- **aot_pretrain**: AOT-GAN 权重文件路径(必填) -- **aot_device**: AOT-GAN 设备(默认: `cpu`) -- **aot_block_num**: AOTBlock 数量(默认: 8) -- **aot_rates**: AOTBlock 膨胀率(默认: `1+2+4+8`) -- **aot_crop(开关)**: 仅对 mask 覆盖区域裁剪修补(默认关闭,传 `--aot-crop` 开启) -- **aot_crop_pad (像素)**: 裁剪边缘留白像素(默认: 0) -- **aot_max_size (像素)**: AOT 输入最大边限制(默认: 0 表示不限制) -- **aot_noise_prefill(开关)**: AOT使用随机噪声预填充(默认关闭) -- **aot_noise_strength (系数)**: 噪声强度(默认: 1.0) -- **mask_dilate (像素)**: mask 膨胀大小(默认: 3)。越大去除范围越大,风险更高 -- **mask_blur (像素)**: mask 模糊大小(默认: 3)。越大边缘越柔和但易过度 -- **mask_threshold (0-255)**: alpha 阈值(默认: 10)。越大保留越多主体 -- **edge_grow (像素)**: 主体边缘额外扩张(默认: 0)。用于清理残留边缘 -- **save_mask(开关)**: 保存 mask 方便检查(默认关闭,传 `--save-mask` 开启) -- **black_subject(开关)**: 将黑色内容也视为主体(默认关闭,传 `--black-subject` 开启) -- **black_threshold (0-255)**: 黑色阈值(默认: 50)。越大越容易把浅灰当黑 -- **gray_subject(开关)**: 将灰阶内容也视为主体(默认关闭,传 `--gray-subject` 开启) -- **gray_saturation_threshold (0-255)**: 灰阶饱和度阈值(默认: 30)。越大越容易把彩色当灰 -- **gray_value_threshold (0-255)**: 灰阶亮度阈值(默认: 200)。越大越容易把浅灰当灰 -- **feather(开关)**: 启用边缘过渡(默认关闭,传 `--feather` 开启) -- **feather_radius (像素)**: 过渡半径(默认: 5)。越大过渡越柔和但可能变糊 -- **说明**: 过渡仅在 mask 外侧进行,避免把主体边缘带回 - -### 5. 参数调优建议(针对书画/字迹) -- 先开启 `--remove-subject`,仅看主体遮罩是否覆盖到字迹 -- 文字残留:提高 `--black-threshold` 或 `--gray-*` 阈值 -- 过度修补:降低 `--black-threshold`、`--gray-value-threshold`,并减小 `--mask-dilate/--mask-blur` -- 边缘不自然:尝试开启 `--feather` 并使用较小的 `--feather-radius` - -## 常见问题解决 - -### 问题1: 前景被过度去除 -**解决方案**: -```python -alpha_matting = True -alpha_matting_foreground_threshold = 270 # 增加此值 -alpha_matting_background_threshold = 10 # 保持较小 +```bash +python remove_background.py input.jpg output.png --artwork-type seal ``` -### 问题2: 背景残留太多 -**解决方案**: -```python -alpha_matting = True -alpha_matting_foreground_threshold = 240 # 保持默认或减小 -alpha_matting_background_threshold = 20 # 增加此值 -post_process_mask = True # 启用后处理 +强制使用 `rembg`: + +```bash +python remove_background.py input.jpg output.png --foreground-mode rembg -m isnet-general-use ``` -### 问题3: 边缘不自然、有锯齿 -**解决方案**: -```python -alpha_matting = True -alpha_matting_erode_size = 15 # 增加平滑程度 +查看完整参数: + +```bash +python remove_background.py -h ``` -### 问题4: 毛发、头发细节丢失 -**解决方案**: -```python -model_name = "birefnet-portrait" # 使用人像专用模型 -alpha_matting = True -alpha_matting_foreground_threshold = 270 # 增加以保留细节 -alpha_matting_erode_size = 5 # 减小以保留细节 +## 输出规则 + +- 单张图片默认输出为 `*_nobg.png` +- 处理目录时,输出文件写入你指定的输出目录 +- 脚本内置的目录批处理只扫描输入目录的第一层文件,不递归子目录 +- 如果启用 `--remove-subject`,输出文件名改为 `*_bgfill.<原扩展>` 或 `*_bgfill.jpg` + +## 前景提取模式 + +### `artwork` + +默认模式。优先适用于书法、国画、篆刻等纸本图像。 + +### `auto` + +先尝试书画专用掩码;如果结果明显不可信,再回退到 `rembg`。 + +### `rembg` + +强制使用通用抠图模型。适合非书画类图片,或书画专用规则不适合的特殊样本。 + +## 书画类型 + +### `auto` + +自动兼容书法与印章。 + +### `calligraphy` + +更偏重墨色笔画、灰黑色文字。 + +### `seal` + +更偏重红章、篆刻印记。 + +## 常用参数 + +### 书画专用参数 + +- `--foreground-mode` +- `--artwork-type` +- `--artwork-max-size` + +建议: +- 大图先尝试 `--artwork-max-size 1600` +- 超大图如果速度较慢,可降低到 `1200` 或 `1000` +- 红章较多的图片优先试 `--artwork-type seal` +- 纯墨迹优先试 `--artwork-type calligraphy` + +### rembg 相关参数 + +这些参数只在 `--foreground-mode rembg` 或 `auto` 回退到 `rembg` 时生效: + +- `-m, --model` +- `-a, --alpha-matting` +- `-ft, --foreground-threshold` +- `-bt, --background-threshold` +- `-es, --erode-size` +- `-p, --post-process` + +当前支持的 `rembg` 模型包括: + +- `u2net` +- `u2netp` +- `u2net_human_seg` +- `silueta` +- `isnet-general-use` +- `isnet-anime` +- `birefnet-general` +- `birefnet-general-lite` +- `birefnet-portrait` +- `birefnet-dis` +- `birefnet-hrsod` +- `birefnet-cod` +- `birefnet-massive` + +### 去主体补背景参数 + +启用: + +```bash +python remove_background.py input.jpg output.jpg --remove-subject --aot-pretrain experiments/your_model.pt ``` -## 推荐配置 +常用参数: -### 配置1: 高质量人像 -```python -model_name = "birefnet-portrait" -alpha_matting = True -alpha_matting_foreground_threshold = 260 -alpha_matting_background_threshold = 15 -alpha_matting_erode_size = 10 -post_process_mask = True +- `--aot-root` +- `--aot-pretrain` +- `--aot-device` +- `--aot-block-num` +- `--aot-rates` +- `--aot-crop` +- `--aot-crop-pad` +- `--aot-max-size` +- `--aot-noise-prefill` +- `--aot-noise-strength` +- `--mask-dilate` +- `--mask-blur` +- `--mask-threshold` +- `--edge-grow` +- `--save-mask` +- `--black-subject` +- `--black-threshold` +- `--gray-subject` +- `--gray-saturation-threshold` +- `--gray-value-threshold` +- `--feather` +- `--feather-radius` + +一个偏保守的示例: + +```bash +python remove_background.py "images/inpaint/IMG_9259 2.JPG" output.jpg \ + --remove-subject \ + --foreground-mode artwork \ + --artwork-type auto \ + --aot-pretrain experiments/G0000000.pt \ + --aot-crop \ + --aot-crop-pad 64 \ + --aot-max-size 900 \ + --aot-noise-prefill \ + --aot-noise-strength 1.0 ``` -### 配置2: 通用高质量 -```python -model_name = "birefnet-general" -alpha_matting = True -alpha_matting_foreground_threshold = 250 -alpha_matting_background_threshold = 12 -alpha_matting_erode_size = 10 -post_process_mask = True +## 典型用法 + +书法图去背景: + +```bash +python remove_background.py input.jpg output.png \ + --foreground-mode artwork \ + --artwork-type calligraphy ``` -### 配置3: 快速处理 -```python -model_name = "u2netp" -alpha_matting = False -post_process_mask = False +印章图去背景: + +```bash +python remove_background.py input.jpg output.png \ + --foreground-mode artwork \ + --artwork-type seal ``` -## 测试不同参数 +通用图片走 `rembg`: -建议按以下顺序调整: +```bash +python remove_background.py input.jpg output.png \ + --foreground-mode rembg \ + --model birefnet-general +``` -1. 先尝试不同的模型 -2. 启用alpha_matting -3. 调整foreground_threshold和background_threshold -4. 最后调整erode_size +## 调参建议 -每次修改后运行脚本,对比结果。 +如果背景没有去掉: + +- 先尝试 `--artwork-type calligraphy` +- 再尝试 `--foreground-mode auto` +- 非书画图直接改用 `--foreground-mode rembg` + +如果字边仍有底色: + +- 先确认原图是否有严重纸纹、阴影或压缩噪声 +- 降低 `--artwork-max-size` 可能更快,但通常不利于细节 +- 对超大图可以保留 `1600`,必要时单独抽样检查结果 + +如果去主体补背景不自然: + +- 开启 `--aot-crop` +- 增加 `--aot-crop-pad` +- 尝试 `--aot-noise-prefill` +- 减小 `--mask-dilate`、`--mask-blur` + +## 校验命令 + +项目当前可用的检查命令: + +```bash +~/venv/bin/python -m pytest -q +~/venv/bin/python -m mypy remove_background.py tests/test_remove_background.py +~/venv/bin/python -m ruff check remove_background.py tests/test_remove_background.py +~/venv/bin/python -m ruff format remove_background.py tests/test_remove_background.py +``` + +说明: +- `mypy` 与 `ruff` 的项目配置见 [pyproject.toml](/Users/gavin/removeback/pyproject.toml) +- `images/` 与 `output/` 已在静态检查配置里排除 + +## 已知限制 + +- CLI 自带的目录批处理不递归子目录 +- 书画专用规则仍然可能受极端纸色、重阴影、扫描边框影响 +- 某些透明 PNG 在图片预览器里会显示白底或棋盘底,这是预览器合成效果,不代表 alpha 一定有问题 + +## 仓库结构 + +```txt +remove_background.py 主脚本 +tests/test_remove_background.py +requirements.txt +pyproject.toml +images/ 输入样例 +output/ 输出目录 +experiments/ AOT-GAN 权重示例 +``` diff --git a/__pycache__/remove_background.cpython-312.pyc b/__pycache__/remove_background.cpython-312.pyc index 1e78c22..f3b2116 100644 Binary files a/__pycache__/remove_background.cpython-312.pyc and b/__pycache__/remove_background.cpython-312.pyc differ diff --git a/images/.DS_Store b/images/.DS_Store index ec5b2aa..493a0f3 100644 Binary files a/images/.DS_Store and b/images/.DS_Store differ diff --git a/remove_background.py b/remove_background.py index a59690f..3b71468 100644 --- a/remove_background.py +++ b/remove_background.py @@ -1,7 +1,5 @@ -""" -图片去背景工具 -使用rembg库自动去除图片背景 -""" +"""书画与篆刻作品去背景工具。""" + import os import sys import argparse @@ -13,9 +11,15 @@ from PIL import Image, ImageOps # 避免 numba 在某些环境下缓存失败 os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache") -from rembg import remove, new_session +try: + from rembg import remove, new_session +except ImportError: + remove = None + new_session = None AOT_MODEL_CACHE: dict[tuple[str, str, int, tuple[int, ...]], tuple[object, object]] = {} +REMBG_SESSION_CACHE: dict[str, object] = {} + def _resolve_aot_paths(aot_root, aot_pretrain): """解析 AOT-GAN 的路径配置。""" @@ -29,10 +33,12 @@ def _resolve_aot_paths(aot_root, aot_pretrain): aot_pretrain = (Path.cwd() / aot_pretrain).resolve() return aot_root, aot_pretrain + def _parse_aot_rates(rates_str): parts = [p for p in rates_str.split("+") if p] return [int(p) for p in parts] + def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None): """加载/缓存 AOT-GAN 模型。""" aot_root, aot_pretrain = _resolve_aot_paths(aot_root, aot_pretrain) @@ -53,14 +59,17 @@ def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None if str(src_root) not in sys.path: sys.path.insert(0, str(src_root)) import importlib + try: import torch except ImportError as exc: raise ImportError("未找到 PyTorch,请先按README安装依赖。") from exc net = importlib.import_module("model.aotgan") + class _Args: pass + args = _Args() args.block_num = int(block_num) args.rates = rates @@ -81,12 +90,14 @@ def _get_aot_model(aot_root, aot_pretrain, device="cpu", block_num=8, rates=None AOT_MODEL_CACHE[key] = (model, device) return AOT_MODEL_CACHE[key] + def _mask_bbox(mask): ys, xs = np.where(mask > 0) if len(xs) == 0: return None return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) + def _expand_bbox(bbox, pad, width, height): if bbox is None: return None @@ -98,6 +109,7 @@ def _expand_bbox(bbox, pad, width, height): y1 = min(height - 1, y1 + pad) return x0, y0, x1, y1 + def _expand_bbox_min_size(bbox, pad, width, height, min_size): """在给定 pad 基础上,确保裁剪区域至少为 min_size。""" expanded = _expand_bbox(bbox, pad, width, height) @@ -113,9 +125,11 @@ def _expand_bbox_min_size(bbox, pad, width, height, min_size): extra = max((need_w + 1) // 2, (need_h + 1) // 2) return _expand_bbox(bbox, pad + extra, width, height) + def _build_noise_prefill(img, mask_t, strength): """为mask区域生成噪声预填充,img取值范围为[-1, 1]。""" import torch + img01 = (img + 1.0) / 2.0 unmasked = 1.0 - mask_t denom = unmasked.sum() @@ -130,6 +144,7 @@ def _build_noise_prefill(img, mask_t, strength): noise = noise.clamp(0.0, 1.0) return noise * 2.0 - 1.0 + def _inpaint_with_aot_core( bgr, mask, @@ -161,6 +176,7 @@ def _inpaint_with_aot_core( mask_crop = mask[:h2, :w2] import torch + img = torch.from_numpy(img_crop).permute(2, 0, 1).float() / 255.0 img = img * 2.0 - 1.0 mask_t = torch.from_numpy((mask_crop > 0).astype(np.float32)).unsqueeze(0) @@ -196,6 +212,7 @@ def _inpaint_with_aot_core( output_full[:h2, :w2, :] = result_bgr return output_full + def _inpaint_with_aot( bgr, mask, @@ -259,7 +276,9 @@ def _inpaint_with_aot( interp = cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR if scale != 1.0: bgr_roi = cv2.resize(bgr_roi, (new_w, new_h), interpolation=interp) - mask_roi = cv2.resize(mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST) + mask_roi = cv2.resize( + mask_roi, (new_w, new_h), interpolation=cv2.INTER_NEAREST + ) filled_roi = _inpaint_with_aot_core( bgr_roi, @@ -286,48 +305,373 @@ def _inpaint_with_aot( output_full[y0 : y1 + 1, x0 : x1 + 1] = filled_roi return output_full + # 支持HEIC格式 try: from pillow_heif import register_heif_opener + register_heif_opener() HEIC_SUPPORTED = True except ImportError: HEIC_SUPPORTED = False -def remove_background(input_path, output_path, session=None, **kwargs): + +def _get_rembg_session(model_name): + """按模型缓存 rembg 会话,避免重复加载。""" + if new_session is None: + raise ImportError( + "未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。" + ) + if model_name not in REMBG_SESSION_CACHE: + REMBG_SESSION_CACHE[model_name] = new_session(model_name) + return REMBG_SESSION_CACHE[model_name] + + +def _resize_for_processing(image, max_size): + """按最大边缩放图片,返回缩放后的图片与缩放比例。""" + height, width = image.shape[:2] + max_size = int(max_size) if max_size else 0 + if max_size <= 0 or max(height, width) <= max_size: + return image.copy(), 1.0 + scale = max_size / float(max(height, width)) + resized = cv2.resize( + image, + (max(1, int(round(width * scale))), max(1, int(round(height * scale)))), + interpolation=cv2.INTER_AREA, + ) + return resized, scale + + +def _otsu_threshold(channel, floor=0): + """返回 Otsu 阈值,并设置一个最小下限避免纸张纹理误检。""" + if channel.size == 0 or int(channel.max()) <= 0: + return int(floor) + threshold, _ = cv2.threshold(channel, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + return max(int(round(threshold)), int(floor)) + + +def _remove_small_components(mask, min_area): + """移除面积过小的连通域,降低纸张纹理噪声。""" + if min_area <= 0 or int(mask.max()) == 0: + return mask + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + cleaned = np.zeros_like(mask) + for label in range(1, num_labels): + area = int(stats[label, cv2.CC_STAT_AREA]) + if area >= min_area: + cleaned[labels == label] = 255 + return cleaned + + +def _remove_border_frame_components( + mask, + min_width_ratio=0.7, + min_height_ratio=0.7, + max_fill_ratio=0.35, +): + """移除贴边的大型空心边框连通域,避免纸张外轮廓被误判为前景。""" + if int(mask.max()) == 0: + return mask + + height, width = mask.shape + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) + cleaned = mask.copy() + for label in range(1, num_labels): + x, y, comp_w, comp_h, area = stats[label] + touches_border = x == 0 or y == 0 or (x + comp_w) == width or (y + comp_h) == height + if not touches_border: + continue + if comp_w < int(width * float(min_width_ratio)): + continue + if comp_h < int(height * float(min_height_ratio)): + continue + fill_ratio = float(area) / float(max(1, comp_w * comp_h)) + if fill_ratio <= float(max_fill_ratio): + cleaned[labels == label] = 0 + return cleaned + + +def _artwork_mask_is_reasonable(mask): + """检查书画掩码是否可信,用于 auto 模式回退。""" + if mask.size == 0 or int(mask.max()) == 0: + return False + + coverage = float(np.count_nonzero(mask)) / float(mask.size) + if coverage < 0.0005 or coverage > 0.55: + return False + + height, width = mask.shape + border = max(4, min(height, width) // 32) + border_pixels = np.concatenate( + [ + mask[:border, :].reshape(-1), + mask[-border:, :].reshape(-1), + mask[:, :border].reshape(-1), + mask[:, -border:].reshape(-1), + ] + ) + border_ratio = float(np.count_nonzero(border_pixels)) / float(border_pixels.size) + return border_ratio <= 0.28 + + +def _compute_detail_alpha(rgb_image): + """根据局部暗度与严格红章特征生成细节 alpha。""" + height, width = rgb_image.shape[:2] + gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY) + sigma_bg = max(8.0, max(height, width) / 80.0) + bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) + dark_score = cv2.subtract(bg_gray, gray).astype(np.float32) + dark_alpha = np.clip((dark_score - 18.0) / 48.0, 0.0, 1.0) + + hsv = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2HSV) + hue = hsv[:, :, 0].astype(np.float32) + saturation = hsv[:, :, 1].astype(np.float32) + value = hsv[:, :, 2].astype(np.float32) + red = rgb_image[:, :, 0].astype(np.float32) + green = rgb_image[:, :, 1].astype(np.float32) + blue = rgb_image[:, :, 2].astype(np.float32) + red_dominance = red - np.maximum(green, blue) + red_hue_mask = ((hue <= 12.0) | (hue >= 168.0)).astype(np.float32) + seal_alpha = np.clip((red_dominance - 20.0) / 40.0, 0.0, 1.0) + seal_alpha *= np.clip((saturation - 50.0) / 80.0, 0.0, 1.0) + seal_alpha *= (value >= 40.0).astype(np.float32) + seal_alpha *= red_hue_mask + + detail_alpha = np.maximum(dark_alpha, seal_alpha) + return np.clip((detail_alpha - 0.07) / 0.93, 0.0, 1.0) + + +def _extract_artwork_mask(input_image, artwork_type="auto", max_size=1600): + """为书画/篆刻作品提取前景掩码,避免依赖通用人像分割模型。""" + rgb = np.array(input_image.convert("RGB")) + resized, scale = _resize_for_processing(rgb, max_size=max_size) + height, width = resized.shape[:2] + + gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY) + lab = cv2.cvtColor(resized, cv2.COLOR_RGB2LAB) + hsv = cv2.cvtColor(resized, cv2.COLOR_RGB2HSV) + sigma_bg = max(6.0, max(height, width) / 40.0) + + bg_gray = cv2.GaussianBlur(gray, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) + dark_score = cv2.subtract(bg_gray, gray) + + lab_a = lab[:, :, 1] + lab_b = lab[:, :, 2] + bg_a = cv2.GaussianBlur(lab_a, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) + bg_b = cv2.GaussianBlur(lab_b, (0, 0), sigmaX=sigma_bg, sigmaY=sigma_bg) + chroma_score = np.sqrt( + (lab_a.astype(np.float32) - bg_a.astype(np.float32)) ** 2 + + (lab_b.astype(np.float32) - bg_b.astype(np.float32)) ** 2 + ) + chroma_score = np.clip(chroma_score * 3.0, 0, 255).astype(np.uint8) + + red = resized[:, :, 0].astype(np.int16) + green = resized[:, :, 1].astype(np.int16) + blue = resized[:, :, 2].astype(np.int16) + hue = hsv[:, :, 0] + saturation = hsv[:, :, 1] + value = hsv[:, :, 2] + red_dominance = np.clip(red - np.maximum(green, blue), 0, 255).astype(np.uint8) + seal_score = np.clip( + red_dominance.astype(np.int16) * 2 + + np.clip(saturation.astype(np.int16) - 45, 0, 255), + 0, + 255, + ).astype(np.uint8) + seal_score = np.where( + ((hue <= 12) | (hue >= 168)) + & (saturation >= 55) + & (value >= 40) + & (red_dominance >= 20), + seal_score, + 0, + ).astype(np.uint8) + color_score = np.maximum(chroma_score, seal_score) + + if artwork_type == "seal": + combined_score = np.maximum( + np.clip(dark_score.astype(np.float32) * 1.0, 0, 255).astype(np.uint8), + np.clip(color_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8), + ) + dark_floor = 12 + color_floor = 18 + combined_floor = 20 + elif artwork_type == "calligraphy": + combined_score = np.maximum( + np.clip(dark_score.astype(np.float32) * 1.35, 0, 255).astype(np.uint8), + np.clip(color_score.astype(np.float32) * 0.8, 0, 255).astype(np.uint8), + ) + dark_floor = 14 + color_floor = 28 + combined_floor = 18 + else: + combined_score = np.maximum( + np.clip(dark_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8), + np.clip(color_score.astype(np.float32) * 1.2, 0, 255).astype(np.uint8), + ) + dark_floor = 14 + color_floor = 16 + combined_floor = 18 + + smooth_sigma = max(0.6, max(height, width) / 320.0) + dark_smooth = cv2.GaussianBlur( + dark_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma + ) + color_smooth = cv2.GaussianBlur( + color_score, (0, 0), sigmaX=smooth_sigma, sigmaY=smooth_sigma + ) + combined_smooth = cv2.GaussianBlur( + combined_score, + (0, 0), + sigmaX=smooth_sigma, + sigmaY=smooth_sigma, + ) + + dark_mask = (dark_smooth >= _otsu_threshold(dark_smooth, floor=dark_floor)).astype( + np.uint8 + ) * 255 + color_mask = ( + color_smooth >= _otsu_threshold(color_smooth, floor=color_floor) + ).astype(np.uint8) * 255 + combined_mask = ( + combined_smooth >= _otsu_threshold(combined_smooth, floor=combined_floor) + ).astype(np.uint8) * 255 + + if artwork_type == "seal": + mask = cv2.bitwise_or(combined_mask, color_mask) + mask = cv2.bitwise_or(mask, dark_mask) + elif artwork_type == "calligraphy": + mask = cv2.bitwise_or(combined_mask, dark_mask) + else: + mask = cv2.bitwise_or(combined_mask, cv2.bitwise_or(dark_mask, color_mask)) + close_kernel = np.ones((3, 3), dtype=np.uint8) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel, iterations=1) + min_area = max(12, (height * width) // 25000) + mask = _remove_small_components(mask, min_area=min_area) + mask = _remove_border_frame_components(mask) + mask = cv2.dilate(mask, close_kernel, iterations=1) + + if scale < 1.0: + mask = cv2.resize( + mask, (rgb.shape[1], rgb.shape[0]), interpolation=cv2.INTER_NEAREST + ) + return mask + + +def _mask_to_transparent_image(input_image, mask): + """将二值掩码转为带透明背景的 RGBA 图片。""" + rgb = np.array(input_image.convert("RGB")) + detail_alpha = _compute_detail_alpha(rgb) + alpha = ((mask.astype(np.float32) / 255.0) * detail_alpha * 255.0).astype(np.uint8) + alpha = cv2.GaussianBlur(alpha, (0, 0), sigmaX=0.8, sigmaY=0.8) + alpha = np.where(mask > 0, alpha, 0).astype(np.uint8) + rgba = np.dstack((rgb, alpha)) + return Image.fromarray(rgba) + + +def _extract_foreground_mask( + input_image, + session=None, + model_name="isnet-general-use", + foreground_mode="artwork", + artwork_type="auto", + artwork_max_size=1600, + mask_threshold=10, + **kwargs, +): + """按指定模式提取前景掩码。""" + if foreground_mode in {"artwork", "auto"}: + artwork_mask = _extract_artwork_mask( + input_image, + artwork_type=artwork_type, + max_size=artwork_max_size, + ) + if foreground_mode == "artwork" or _artwork_mask_is_reasonable(artwork_mask): + return artwork_mask, "artwork" + + if foreground_mode not in {"rembg", "auto"}: + raise ValueError(f"不支持的前景提取模式: {foreground_mode}") + + if remove is None: + raise ImportError( + "未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。" + ) + if session is None: + session = _get_rembg_session(model_name) + output_image = remove(input_image, session=session, **kwargs) + output_image = _ensure_rgba_size(output_image, input_image.size) + alpha = np.array(output_image.getchannel("A")) + return _alpha_to_mask(alpha, threshold=mask_threshold), "rembg" + + +def remove_background( + input_path, + output_path, + session=None, + model_name="isnet-general-use", + foreground_mode="artwork", + artwork_type="auto", + artwork_max_size=1600, + **kwargs, +): """ - 去除图片背景 - + 去除图片背景。 + Args: input_path: 输入图片路径 output_path: 输出图片路径 session: rembg会话对象(可选) + model_name: rembg 模型名称 + foreground_mode: 前景提取模式,可选 artwork / auto / rembg + artwork_type: 书画类型,可选 auto / calligraphy / seal + artwork_max_size: 书画掩码估算时的最大边 **kwargs: 其他参数,如alpha_matting相关参数 """ print(f"正在处理: {input_path}") - + # 读取输入图片 input_image = ImageOps.exif_transpose(Image.open(input_path)) - - # 去除背景 + + if foreground_mode in {"artwork", "auto"}: + mask = _extract_artwork_mask( + input_image, + artwork_type=artwork_type, + max_size=artwork_max_size, + ) + if foreground_mode == "artwork" or _artwork_mask_is_reasonable(mask): + output_image = _mask_to_transparent_image(input_image, mask) + output_image.save(output_path) + print(f"已保存: {output_path}") + return + + if remove is None: + raise ImportError( + "未找到 rembg,请先安装相关依赖或改用 --foreground-mode artwork。" + ) + if session is None: + session = _get_rembg_session(model_name) + output_image = remove(input_image, session=session, **kwargs) - - # 保存输出图片 output_image.save(output_path) - + print(f"已保存: {output_path}") + def _pil_to_bgr(image): """将PIL图片转换为OpenCV BGR格式""" rgb = image.convert("RGB") arr = np.array(rgb) return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR) + def _alpha_to_mask(alpha_channel, threshold=10): """将alpha通道转换为二值mask(0或255)""" mask = (alpha_channel > threshold).astype(np.uint8) * 255 return mask + def _ensure_rgba_size(image, target_size): if image.mode != "RGBA": image = image.convert("RGBA") @@ -335,6 +679,7 @@ def _ensure_rgba_size(image, target_size): image = image.resize(target_size, resample=Image.NEAREST) return image + def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0): """生成硬边mask与处理后mask(用于填补/过渡)""" if mask_dilate and mask_dilate > 0: @@ -352,10 +697,15 @@ def _prepare_mask(mask, mask_dilate=3, mask_blur=3, edge_grow=0): mask_used = (mask_used > 0).astype(np.uint8) * 255 return mask_hard, mask_used + def remove_subject_and_inpaint( input_path, output_path, session=None, + model_name="isnet-general-use", + foreground_mode="artwork", + artwork_type="auto", + artwork_max_size=1600, mask_dilate=3, mask_blur=3, mask_threshold=10, @@ -388,6 +738,10 @@ def remove_subject_and_inpaint( input_path: 输入图片路径 output_path: 输出图片路径 session: rembg会话对象(可选) + model_name: rembg 模型名称 + foreground_mode: 前景提取模式,可选 artwork / auto / rembg + artwork_type: 书画类型,可选 auto / calligraphy / seal + artwork_max_size: 书画掩码估算时的最大边 aot_root: AOT-GAN目录 aot_pretrain: AOT-GAN权重文件路径 aot_device: AOT-GAN设备 @@ -409,12 +763,16 @@ def remove_subject_and_inpaint( input_image = ImageOps.exif_transpose(Image.open(input_path)) - # 使用rembg获取主体mask - output_image = remove(input_image, session=session, **kwargs) - output_image = _ensure_rgba_size(output_image, input_image.size) - - alpha = np.array(output_image.getchannel("A")) - mask = _alpha_to_mask(alpha, threshold=mask_threshold) + mask, mask_source = _extract_foreground_mask( + input_image, + session=session, + model_name=model_name, + foreground_mode=foreground_mode, + artwork_type=artwork_type, + artwork_max_size=artwork_max_size, + mask_threshold=mask_threshold, + **kwargs, + ) bgr = _pil_to_bgr(input_image) if black_subject: @@ -425,7 +783,9 @@ def remove_subject_and_inpaint( hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV) s = hsv[:, :, 1] v = hsv[:, :, 2] - gray_mask = ((s <= gray_saturation_threshold) & (v <= gray_value_threshold)).astype(np.uint8) * 255 + gray_mask = ( + (s <= gray_saturation_threshold) & (v <= gray_value_threshold) + ).astype(np.uint8) * 255 mask = cv2.bitwise_or(mask, gray_mask) mask_hard, mask_used = _prepare_mask( @@ -457,7 +817,9 @@ def remove_subject_and_inpaint( dist_out = cv2.distanceTransform(255 - mask_bin, cv2.DIST_L2, 3) alpha = np.ones_like(dist_out, dtype=np.float32) outside = mask_bin == 0 - alpha[outside] = np.clip(1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0) + alpha[outside] = np.clip( + 1.0 - (dist_out[outside] / float(feather_radius)), 0.0, 1.0 + ) alpha = alpha[:, :, None] blended = (alpha * filled + (1.0 - alpha) * bgr).astype(np.uint8) result = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) @@ -471,12 +833,17 @@ def remove_subject_and_inpaint( Image.fromarray(mask_used).save(mask_output_path) print(f"已保存mask: {mask_output_path}") + print(f"前景提取模式: {mask_source}") print(f"已保存: {output_path}") + def process_images_folder( input_folder, output_folder, - model_name="u2net", + model_name="isnet-general-use", + foreground_mode="artwork", + artwork_type="auto", + artwork_max_size=1600, alpha_matting=False, alpha_matting_foreground_threshold=240, alpha_matting_background_threshold=10, @@ -508,11 +875,11 @@ def process_images_folder( ): """ 批量处理文件夹中的所有图片 - + Args: input_folder: 输入文件夹路径 output_folder: 输出文件夹路径 - model_name: 模型名称,可选值: + model_name: rembg 模型名称,可选值: - u2net (默认): 通用模型 - u2netp: 轻量版u2net - u2net_human_seg: 人物分割 @@ -521,6 +888,9 @@ def process_images_folder( - isnet-anime: 动漫角色高精度分割 - birefnet-general: 通用模型 - birefnet-portrait: 人像模型 + foreground_mode: 前景提取模式,可选 artwork / auto / rembg + artwork_type: 书画类型,可选 auto / calligraphy / seal + artwork_max_size: 书画掩码估算时的最大边 alpha_matting: 是否启用alpha matting后处理(改善边缘质量) alpha_matting_foreground_threshold: 前景阈值 (0-255),值越大保留越多前景 alpha_matting_background_threshold: 背景阈值 (0-255),值越大去除越多背景 @@ -529,27 +899,33 @@ def process_images_folder( """ # 创建输出文件夹 Path(output_folder).mkdir(parents=True, exist_ok=True) - - # 创建会话(重用会话可以提高性能) - print(f"使用模型: {model_name}") - session = new_session(model_name) - + + print(f"前景提取模式: {foreground_mode}") + print(f"书画类型: {artwork_type}") + if foreground_mode in {"rembg", "auto"}: + print(f"rembg模型: {model_name}") + # 支持的图片格式 - image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'} + image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"} if HEIC_SUPPORTED: - image_extensions.update({'.heic', '.heif'}) + image_extensions.update({".heic", ".heif"}) else: - print("提示: 未安装pillow-heif,HEIC格式不可用。安装方法: pip install pillow-heif") - + print( + "提示: 未安装pillow-heif,HEIC格式不可用。安装方法: pip install pillow-heif" + ) + # 获取所有图片文件 input_path = Path(input_folder) - image_files = [f for f in input_path.iterdir() - if f.is_file() and f.suffix.lower() in image_extensions] - + image_files = [ + f + for f in input_path.iterdir() + if f.is_file() and f.suffix.lower() in image_extensions + ] + if not image_files: print(f"在 {input_folder} 中没有找到图片文件") return - + print(f"找到 {len(image_files)} 张图片,开始处理...") print(f"Alpha Matting: {'启用' if alpha_matting else '禁用'}") if alpha_matting: @@ -587,7 +963,7 @@ def process_images_folder( print(f" - 过渡半径: {feather_radius}") print(f" - 保存mask: {'是' if save_mask else '否'}") print("-" * 50) - + # 处理每张图片 for i, image_file in enumerate(image_files, 1): try: @@ -601,16 +977,21 @@ def process_images_folder( # 去背景默认使用PNG格式以支持透明背景 output_filename = image_file.stem + "_nobg.png" output_path = Path(output_folder) / output_filename - + print(f"[{i}/{len(image_files)}] ", end="") if remove_subject: mask_output_path = None if save_mask: - mask_output_path = str(Path(output_folder) / (image_file.stem + "_mask.png")) + mask_output_path = str( + Path(output_folder) / (image_file.stem + "_mask.png") + ) remove_subject_and_inpaint( str(image_file), str(output_path), - session=session, + model_name=model_name, + foreground_mode=foreground_mode, + artwork_type=artwork_type, + artwork_max_size=artwork_max_size, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, @@ -644,25 +1025,29 @@ def process_images_folder( remove_background( str(image_file), str(output_path), - session=session, + model_name=model_name, + foreground_mode=foreground_mode, + artwork_type=artwork_type, + artwork_max_size=artwork_max_size, alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size, post_process_mask=post_process_mask, ) - + except Exception as e: print(f"处理 {image_file.name} 时出错: {e}") - + print("-" * 50) print(f"处理完成!结果保存在 {output_folder} 文件夹中") + if __name__ == "__main__": parser = argparse.ArgumentParser( - description='图片去背景工具 - 使用rembg自动去除图片背景', + description="书画与篆刻作品去背景工具 - 默认使用轻量书画掩码提取", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=''' + epilog=""" 示例用法: # 使用默认参数处理images文件夹 python remove_background.py @@ -673,104 +1058,233 @@ if __name__ == "__main__": # 处理指定文件夹 python remove_background.py my_images/ my_output/ - # 使用不同模型 - python remove_background.py input.jpg output.png -m birefnet-portrait + # 强制使用 rembg 旧流程 + python remove_background.py input.jpg output.png --foreground-mode rembg -m isnet-general-use + + # 指定为篆刻模式 + python remove_background.py input.jpg output.png --artwork-type seal # 自定义alpha matting参数 python remove_background.py input.jpg output.png -ft 260 -bt 12 -es 5 - ''' + """, ) - + # 必需参数 - parser.add_argument('input', nargs='?', default='images', - help='输入文件或文件夹路径(默认: images)') - parser.add_argument('output', nargs='?', default=None, - help='输出文件或文件夹路径(可选,默认为output/)') - + parser.add_argument( + "input", + nargs="?", + default="images", + help="输入文件或文件夹路径(默认: images)", + ) + parser.add_argument( + "output", + nargs="?", + default=None, + help="输出文件或文件夹路径(可选,默认为output/)", + ) + # 模型选择 - parser.add_argument('-m', '--model', default='isnet-general-use', - choices=['u2net', 'u2netp', 'u2net_human_seg', 'silueta', - 'isnet-general-use', 'isnet-anime', - 'birefnet-general', 'birefnet-general-lite', - 'birefnet-portrait', 'birefnet-dis', - 'birefnet-hrsod', 'birefnet-cod', 'birefnet-massive'], - help='选择使用的模型 (默认: isnet-general-use)') - + parser.add_argument( + "-m", + "--model", + default="isnet-general-use", + choices=[ + "u2net", + "u2netp", + "u2net_human_seg", + "silueta", + "isnet-general-use", + "isnet-anime", + "birefnet-general", + "birefnet-general-lite", + "birefnet-portrait", + "birefnet-dis", + "birefnet-hrsod", + "birefnet-cod", + "birefnet-massive", + ], + help="选择 rembg 模型(仅在 --foreground-mode rembg/auto 时使用)", + ) + parser.add_argument( + "--foreground-mode", + default="artwork", + choices=["artwork", "auto", "rembg"], + help="前景提取模式:artwork 为书画专用快速模式,auto 失败时回退 rembg,rembg 为旧流程 (默认: artwork)", + ) + parser.add_argument( + "--artwork-type", + default="auto", + choices=["auto", "calligraphy", "seal"], + help="书画类型:auto 自动兼容书法与篆刻,seal 更偏重红色印章,calligraphy 更偏重墨色笔画 (默认: auto)", + ) + parser.add_argument( + "--artwork-max-size", + type=int, + default=1600, + help="书画掩码估算时的最大边,越小越快,越大越精细 (默认: 1600)", + ) + # Alpha Matting参数 - parser.add_argument('-a', '--alpha-matting', '--alpha_matting', action='store_true', - help='启用alpha matting后处理(默认: false)') - parser.add_argument('-ft', '--foreground-threshold', type=int, default=245, - help='前景阈值 (0-255),值越大保留越多细节 (默认: 245)') - parser.add_argument('-bt', '--background-threshold', type=int, default=8, - help='背景阈值 (0-255),值越大去除越多背景 (默认: 8)') - parser.add_argument('-es', '--erode-size', type=int, default=2, - help='侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)') - + parser.add_argument( + "-a", + "--alpha-matting", + "--alpha_matting", + action="store_true", + help="启用alpha matting后处理(默认: false)", + ) + parser.add_argument( + "-ft", + "--foreground-threshold", + type=int, + default=245, + help="前景阈值 (0-255),值越大保留越多细节 (默认: 245)", + ) + parser.add_argument( + "-bt", + "--background-threshold", + type=int, + default=8, + help="背景阈值 (0-255),值越大去除越多背景 (默认: 8)", + ) + parser.add_argument( + "-es", + "--erode-size", + type=int, + default=2, + help="侵蚀大小,用于平滑边缘,值越大越平滑但可能丢失细节 (默认: 2)", + ) + # 其他选项 - parser.add_argument('-p', '--post-process', '--post_process', action='store_true', - help='启用mask后处理(默认: false)') + parser.add_argument( + "-p", + "--post-process", + "--post_process", + action="store_true", + help="启用mask后处理(默认: false)", + ) # 去主体补背景参数 - parser.add_argument('--remove-subject', '--remove_subject', action='store_true', - help='去掉主体并补全背景(默认: false)') - parser.add_argument('--aot-root', type=str, default='AOT-GAN-for-Inpainting', - help='AOT-GAN目录(默认: AOT-GAN-for-Inpainting)') - parser.add_argument('--aot-pretrain', type=str, default=None, - help='AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)') - parser.add_argument('--aot-device', type=str, default='cpu', - help='AOT-GAN设备(默认: cpu)') - parser.add_argument('--aot-block-num', type=int, default=8, - help='AOTBlock数量(默认: 8)') - parser.add_argument('--aot-rates', type=str, default='1+2+4+8', - help='AOTBlock膨胀率(默认: 1+2+4+8)') - parser.add_argument('--aot-crop', action='store_true', - help='AOT仅对mask区域裁剪修补(默认: false)') - parser.add_argument('--aot-crop-pad', type=int, default=0, - help='AOT裁剪边缘留白像素(默认: 0)') - parser.add_argument('--aot-max-size', type=int, default=0, - help='AOT输入最大边限制,0为不限制(默认: 0)') - parser.add_argument('--aot-noise-prefill', action='store_true', - help='AOT使用随机噪声预填充(默认: false)') - parser.add_argument('--aot-noise-strength', type=float, default=1.0, - help='AOT噪声强度系数(默认: 1.0)') - parser.add_argument('--mask-dilate', type=int, default=3, - help='mask膨胀大小(默认: 3)') - parser.add_argument('--mask-blur', type=int, default=3, - help='mask模糊大小(默认: 3,建议奇数)') - parser.add_argument('--mask-threshold', type=int, default=10, - help='alpha阈值(默认: 10)') - parser.add_argument('--edge-grow', type=int, default=0, - help='主体边缘扩张像素(默认: 0)') - parser.add_argument('--save-mask', '--save_mask', action='store_true', - help='保存mask到output目录(默认: false)') - parser.add_argument('--black-subject', '--black_subject', action='store_true', - help='将黑色内容也视为主体(默认: false)') - parser.add_argument('--black-threshold', type=int, default=50, - help='黑色阈值(0-255,灰度越小越黑,默认: 50)') - parser.add_argument('--gray-subject', '--gray_subject', action='store_true', - help='将灰阶内容也视为主体(默认: false)') - parser.add_argument('--gray-saturation-threshold', type=int, default=30, - help='灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)') - parser.add_argument('--gray-value-threshold', type=int, default=200, - help='灰阶亮度阈值(0-255,越小越暗,默认: 200)') - parser.add_argument('--feather', action='store_true', - help='启用边缘过渡融合(默认: false)') - parser.add_argument('--feather-radius', type=int, default=5, - help='边缘过渡半径(默认: 5)') - + parser.add_argument( + "--remove-subject", + "--remove_subject", + action="store_true", + help="去掉主体并补全背景(默认: false)", + ) + parser.add_argument( + "--aot-root", + type=str, + default="AOT-GAN-for-Inpainting", + help="AOT-GAN目录(默认: AOT-GAN-for-Inpainting)", + ) + parser.add_argument( + "--aot-pretrain", + type=str, + default=None, + help="AOT-GAN预训练权重文件路径(必填,相对路径基于当前目录)", + ) + parser.add_argument( + "--aot-device", type=str, default="cpu", help="AOT-GAN设备(默认: cpu)" + ) + parser.add_argument( + "--aot-block-num", type=int, default=8, help="AOTBlock数量(默认: 8)" + ) + parser.add_argument( + "--aot-rates", + type=str, + default="1+2+4+8", + help="AOTBlock膨胀率(默认: 1+2+4+8)", + ) + parser.add_argument( + "--aot-crop", action="store_true", help="AOT仅对mask区域裁剪修补(默认: false)" + ) + parser.add_argument( + "--aot-crop-pad", type=int, default=0, help="AOT裁剪边缘留白像素(默认: 0)" + ) + parser.add_argument( + "--aot-max-size", + type=int, + default=0, + help="AOT输入最大边限制,0为不限制(默认: 0)", + ) + parser.add_argument( + "--aot-noise-prefill", + action="store_true", + help="AOT使用随机噪声预填充(默认: false)", + ) + parser.add_argument( + "--aot-noise-strength", + type=float, + default=1.0, + help="AOT噪声强度系数(默认: 1.0)", + ) + parser.add_argument( + "--mask-dilate", type=int, default=3, help="mask膨胀大小(默认: 3)" + ) + parser.add_argument( + "--mask-blur", type=int, default=3, help="mask模糊大小(默认: 3,建议奇数)" + ) + parser.add_argument( + "--mask-threshold", type=int, default=10, help="alpha阈值(默认: 10)" + ) + parser.add_argument( + "--edge-grow", type=int, default=0, help="主体边缘扩张像素(默认: 0)" + ) + parser.add_argument( + "--save-mask", + "--save_mask", + action="store_true", + help="保存mask到output目录(默认: false)", + ) + parser.add_argument( + "--black-subject", + "--black_subject", + action="store_true", + help="将黑色内容也视为主体(默认: false)", + ) + parser.add_argument( + "--black-threshold", + type=int, + default=50, + help="黑色阈值(0-255,灰度越小越黑,默认: 50)", + ) + parser.add_argument( + "--gray-subject", + "--gray_subject", + action="store_true", + help="将灰阶内容也视为主体(默认: false)", + ) + parser.add_argument( + "--gray-saturation-threshold", + type=int, + default=30, + help="灰阶饱和度阈值(0-255,越小越接近灰阶,默认: 30)", + ) + parser.add_argument( + "--gray-value-threshold", + type=int, + default=200, + help="灰阶亮度阈值(0-255,越小越暗,默认: 200)", + ) + parser.add_argument( + "--feather", action="store_true", help="启用边缘过渡融合(默认: false)" + ) + parser.add_argument( + "--feather-radius", type=int, default=5, help="边缘过渡半径(默认: 5)" + ) + args = parser.parse_args() - + print("=" * 50) print("图片去背景工具") print("=" * 50) - + # 判断输入是文件还是文件夹 input_path = Path(args.input) - + if not input_path.exists(): print(f"错误: 输入路径不存在: {args.input}") exit(1) - + # 处理单个文件 if input_path.is_file(): suffix = input_path.suffix.lower() @@ -787,7 +1301,7 @@ if __name__ == "__main__": if args.remove_subject: output_path = input_path.parent / "output" / output_name else: - output_path = input_path.parent / 'output' / output_name + output_path = input_path.parent / "output" / output_name output_path.parent.mkdir(parents=True, exist_ok=True) else: output_candidate = Path(args.output) @@ -800,10 +1314,14 @@ if __name__ == "__main__": else: output_path = output_candidate output_path.parent.mkdir(parents=True, exist_ok=True) - + print(f"输入文件: {input_path}") print(f"输出文件: {output_path}") - print(f"模型: {args.model}") + print(f"前景提取模式: {args.foreground_mode}") + print(f"书画类型: {args.artwork_type}") + print(f"书画最大边: {args.artwork_max_size}") + if args.foreground_mode in {"rembg", "auto"}: + print(f"rembg模型: {args.model}") print(f"Alpha Matting: {'启用' if args.alpha_matting else '禁用'}") if args.alpha_matting: print(f" - 前景阈值: {args.foreground_threshold}") @@ -840,10 +1358,7 @@ if __name__ == "__main__": print(f" - 过渡半径: {args.feather_radius}") print(f" - 保存mask: {'是' if args.save_mask else '否'}") print("-" * 50) - - # 创建会话 - session = new_session(args.model) - + # 处理图片 if args.remove_subject: mask_output_path = None @@ -852,7 +1367,10 @@ if __name__ == "__main__": remove_subject_and_inpaint( str(input_path), str(output_path), - session=session, + model_name=args.model, + foreground_mode=args.foreground_mode, + artwork_type=args.artwork_type, + artwork_max_size=args.artwork_max_size, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, @@ -886,25 +1404,31 @@ if __name__ == "__main__": remove_background( str(input_path), str(output_path), - session=session, + model_name=args.model, + foreground_mode=args.foreground_mode, + artwork_type=args.artwork_type, + artwork_max_size=args.artwork_max_size, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, alpha_matting_erode_size=args.erode_size, post_process_mask=args.post_process, ) - + print("-" * 50) print(f"处理完成!结果保存在: {output_path}") - + # 处理文件夹 elif input_path.is_dir(): - output_folder = args.output if args.output else 'output' - + output_folder = args.output if args.output else "output" + process_images_folder( str(input_path), output_folder, model_name=args.model, + foreground_mode=args.foreground_mode, + artwork_type=args.artwork_type, + artwork_max_size=args.artwork_max_size, alpha_matting=args.alpha_matting, alpha_matting_foreground_threshold=args.foreground_threshold, alpha_matting_background_threshold=args.background_threshold, @@ -934,7 +1458,7 @@ if __name__ == "__main__": feather_radius=args.feather_radius, save_mask=args.save_mask, ) - + else: print(f"错误: 不支持的输入类型: {args.input}") exit(1) diff --git a/tests/test_remove_background.py b/tests/test_remove_background.py index b9469e8..d0d9824 100644 --- a/tests/test_remove_background.py +++ b/tests/test_remove_background.py @@ -1,7 +1,9 @@ import importlib import sys +import tempfile import types import unittest +from pathlib import Path def _import_remove_background(): @@ -18,6 +20,15 @@ class RemoveBackgroundTests(unittest.TestCase): def setUp(self): self.mod = _import_remove_background() + def _make_paper_image(self, width=160, height=120): + rng = self.mod.np.random.default_rng(0) + base = self.mod.np.full((height, width, 3), 230, dtype=self.mod.np.uint8) + noise = rng.integers(-8, 9, size=(height, width, 3), dtype=self.mod.np.int16) + image = self.mod.np.clip(base.astype(self.mod.np.int16) + noise, 0, 255).astype( + self.mod.np.uint8 + ) + return image + def test_alpha_to_mask_threshold(self): alpha = self.mod.np.array([[0, 9, 10, 255]], dtype=self.mod.np.uint8) mask = self.mod._alpha_to_mask(alpha, threshold=10) @@ -67,12 +78,126 @@ class RemoveBackgroundTests(unittest.TestCase): expanded_min = self.mod._expand_bbox_min_size(bbox, 0, 5, 5, 4) self.assertEqual(expanded_min, (1, 0, 4, 3)) + def test_remove_border_frame_components_keeps_inner_content(self): + mask = self.mod.np.zeros((12, 12), dtype=self.mod.np.uint8) + mask[0, :] = 255 + mask[-1, :] = 255 + mask[:, 0] = 255 + mask[:, -1] = 255 + mask[4:8, 5:7] = 255 + + cleaned = self.mod._remove_border_frame_components( + mask, + min_width_ratio=0.7, + min_height_ratio=0.7, + max_fill_ratio=0.35, + ) + + self.assertEqual(int(cleaned[0, 0]), 0) + self.assertGreater(cleaned[4:8, 5:7].mean(), 200) + def test_ensure_rgba_size_resizes(self): img = self.mod.Image.new("RGB", (2, 3), color=(0, 0, 0)) out = self.mod._ensure_rgba_size(img, (4, 5)) self.assertEqual(out.mode, "RGBA") self.assertEqual(out.size, (4, 5)) + def test_extract_artwork_mask_detects_calligraphy_strokes(self): + image = self._make_paper_image() + self.mod.cv2.line(image, (20, 20), (130, 90), (20, 20, 20), thickness=6) + self.mod.cv2.line(image, (45, 15), (45, 105), (30, 30, 30), thickness=5) + + mask = self.mod._extract_artwork_mask( + self.mod.Image.fromarray(image), + artwork_type="calligraphy", + max_size=120, + ) + + self.assertGreater(mask[25:95, 42:49].mean(), 180) + self.assertLess(mask[:15, :15].mean(), 20) + + def test_extract_artwork_mask_detects_red_seal(self): + image = self._make_paper_image() + self.mod.cv2.rectangle(image, (40, 30), (120, 95), (165, 30, 40), thickness=5) + self.mod.cv2.line(image, (55, 40), (105, 85), (170, 25, 35), thickness=6) + + mask = self.mod._extract_artwork_mask( + self.mod.Image.fromarray(image), + artwork_type="seal", + max_size=120, + ) + + self.assertGreater(mask[38:92, 38:122].mean(), 40) + self.assertLess(mask[:15, :15].mean(), 20) + + def test_extract_artwork_mask_ignores_warm_paper_cast(self): + image = self.mod.np.full( + (120, 160, 3), (222, 188, 150), dtype=self.mod.np.uint8 + ) + self.mod.cv2.line(image, (30, 20), (30, 100), (28, 28, 28), thickness=6) + self.mod.cv2.line(image, (55, 20), (125, 95), (32, 32, 32), thickness=5) + + mask = self.mod._extract_artwork_mask( + self.mod.Image.fromarray(image), + artwork_type="calligraphy", + max_size=120, + ) + + self.assertGreater(mask[25:100, 26:34].mean(), 180) + self.assertLess(mask[:12, :12].mean(), 20) + + def test_mask_to_transparent_image_refines_alpha_inside_coarse_mask(self): + image = self.mod.np.full( + (120, 160, 3), (224, 190, 156), dtype=self.mod.np.uint8 + ) + self.mod.cv2.line(image, (40, 18), (40, 102), (24, 24, 24), thickness=6) + self.mod.cv2.rectangle(image, (102, 24), (132, 54), (170, 30, 40), thickness=-1) + + coarse_mask = self.mod.np.zeros((120, 160), dtype=self.mod.np.uint8) + coarse_mask[10:110, 20:140] = 255 + + out = self.mod._mask_to_transparent_image( + self.mod.Image.fromarray(image), + coarse_mask, + ) + alpha = self.mod.np.array(out.getchannel("A")) + + self.assertGreater(alpha[25:95, 36:45].mean(), 110) + self.assertGreater(alpha[30:50, 108:128].mean(), 90) + self.assertLess(alpha[70:90, 70:90].mean(), 25) + + def test_remove_background_artwork_mode_skips_rembg(self): + image = self._make_paper_image(width=100, height=80) + self.mod.cv2.line(image, (15, 15), (85, 60), (20, 20, 20), thickness=5) + + original_remove = self.mod.remove + + def _unexpected_remove(*args, **kwargs): + raise AssertionError("artwork 模式不应调用 rembg") + + self.mod.remove = _unexpected_remove + try: + with tempfile.TemporaryDirectory() as tmpdir: + input_path = Path(tmpdir) / "input.png" + output_path = Path(tmpdir) / "output.png" + self.mod.Image.fromarray(image).save(input_path) + + self.mod.remove_background( + str(input_path), + str(output_path), + foreground_mode="artwork", + artwork_type="calligraphy", + artwork_max_size=120, + ) + + out = self.mod.Image.open(output_path) + alpha = self.mod.np.array(out.getchannel("A")) + self.assertEqual(out.mode, "RGBA") + self.assertGreater(alpha[30:55, 35:70].mean(), 70) + self.assertLess(alpha[:10, :10].mean(), 10) + finally: + self.mod.remove = original_remove + if __name__ == "__main__": unittest.main()