125 lines
4.0 KiB
Python
125 lines
4.0 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
from paddle.nn.initializer import TruncatedNormal, Constant, Assign
|
|
|
|
# Common initializations
|
|
ones_ = Constant(value=1.)
|
|
zeros_ = Constant(value=0.)
|
|
trunc_normal_ = TruncatedNormal(std=.02)
|
|
|
|
|
|
# Common Layers
|
|
def drop_path(x, drop_prob=0., training=False):
|
|
"""
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
|
"""
|
|
if drop_prob == 0. or not training:
|
|
return x
|
|
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
|
|
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
|
|
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
|
random_tensor = paddle.floor(random_tensor) # binarize
|
|
output = x.divide(keep_prob) * random_tensor
|
|
return output
|
|
|
|
|
|
class DropPath(nn.Layer):
|
|
def __init__(self, drop_prob=None):
|
|
super(DropPath, self).__init__()
|
|
self.drop_prob = drop_prob
|
|
|
|
def forward(self, x):
|
|
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
class Identity(nn.Layer):
|
|
def __init__(self):
|
|
super(Identity, self).__init__()
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
|
|
# common funcs
|
|
|
|
|
|
def to_2tuple(x):
|
|
if isinstance(x, (list, tuple)):
|
|
return x
|
|
return tuple([x] * 2)
|
|
|
|
|
|
def add_parameter(layer, datas, name=None):
|
|
parameter = layer.create_parameter(
|
|
shape=(datas.shape), default_initializer=Assign(datas))
|
|
if name:
|
|
layer.add_parameter(name, parameter)
|
|
return parameter
|
|
|
|
|
|
def window_partition(x, window_size):
|
|
"""
|
|
Partition into non-overlapping windows with padding if needed.
|
|
Args:
|
|
x (tensor): input tokens with [B, H, W, C].
|
|
window_size (int): window size.
|
|
Returns:
|
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
|
(Hp, Wp): padded height and width before partition
|
|
"""
|
|
B, H, W, C = paddle.shape(x)
|
|
|
|
pad_h = (window_size - H % window_size) % window_size
|
|
pad_w = (window_size - W % window_size) % window_size
|
|
x = F.pad(x.transpose([0, 3, 1, 2]),
|
|
paddle.to_tensor(
|
|
[0, int(pad_w), 0, int(pad_h)],
|
|
dtype='int32')).transpose([0, 2, 3, 1])
|
|
Hp, Wp = H + pad_h, W + pad_w
|
|
|
|
num_h, num_w = Hp // window_size, Wp // window_size
|
|
|
|
x = x.reshape([B, num_h, window_size, num_w, window_size, C])
|
|
windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
|
|
[-1, window_size, window_size, C])
|
|
return windows, (Hp, Wp), (num_h, num_w)
|
|
|
|
|
|
def window_unpartition(x, pad_hw, num_hw, hw):
|
|
"""
|
|
Window unpartition into original sequences and removing padding.
|
|
Args:
|
|
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
|
pad_hw (Tuple): padded height and width (Hp, Wp).
|
|
hw (Tuple): original height and width (H, W) before padding.
|
|
Returns:
|
|
x: unpartitioned sequences with [B, H, W, C].
|
|
"""
|
|
Hp, Wp = pad_hw
|
|
num_h, num_w = num_hw
|
|
H, W = hw
|
|
B, window_size, _, C = paddle.shape(x)
|
|
B = B // (num_h * num_w)
|
|
x = x.reshape([B, num_h, num_w, window_size, window_size, C])
|
|
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, C])
|
|
|
|
return x[:, :H, :W, :]
|