Files
RT-DETR/rtdetr_pytorch/src/nn/backbone/regnet.py
2026-06-03 12:42:47 +08:00

23 lines
533 B
Python

import torch
import torch.nn as nn
from transformers import RegNetModel
from src.core import register
__all__ = ['RegNet']
@register
class RegNet(nn.Module):
def __init__(self, configuration, return_idx=[0, 1, 2, 3]):
super(RegNet, self).__init__()
self.model = RegNetModel.from_pretrained("facebook/regnet-y-040")
self.return_idx = return_idx
def forward(self, x):
outputs = self.model(x, output_hidden_states = True)
x = outputs.hidden_states[2:5]
return x