first commit
This commit is contained in:
23
rtdetr_pytorch/src/nn/backbone/regnet.py
Normal file
23
rtdetr_pytorch/src/nn/backbone/regnet.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import RegNetModel
|
||||
|
||||
|
||||
from src.core import register
|
||||
|
||||
__all__ = ['RegNet']
|
||||
|
||||
@register
|
||||
class RegNet(nn.Module):
|
||||
def __init__(self, configuration, return_idx=[0, 1, 2, 3]):
|
||||
super(RegNet, self).__init__()
|
||||
self.model = RegNetModel.from_pretrained("facebook/regnet-y-040")
|
||||
self.return_idx = return_idx
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
outputs = self.model(x, output_hidden_states = True)
|
||||
x = outputs.hidden_states[2:5]
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user