7.8 KiB
This model was published in HF papers on 2024-07-24 and contributed to Hugging Face Transformers on 2026-05-07.
RF-DETR
RF-DETR proposes a Receptive Field Detection Transformer (DETR) architecture designed to compete with and surpass the dominant YOLO series for real-time object detection. It achieves a new state-of-the-art balance between speed (latency) and accuracy (mAP) by combining recent transformer advances with efficient design choices.
The RF-DETR architecture is characterized by its simple and efficient structure: a DINOv2 Backbone, a Projector, and a shallow DETR Decoder. It enhances the DETR architecture for efficiency and speed using the following core modifications:
- DINOv2 Backbone: Uses a powerful DINOv2 backbone for robust feature extraction.
- Group DETR Training: Utilizes Group-Wise One-to-Many Assignment during training to accelerate convergence.
- Richer Input: Aggregates multi-level features from the backbone and uses a C2f Projector (similarly to YOLOv8) to pass multi-scale features.
- Faster Decoder: Employs a shallow 3-layer DETR decoder with deformable cross-attention for lower latency.
- Optimized Queries: Uses a mixed-query scheme combining learnable content queries and generated spatial queries.
You can find all the available RF-DETR checkpoints under the Roboflow organization. The original code can be found here.
Thanks to the weight conversion mapping, RfDetr is compatible with models from the original
rf-detr library as well as models that you trained using the
Roboflow platform. This means you can use Roboflow platform to train your model and use
RfDetr in transformers to import the weights and deploy your model anywhere.
Tip
Click on the RF-DETR models in the right sidebar for more examples of how to apply RF-DETR to different object detection tasks.
The example below demonstrates how to perform object detection with the [Pipeline] and the [AutoModel] class.
from transformers import pipeline
import torch
pipeline = pipeline("object-detection", model="Roboflow/rf-detr-medium", device_map="auto")
pipeline("http://images.cocodataset.org/val2017/000000039769.jpg")
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image
import requests
import torch
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("Roboflow/rf-detr-medium")
model = AutoModelForObjectDetection.from_pretrained("Roboflow/rf-detr-medium", device_map="auto")
# prepare image for the model
inputs = image_processor(images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)
for result in results:
for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
score, label = score.item(), label_id.item()
box = [round(i, 2) for i in box.tolist()]
print(f"{model.config.id2label[label]}: {score:.2f} {box}")
Visualizing results with supervision
You can use the supervision library to visualize detection and segmentation results. Install it with pip install supervision.
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image
import supervision as sv
import requests
import torch
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("Roboflow/rf-detr-medium")
model = AutoModelForObjectDetection.from_pretrained("Roboflow/rf-detr-medium", device_map="auto")
inputs = image_processor(images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_object_detection(
outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3
)[0]
detections = sv.Detections.from_transformers(
transformers_results=results, id2label=model.config.id2label
)
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
annotated_image = image.copy()
annotated_image = box_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections)
sv.plot_image(annotated_image)
from transformers import AutoImageProcessor, AutoModelForInstanceSegmentation
from PIL import Image
import supervision as sv
import requests
import torch
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("Roboflow/rf-detr-seg-medium")
model = AutoModelForInstanceSegmentation.from_pretrained("Roboflow/rf-detr-seg-medium", device_map="auto")
inputs = image_processor(images=image, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
results = image_processor.post_process_instance_segmentation(
outputs, target_sizes=[image.size[::-1]], threshold=0.3
)[0]
detections = sv.Detections.from_transformers(
transformers_results=results, id2label=model.config.id2label
)
mask_annotator = sv.MaskAnnotator()
label_annotator = sv.LabelAnnotator()
annotated_image = image.copy()
annotated_image = mask_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections)
sv.plot_image(annotated_image)
Resources
- Scripts for finetuning [
RfDetrForObjectDetection] with [Trainer] or Accelerate can be found here. - See also: Object detection task guide.
RfDetrConfig
autodoc RfDetrConfig
RfDetrDinov2Config
autodoc RfDetrDinov2Config
RfDetrImageProcessor
autodoc RfDetrImageProcessor - preprocess - post_process_object_detection - post_process_instance_segmentation
RfDetrModel
autodoc RfDetrModel - forward
RfDetrForObjectDetection
autodoc RfDetrForObjectDetection - forward
RfDetrForInstanceSegmentation
autodoc RfDetrForInstanceSegmentation - forward
RfDetrDinov2Backbone
autodoc RfDetrDinov2Backbone - forward