first commit
This commit is contained in:
14
rtdetr_pytorch/src/data/cifar10/__init__.py
Normal file
14
rtdetr_pytorch/src/data/cifar10/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
|
||||
import torchvision
|
||||
from typing import Optional, Callable
|
||||
|
||||
from src.core import register
|
||||
|
||||
|
||||
@register
|
||||
class CIFAR10(torchvision.datasets.CIFAR10):
|
||||
__inject__ = ['transform', 'target_transform']
|
||||
|
||||
def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None:
|
||||
super().__init__(root, train, transform, target_transform, download)
|
||||
|
||||
Reference in New Issue
Block a user