Files
firerec/classify_images.py
2026-01-06 14:20:17 +08:00

198 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
gavin 2026-01-06
"""
from __future__ import annotations
import argparse
import shutil
from datetime import datetime
from pathlib import Path
from typing import Iterable, List, Tuple
from PIL import Image, ImageChops, UnidentifiedImageError
def is_color_image(image_path: Path, tolerance: int = 3) -> bool:
"""
Return True if the image contains color information beyond grayscale.
The tolerance parameter allows minor compression artifacts without
misclassifying grayscale images as color.
"""
with Image.open(image_path) as img:
rgb = img.convert("RGB")
r, g, b = rgb.split()
diff_rg = ImageChops.difference(r, g)
diff_rb = ImageChops.difference(r, b)
max_diff = max(diff_rg.getextrema()[1], diff_rb.getextrema()[1])
return max_diff > tolerance
def iter_image_files(directory: Path) -> List[Path]:
return [path for path in sorted(directory.iterdir()) if path.is_file()]
def unique_destination(target_dir: Path, filename: str) -> Path:
candidate = target_dir / filename
if not candidate.exists():
return candidate
stem = candidate.stem
suffix = candidate.suffix
index = 1
while True:
candidate = target_dir / f"{stem}_{index}{suffix}"
if not candidate.exists():
return candidate
index += 1
def classify_images(
input_dir: Path,
color_dir: Path,
bw_dir: Path,
*,
move_files: bool = False,
tolerance: int = 3,
dry_run: bool = False,
) -> Tuple[int, int, int]:
"""
Classify images in input_dir into color or black/white destinations.
Returns a tuple of (color_count, bw_count, skipped_count).
"""
color_dir.mkdir(parents=True, exist_ok=True)
bw_dir.mkdir(parents=True, exist_ok=True)
color_count = bw_count = skipped_count = 0
image_paths = iter_image_files(input_dir)
total = len(image_paths)
def print_progress(done: int) -> None:
if total == 0:
return
if done % 10 != 0 and done != total:
return
percent = done / total * 100
print(f"处理进度: {done}/{total} ({percent:5.1f}%)")
processed = 0
for image_path in image_paths:
try:
has_color = is_color_image(image_path, tolerance=tolerance)
except (UnidentifiedImageError, OSError) as exc:
print(f"跳过无法读取的文件: {image_path} ({exc})")
skipped_count += 1
continue
target_dir = color_dir if has_color else bw_dir
destination = unique_destination(target_dir, image_path.name)
if dry_run:
action = "彩色" if has_color else "黑白"
print(f"[dry-run] {image_path} -> {action} ({destination.name})")
else:
if move_files:
shutil.move(str(image_path), destination)
else:
shutil.copy2(image_path, destination)
if has_color:
color_count += 1
else:
bw_count += 1
processed += 1
print_progress(processed)
return color_count, bw_count, skipped_count
def build_parser() -> argparse.ArgumentParser:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
default_color_dir = Path("data") / f"c-{timestamp}"
default_bw_dir = Path("data") / f"b-{timestamp}"
parser = argparse.ArgumentParser(
description="将图片按彩色或黑白分类到指定目录。"
)
parser.add_argument(
"-i",
"--input-dir",
type=Path,
default=Path("data/color/mix"),
help="待分类的图片目录(默认: data/color/mix",
)
parser.add_argument(
"-c",
"--color-dir",
type=Path,
default=default_color_dir,
help=f"彩色图片输出目录(默认: {default_color_dir})。",
)
parser.add_argument(
"-b",
"--bw-dir",
type=Path,
default=default_bw_dir,
help=f"黑白/灰阶图片输出目录(默认: {default_bw_dir})。",
)
parser.add_argument(
"-t",
"--tolerance",
type=int,
default=3,
help="颜色通道差异阈值,越大越宽松(默认: 3",
)
parser.add_argument(
"--move",
dest="move_files",
action="store_true",
help="移动文件而非复制,谨慎使用。",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="仅输出分类结果,不执行复制/移动。",
)
return parser
def exit_with_usage(parser: argparse.ArgumentParser, message: str) -> None:
# 额外打印 usage便于用户看到用法提示
parser.print_usage()
parser.exit(2, f"{parser.prog}: error: {message}\n")
def main() -> None:
parser = build_parser()
args = parser.parse_args()
if not args.input_dir.exists():
exit_with_usage(parser, f"输入目录不存在: {args.input_dir}")
if not args.input_dir.is_dir():
exit_with_usage(parser, f"输入路径不是目录: {args.input_dir}")
if args.tolerance < 0:
exit_with_usage(parser, "tolerance 需为非负整数。")
color_count, bw_count, skipped_count = classify_images(
args.input_dir,
args.color_dir,
args.bw_dir,
move_files=args.move_files,
tolerance=args.tolerance,
dry_run=args.dry_run,
)
summary = (
f"彩色: {color_count} 张,"
f"黑白/灰阶: {bw_count} 张,"
f"跳过: {skipped_count} 张。"
)
print(summary)
if __name__ == "__main__":
main()