-
-
Save aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 to your computer and use it in GitHub Desktop.
import functools | |
import numpy as np | |
import tensorflow.compat.v1 as tf | |
from tensorflow.python.tpu import tpu_function | |
BATCH_NORM_DECAY = 0.9 | |
BATCH_NORM_EPSILON = 1e-5 | |
def Activation(inputs, activation='relu'): | |
"""Only supports ReLU and SiLU/Swish.""" | |
assert activation in ['relu', 'silu'] | |
if activation == 'relu': | |
return tf.nn.relu(inputs) | |
else: | |
return tf.nn.swish(inputs) | |
def BNReLU( | |
inputs, is_training, nonlinearity=True, | |
init_zero=False, activation='relu'): | |
"""Performs a batch normalization followed by a ReLU.""" | |
if init_zero: | |
gamma_initializer = tf.zeros_initializer() | |
else: | |
gamma_initializer = tf.ones_initializer() | |
inputs = tf.layers.batch_normalization( | |
inputs=inputs, | |
axis=3, | |
momentum=BATCH_NORM_DECAY, | |
epsilon=BATCH_NORM_EPSILON, | |
center=True, | |
scale=True, | |
training=is_training, | |
fused=True, | |
gamma_initializer=gamma_initializer) | |
if nonlinearity: | |
inputs = Activation(inputs, activation=activation) | |
return inputs | |
def fixed_padding(inputs, kernel_size): | |
"""Pads the input along the spatial dimensions independently of input size.""" | |
pad_total = kernel_size - 1 | |
pad_beg = pad_total // 2 | |
pad_end = pad_total - pad_beg | |
padded_inputs = tf.pad( | |
inputs, [[0, 0], [pad_beg, pad_end], pad_beg, pad_end], [0, 0]]) | |
return padded_inputs | |
def Conv2D(inputs, *, filters, kernel_size, strides=1): | |
"""Strided 2-D convolution with explicit padding.""" | |
if strides > 1: | |
inputs = fixed_padding(inputs, kernel_size) | |
return tf.layers.conv2d( | |
inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, | |
padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, | |
kernel_initializer=tf.variance_scaling_initializer( | |
scale=2., mode='fan_in', distribution='untruncated_normal')) | |
# Functions `rel_to_abs`, `relative_logits_1d`, `relative_logits` | |
# and `relpos_self_attention` are fully based on | |
# https://github.com/tensorflow/tensor2tensor/blob/21dba2c1bdcc7ab582a2bfd8c0885c217963bb4f/tensor2tensor/layers/common_attention.py#L2225. | |
def rel_to_abs(x): | |
""" | |
Converts relative indexing to absolute. | |
Input: [bs, heads, length, 2*length - 1] | |
Output: [bs, heads, length, length] | |
""" | |
bs, heads, length, _ = x.shape | |
col_pad = tf.zeros((bs, heads, length, 1), dtype=x.dtype) | |
x = tf.concat([x, col_pad], axis=3) | |
flat_x = tf.reshape(x, [bs, heads, -1]) | |
flat_pad = tf.zeros((bs, heads, length-1), dtype=x.dtype) | |
flat_x_padded = tf.concat([flat_x, flat_pad], axis=2) | |
final_x = tf.reshape( | |
flat_x_padded, [bs, heads, length+1, 2*length-1]) | |
final_x = final_x[:, :, :length, length-1:] | |
return final_x | |
def relative_logits_1d(*, q, rel_k, transpose_mask): | |
""" | |
Compute relative logits along one dimenion. | |
`q`: [bs, heads, height, width, dim] | |
`rel_k`: [2*width - 1, dim] | |
""" | |
bs, heads, h, w, dim = q.shape | |
rel_logits = tf.einsum('bhxyd,md->bhxym', q, rel_k) | |
rel_logits = tf.reshape(rel_logits, [-1, heads * h, w, 2*w-1]) | |
rel_logits = rel_to_abs(rel_logits) | |
rel_logits = tf.reshape(rel_logits, [-1, heads, h, w, w]) | |
rel_logits = tf.expand_dims(rel_logits, axis=3) | |
rel_logits = tf.tile(rel_logits, [1, 1, 1, h, 1, 1]) | |
rel_logits = tf.transpose(rel_logits, transpose_mask) | |
return rel_logits | |
def relative_logits(q): | |
"""Compute relative position enc logits.""" | |
with tf.variable_scope('relative', reuse=tf.AUTO_REUSE): | |
bs, heads, h, w, dim = q.shape | |
int_dim = dim.value | |
# Note: below, we passed stddev arg as mean for the initializer. | |
# Providing code as is, with this small error. | |
# right way: normal_initializer(stddev=int_dim**-0.5) | |
# Relative logits in width dimension. | |
rel_emb_w = tf.get_variable( | |
'r_width', shape=(2*w - 1, dim), | |
dtype=q.dtype, | |
initializer=tf.random_normal_initializer(int_dim**-0.5)) | |
rel_logits_w = relative_logits_1d( | |
q=q, rel_k=rel_emb_w, | |
transpose_mask=[0, 1, 2, 4, 3, 5]) | |
# Relative logits in height dimension. | |
rel_emb_h = tf.get_variable( | |
'r_height', shape=(2*h - 1, dim), | |
dtype=q.dtype, | |
initializer=tf.random_normal_initializer(int_dim**-0.5)) | |
rel_logits_h = relative_logits_1d( | |
q=tf.transpose(q, [0, 1, 3, 2, 4]), | |
rel_k=rel_emb_h, | |
transpose_mask=[0, 1, 4, 2, 5, 3]) | |
return rel_logits_h + rel_logits_w | |
def relpos_self_attention( | |
*, q, k, v, relative=True, fold_heads=False): | |
"""2D self-attention with rel-pos. Add option to fold heads.""" | |
bs, heads, h, w, dim = q.shape | |
int_dim = dim.value | |
q = q * (dim ** -0.5) # scaled dot-product | |
logits = tf.einsum('bhHWd,bhPQd->bhHWPQ', q, k) | |
if relative: | |
logits += relative_logits(q) | |
weights = tf.reshape(logits, [-1, heads, h, w, h * w]) | |
weights = tf.nn.softmax(weights) | |
weights = tf.reshape(weights, [-1, heads, h, w, h, w]) | |
attn_out = tf.einsum('bhHWPQ,bhPQd->bHWhd', weights, v) | |
if fold_heads: | |
attn_out = tf.reshape(attn_out, [-1, h, w, heads * dim]) | |
return attn_out | |
def absolute_logits(q): | |
"""Compute absolute position enc logits.""" | |
with tf.variable_scope('absolute', reuse=tf.AUTO_REUSE): | |
emb_w = tf.get_variable( | |
'r_width', shape=(W, dkh), | |
dtype=q.dtype, | |
initializer=tf.random_normal_initializer(dkh**-0.5)) | |
emb_h = tf.get_variable( | |
'r_height', shape=(H, dkh), | |
dtype=q.dtype, | |
initializer=tf.random_normal_initializer(dkh**-0.5)) | |
emb_h = emb_h[:, None, :] | |
emb_w = emb_w[None, :, :] | |
emb = emb_h + emb_w | |
abs_logits = tf.einsum('bhxyd,pqd->bhxypq', q, emb) | |
return abs_logits | |
def abspos_self_attention(*, q, k, v, absolue=True, fold_heads=False): | |
"""2D self-attention with abs-pos. Add option to fold heads.""" | |
bs, heads, h, w, dim = q.shape | |
int_dim = dim.value | |
q = q * (dim ** -0.5) # scaled dot-product | |
logits = tf.einsum('bhHWd,bhPQd->bhHWPQ', q, k) | |
abs_logits = absolute_logits(q) | |
if absolute: | |
logits += abs_logits | |
weights = tf.reshape(logits, [-1, heads, h, w, h * w]) | |
weights = tf.nn.softmax(weights) | |
weights = tf.reshape(weights, [-1, heads, h, w, h, w]) | |
attn_out = tf.einsum('bhHWPQ,bhPQd->bHWhd', weights, v) | |
if fold_heads: | |
attn_out = tf.reshape(attn_out, [-1, h, w, heads * dim]) | |
return attn_out | |
def group_pointwise( | |
featuremap, proj_factor=1, name='grouppoint', | |
heads=4, target_dimension=None): | |
"""1x1 conv with heads.""" | |
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | |
in_channels = featuremap.shape[-1] | |
if target_dimension is not None: | |
proj_channels = target_dimension // proj_factor | |
else: | |
proj_channels = in_channels // proj_factor | |
w = tf.get_variable( | |
'w', | |
[in_channels, heads, proj_channels // heads], | |
dtype=featuremap.dtype, | |
initializer=tf.random_normal_initializer(stddev=0.01)) | |
out = tf.einsum('bHWD,Dhd->bhHWd', featuremap, w) | |
return out | |
def MHSA(featuremap, pos_enc_type='relative', use_pos=True): | |
"""Multi-Head Self-Attention.""" | |
q = group_pointwise( | |
featuremap, proj_factor=1, name='q_proj', heads=heads, | |
target_dimension=bottleneck_dimension) | |
k = group_pointwise( | |
featuremap, proj_factor=1, name='k_proj', heads=heads, | |
target_dimension=bottleneck_dimension) | |
v = group_pointwise( | |
featuremap, proj_factor=1, name='v_proj', heads=heads, | |
target_dimension=bottleneck_dimension) | |
assert pos_enc_type in ['relative', 'absolute'] | |
if pos_enc_type == 'relative': | |
o = relpos_self_attention( | |
q=q, k=k, v=v, relative=use_pos, fold_heads=True) | |
else: | |
o = abspos_self_attention( | |
q=q, k=k, v=v, absolute=use_pos, fold_heads=True) | |
return o | |
def BoT_Block( | |
featuremap, is_training=False, | |
heads=4, proj_factor=4, | |
activation='relu', | |
pos_enc_type='relative', | |
name='all2all', strides=1, | |
target_dimension=2048): | |
"""Bottleneck Transformer (BoT) Block.""" | |
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | |
shortcut = featuremap | |
in_dimension = featuremap.shape[-1] | |
if strides != 1 or in_dimension != target_dimension: | |
shortcut = Conv2D( | |
shortcut, filters=target_dimension, kernel_size=1, strides=strides) | |
shortcut = BNReLU( | |
shortcut, is_training, activation=activation, nonlinearity=True) | |
bottleneck_dimension = target_dimension // proj_factor | |
featuremap = Conv2D( | |
featuremap, filters=bottleneck_dimension, kernel_size=1, strides=1) | |
featuremap = BNReLU( | |
featuremap, is_training, activation=activation, nonlinearity=True) | |
featuremap = MHSA(featuremap, pos_enc_type=pos_enc_type) | |
if strides != 1: | |
assert strides == 2 | |
featuremap = tf.keras.layers.AveragePooling2D( | |
pool_size=(2, 2), strides=(2, 2), padding='same')(featuremap) | |
featuremap = BNReLU( | |
featuremap, is_training, activation=activation, nonlinearity=True) | |
featuremap= Conv2D( | |
featuremap, filters=target_dimension, | |
kernel_size=1, strides=1) | |
featuremap = BNReLU( | |
featuremap, is_training, nonlinearity=False, init_zero=True) | |
return Activation(shortcut + featuremap, activation=activation) | |
def BoT_Stack( | |
featuremap, *, | |
blocks_so_far, | |
total_blocks, | |
is_training=False, | |
heads=4, proj_factor=4, | |
activation='relu', | |
pos_enc_type='relative', | |
name='all2all_stack', | |
strides=2, num_layers=3, | |
target_dimension=2048): | |
"""c5 Blockgroup of BoT Blocks.""" | |
with tf.variable_scope(name, reuse=tf.AUTO_REUSE): | |
for i in range(num_layers): | |
featuremap = BoT_Block( | |
featuremap, | |
is_training=is_training, | |
heads=heads, | |
proj_factor=proj_factor, | |
activation=activation, | |
pos_enc_type=pos_enc_type, | |
strides=strides if i == 0 else 1, | |
target_dimension=target_dimension, | |
name='all2all_layer_{}'.format(i)) | |
return featuremap |
Understanding the calculating process of rel_to_abs
gave me an idea of simplify. I think here the zeropadding
s can be removed, but maybe some scenarios I missed:
def rel_to_abs_2(rel_pos):
_, heads, hh, ww, dim = rel_pos.shape # [bs, heads, height, width, 2 * width - 1]
# [bs, heads, height, width * (2 * width - 1)] --> [bs, heads, height, width * (2 * width - 1) - width]
flat_x = tf.reshape(rel_pos, [-1, heads, hh, ww * (ww * 2 - 1)])[:, :, :, ww - 1:-1]
# [bs, heads, height, width, 2 * (width - 1)] --> [bs, heads, height, width, width]
return tf.reshape(flat_x, [-1, heads, hh, ww, 2 * (ww - 1)])[:, :, :, :, :ww]
Test
rel_pos = tf.random.uniform([12, 6, 14, 16, 2 * 16 - 1])
orignal_rel_to_abs = tf.reshape(rel_to_abs(tf.reshape(rel_pos, [-1, 6 * 14, 16, 2 * 16 - 1])), [-1, 6, 14, 16, 16])
print(np.allclose(orignal_rel_to_abs, rel_to_abs_2(rel_pos)))
# True
- Add here keras_cv_attention_models/botnet is my
botnet
with weights loaded fromtimm
- I think this
relative positional embedding
still makes sense in some future works...
@BIGBALLON the Drive link you provided for the .pth
weights is not in the right format it seems:
Could you clarify a bit?
I agree with you that the zeros paddings can be omitted, and your implementation seems more concise and easy-to-understand.
Would you care to push your version to Pytorch Image Models (also known as the timm package), to see if the author agree with you to replace the current version with yours (no padding)?
And also, could the Relative Positional Embedding in HaloNet also be replaced with no padding?
@bsun0802 I have been using this implementation for a long time. Here my keras_cv_attention_models/botnet and also keras_cv_attention_models/halonet both sharing this no-padding version. Those model weights all ported from timm
and kept close outputs. I may discuss this with rwightman.
@leondgarse
Thanks for your reply. I just verified that your idea without padding works for HaloNet as well with a slight different.
The code need to be changed to:
b = 6 # block size
h = 1 # halo size
w = b + 2 * h # window size
To visualize, the index 1 to 8 are the indices we wanted.
x = torch.tensor([[0] * (w-1-i) + list(range(1,1+w)) + [0] * i for i in range(b)])
assert x.shape == (b, 2*w-1)
x
tensor([[0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8],
[0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0],
[0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0],
[0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0],
[0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0]])
For halo attention, we need to add a single 0 at the end of flatten tensor. (If im not wrong, this implementation should work for any block size b
and window size w
, maybe need to be adjusted if halo size h != 1
. )
x = F.pad(x.flatten(), [0, 1])[w-1:].reshape(b, -1) # rel_to_abs
x
tensor([[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0]])
Then simply slice out the intended positions.
out = x[:, :w]
out
tensor([[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8],
[1, 2, 3, 4, 5, 6, 7, 8]])
You may refer to my implementation rel_to_abs, that padding is also not necessary. I'm calling it a full_rank_gap
for this scenario, just need to clip them:
hh = 1
ww, dim = x.shape
pos_dim = (dim + 1) // 2
full_rank_gap = pos_dim - ww
print(f"{pos_dim = }, {full_rank_gap = }")
# pos_dim = 8, full_rank_gap = 2
flat_x = x.reshape([-1, hh, ww * dim])[:, :, ww - 1 : -1]
out = flat_x.reshape([-1, hh, ww, 2 * (pos_dim - 1)])
out
# tensor([[[[0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
# [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
# [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
# [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
# [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0],
# [0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0]]]])
out[:, :, :, full_rank_gap : pos_dim + full_rank_gap]
# tensor([[1, 2, 3, 4, 5, 6, 7, 8],
# [1, 2, 3, 4, 5, 6, 7, 8],
# [1, 2, 3, 4, 5, 6, 7, 8],
# [1, 2, 3, 4, 5, 6, 7, 8],
# [1, 2, 3, 4, 5, 6, 7, 8],
# [1, 2, 3, 4, 5, 6, 7, 8]])
Hi, I found a good explanation of relative position embedding:
https://theaisummer.com/positional-embeddings/
And here is a Chinese version of the explanation I wrote:
https://www.yuque.com/lart/ugkv9f/oazsec