Source code for langml.layers.attention

# -*- coding: utf-8 -*-

from typing import Optional, Union, List

from langml import TF_KERAS
if TF_KERAS:
    import tensorflow.keras as keras
    import tensorflow.keras.backend as K
    import tensorflow.keras.layers as L
else:
    import keras
    import keras.backend as K
    import keras.layers as L

from langml.activations import relu2
from langml.layers import SineCosinePositionEmbedding, ScaleOffset
from langml.tensor_typing import Tensors, Activation, Initializer, Constraint, Regularizer


[docs]class SelfAttention(L.Layer): def __init__(self, attention_units: Optional[int] = None, return_attention: bool = False, is_residual: bool = False, attention_activation: Activation = 'relu', attention_epsilon: float = 1e10, kernel_initializer: Initializer = 'glorot_normal', kernel_regularizer: Optional[Regularizer] = None, kernel_constraint: Optional[Constraint] = None, bias_initializer: Initializer = 'zeros', bias_regularizer: Optional[Regularizer] = None, bias_constraint: Optional[Constraint] = None, use_attention_bias: bool = True, attention_penalty_weight: float = 0.0, **kwargs): super(SelfAttention, self).__init__(**kwargs) self.supports_masking = True self.attention_units = attention_units self.return_attention = return_attention self.is_residual = is_residual self.attention_epsilon = attention_epsilon self.attention_activation = keras.activations.get(attention_activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) self.kernel_constraint = keras.constraints.get(kernel_constraint) self.bias_initializer = keras.initializers.get(bias_initializer) self.bias_regularizer = keras.regularizers.get(bias_regularizer) self.bias_constraint = keras.constraints.get(bias_constraint) self.use_attention_bias = use_attention_bias self.attention_penalty_weight = attention_penalty_weight
[docs] def get_config(self) -> dict: config = { "attention_units": self.attention_units, "return_attention": self.return_attention, "is_residual": self.is_residual, "attention_epsilon": self.attention_epsilon, "attention_activation": keras.activations.serialize(self.attention_activation), "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), "kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer), "kernel_constraint": keras.constraints.serialize(self.kernel_constraint), "bias_initializer": keras.initializers.serialize(self.bias_initializer), "bias_regularizer": keras.regularizers.serialize(self.bias_regularizer), "bias_constraint": keras.constraints.serialize(self.bias_constraint), "use_attention_bias": self.use_attention_bias, "attention_penalty_weight": self.attention_penalty_weight } base_config = super(SelfAttention, self).get_config() return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors): feature_dim = int(input_shape[2]) units = feature_dim if self.attention_units is None else self.attention_units self.Wq = self.add_weight(shape=(feature_dim, units), name=f'{self.name}_Attn_Wq', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wk = self.add_weight(shape=(feature_dim, units), name=f'{self.name}_Attn_Wt', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wv = self.add_weight(shape=(feature_dim, units), name=f'{self.name}_Attn_Wv', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_attention_bias: self.attn_bias = self.add_weight(shape=(1,), name=f'{self.name}_Attn_bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint)
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, **kwargs) -> Union[List[Tensors], Tensors]: q = K.dot(inputs, self.Wq) k = K.dot(inputs, self.Wk) v = K.dot(inputs, self.Wv) if self.attention_activation is not None: q = self.attention_activation(q) k = self.attention_activation(k) v = self.attention_activation(v) if self.use_attention_bias: q += self.attn_bias k += self.attn_bias v += self.attn_bias e = K.batch_dot(q, k, axes=2) if mask is not None: if len(K.int_shape(mask)) == len(K.int_shape(inputs)) - 1: mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1) e -= self.attention_epsilon * (1.0 - mask) a = K.softmax(e) v_o = K.batch_dot(a, v) if self.is_residual: v_o += v if self.attention_penalty_weight > 0.0: self.add_loss(self._attention_penalty(a)) if self.return_attention: return [v_o, a] return v_o
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None) -> Union[List[Union[Tensors, None]], Tensors]: if self.return_attention: return [mask, None] return mask
[docs] def _attention_penalty(self, attention: Tensors) -> Tensors: batch_size = K.cast(K.shape(attention)[0], K.floatx()) input_len = K.shape(attention)[-1] indices = K.expand_dims(K.arange(0, input_len), axis=0) diagonal = K.expand_dims(K.arange(0, input_len), axis=-1) eye = K.cast(K.equal(indices, diagonal), K.floatx()) return self.attention_penalty_weight * K.sum(K.square(K.batch_dot( attention, K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
@staticmethod
[docs] def get_custom_objects() -> dict: return {'SelfAttention': SelfAttention}
[docs] def compute_output_shape(self, input_shape: Tensors) -> Union[List[Tensors], Tensors]: output_shape = input_shape if self.return_attention: attention_shape = (input_shape[0], output_shape[1], input_shape[1]) return [output_shape, attention_shape] return output_shape
[docs]class SelfAdditiveAttention(L.Layer): def __init__(self, attention_units: Optional[int] = None, return_attention: bool = False, is_residual: bool = False, attention_activation: Activation = 'relu', attention_epsilon: float = 1e10, kernel_initializer: Initializer = 'glorot_normal', kernel_regularizer: Optional[Regularizer] = None, kernel_constraint: Optional[Constraint] = None, bias_initializer: Initializer = 'zeros', bias_regularizer: Optional[Regularizer] = None, bias_constraint: Optional[Constraint] = None, use_attention_bias: bool = True, attention_penalty_weight: float = 0.0, **kwargs): super(SelfAdditiveAttention, self).__init__(**kwargs) self.supports_masking = True self.attention_units = attention_units self.return_attention = return_attention self.is_residual = is_residual self.attention_epsilon = attention_epsilon self.attention_activation = keras.activations.get(attention_activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) self.kernel_constraint = keras.constraints.get(kernel_constraint) self.bias_initializer = keras.initializers.get(bias_initializer) self.bias_regularizer = keras.regularizers.get(bias_regularizer) self.bias_constraint = keras.constraints.get(bias_constraint) self.use_attention_bias = use_attention_bias self.attention_penalty_weight = attention_penalty_weight
[docs] def get_config(self) -> dict: config = { "attention_units": self.attention_units, "return_attention": self.return_attention, "is_residual": self.is_residual, "attention_epsilon": self.attention_epsilon, "attention_activation": keras.activations.serialize(self.attention_activation), "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), "kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer), "kernel_constraint": keras.constraints.serialize(self.kernel_constraint), "bias_initializer": keras.initializers.serialize(self.bias_initializer), "bias_regularizer": keras.regularizers.serialize(self.bias_regularizer), "bias_constraint": keras.constraints.serialize(self.bias_constraint), "use_attention_bias": self.use_attention_bias, "attention_penalty_weight": self.attention_penalty_weight } base_config = super(SelfAdditiveAttention, self).get_config() return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors): feature_dim = int(input_shape[2]) units = feature_dim if self.attention_units is None else self.attention_units self.Wh = self.add_weight(shape=(feature_dim, units), name=f'{self.name}_Attn_Wh', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.We = self.add_weight(shape=(units, 1), name=f'{self.name}_Attn_We', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_attention_bias: self.attn_bias = self.add_weight(shape=(1,), name=f'{self.name}_Attn_bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint)
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, **kwargs) -> Union[List[Tensors], Tensors]: h = K.dot(inputs, self.Wh) if self.attention_activation is not None: h = self.attention_activation(h) if self.use_attention_bias: h += self.attn_bias e = K.dot(h, self.We) if self.use_attention_bias: e += self.attn_bias if mask is not None: if len(K.int_shape(mask)) == len(K.int_shape(inputs)) - 1: mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1) e -= self.attention_epsilon * (1.0 - mask) a = K.softmax(e, axis=1) v_o = a * inputs if self.is_residual: v_o += inputs if self.attention_penalty_weight > 0.0: self.add_loss(self._attention_penalty(a)) if self.return_attention: return [v_o, a] return v_o
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None) -> Union[List[Union[Tensors, None]], Tensors]: if self.return_attention: return [mask, None] return mask
[docs] def _attention_penalty(self, attention: Tensors) -> Tensors: batch_size = K.cast(K.shape(attention)[0], K.floatx()) input_len = K.shape(attention)[-1] indices = K.expand_dims(K.arange(0, input_len), axis=0) diagonal = K.expand_dims(K.arange(0, input_len), axis=-1) eye = K.cast(K.equal(indices, diagonal), K.floatx()) return self.attention_penalty_weight * K.sum(K.square(K.batch_dot( attention, K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size
@staticmethod
[docs] def get_custom_objects() -> dict: return {'SelfAdditiveAttention': SelfAdditiveAttention}
[docs] def compute_output_shape(self, input_shape: Tensors) -> Union[List[Tensors], Tensors]: output_shape = input_shape if self.return_attention: attention_shape = (input_shape[0], output_shape[1], input_shape[1]) return [output_shape, attention_shape] return output_shape
[docs]class ScaledDotProductAttention(L.Layer): r""" ScaledDotProductAttention $Attention(Q, K, V) = softmax(\frac{Q K^T}{\sqrt{d_k}}) V$ https://arxiv.org/pdf/1706.03762.pdf """ def __init__(self, return_attention: bool = False, history_only: bool = False, **kwargs): super(ScaledDotProductAttention, self).__init__(**kwargs) self.supports_masking = True self.return_attention = return_attention self.history_only = history_only
[docs] def get_config(self) -> dict: config = { "return_attention": self.return_attention, "history_only": self.history_only, } base_config = super(ScaledDotProductAttention, self).get_config() return dict(base_config, **config)
[docs] def call(self, inputs: Tensors, mask: Optional[Union[Tensors, List[Tensors]]] = None, **kwargs) -> Union[List[Tensors], Tensors]: if isinstance(inputs, list): q, k, v = inputs else: q = k = v = inputs if isinstance(mask, list): mask = mask[1] # e = \frac{QK^T}{\sqrt{d_k}} # shape: [(B, Lq, D), (B, Lk, D)] -> (B, Lq, Lk) e = K.batch_dot(q, k, axes=2) / K.sqrt(K.cast(K.shape(q)[-1], dtype=K.floatx())) if self.history_only: q_len, k_len = K.shape(q)[1], K.shape(k)[1] indices = K.expand_dims(K.arange(0, k_len), axis=0) upper = K.expand_dims(K.arange(0, q_len), axis=-1) e -= 10000.0 * K.expand_dims(K.cast(indices > upper, K.floatx()), axis=0) if mask is not None: e -= 10000.0 * (1.0 - K.cast(K.expand_dims(mask, axis=-2), K.floatx())) # softmax(e) e = K.exp(e - K.max(e, axis=-1, keepdims=True)) attention = e / K.sum(e, axis=-1, keepdims=True) v = K.batch_dot(attention, v) if self.return_attention: return [v, attention] return v
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Union[Tensors, List[Tensors]]] = None) -> Union[ List[Union[Tensors, None]], Tensors]: if isinstance(mask, list): mask = mask[0] return mask
@staticmethod
[docs] def get_custom_objects() -> dict: return {'ScaledDotProductAttention': ScaledDotProductAttention}
[docs] def compute_output_shape(self, input_shape: Union[Tensors, List[Tensors]]) -> Union[List[Tensors], Tensors]: if isinstance(input_shape, list): q_shape, k_shape, v_shape = input_shape else: q_shape = k_shape = v_shape = input_shape output_shape = q_shape[:-1] + v_shape[-1:] if self.return_attention: attention_shape = q_shape[:2] + (k_shape[1],) return [output_shape, attention_shape] return output_shape
[docs]class MultiHeadAttention(L.Layer): """ MultiHeadAttention https://arxiv.org/pdf/1706.03762.pdf """ def __init__(self, head_num: int, return_attention: bool = False, attention_activation: Activation = 'relu', kernel_initializer: Initializer = 'glorot_normal', kernel_regularizer: Optional[Regularizer] = None, kernel_constraint: Optional[Constraint] = None, bias_initializer: Initializer = 'zeros', bias_regularizer: Optional[Regularizer] = None, bias_constraint: Optional[Constraint] = None, use_attention_bias: bool = True, history_only: bool = False, **kwargs): super(MultiHeadAttention, self).__init__(**kwargs) self.supports_masking = True self.head_num = head_num self.return_attention = return_attention self.attention_activation = keras.activations.get(attention_activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) self.kernel_constraint = keras.constraints.get(kernel_constraint) self.bias_initializer = keras.initializers.get(bias_initializer) self.bias_regularizer = keras.regularizers.get(bias_regularizer) self.bias_constraint = keras.constraints.get(bias_constraint) self.use_attention_bias = use_attention_bias self.history_only = history_only
[docs] def get_config(self) -> dict: config = { "head_num": self.head_num, "return_attention": self.return_attention, "attention_activation": keras.activations.serialize(self.attention_activation), "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), "kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer), "kernel_constraint": keras.constraints.serialize(self.kernel_constraint), "bias_initializer": keras.initializers.serialize(self.bias_initializer), "bias_regularizer": keras.regularizers.serialize(self.bias_regularizer), "bias_constraint": keras.constraints.serialize(self.bias_constraint), "use_attention_bias": self.use_attention_bias, "history_only": self.history_only } base_config = super(MultiHeadAttention, self).get_config() return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors): if isinstance(input_shape, list): q, k, v = input_shape else: q = k = v = input_shape feature_dim = int(v[-1]) assert feature_dim % self.head_num == 0, 'feature_dim should be divided by head_num with no remainder' self.Wq = self.add_weight(shape=(int(q[-1]), feature_dim), name=f'{self.name}_Wq', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wk = self.add_weight(shape=(int(k[-1]), feature_dim), name=f'{self.name}_Wk', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wv = self.add_weight(shape=(feature_dim, feature_dim), name=f'{self.name}_Wv', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wo = self.add_weight(shape=(feature_dim, feature_dim), name=f'{self.name}_Wo', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_attention_bias: self.bq = self.add_weight(shape=(feature_dim,), name=f'{self.name}_bq', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.bk = self.add_weight(shape=(feature_dim,), name=f'{self.name}_bk', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.bv = self.add_weight(shape=(feature_dim,), name=f'{self.name}_bv', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.bo = self.add_weight(shape=(feature_dim,), name=f'{self.name}_bo', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint)
@staticmethod
[docs] def _reshape_to_batches(x, head_num): input_shape = K.shape(x) batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2] head_dim = feature_dim // head_num x = K.reshape(x, (batch_size, seq_len, head_num, head_dim)) x = K.permute_dimensions(x, [0, 2, 1, 3]) return K.reshape(x, (batch_size * head_num, seq_len, head_dim))
@staticmethod
[docs] def _reshape_attention_from_batches(x, head_num): input_shape = K.shape(x) batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2] x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim)) return K.permute_dimensions(x, [0, 2, 1, 3])
@staticmethod
[docs] def _reshape_from_batches(x, head_num): input_shape = K.shape(x) batch_size, seq_len, feature_dim = input_shape[0], input_shape[1], input_shape[2] x = K.reshape(x, (batch_size // head_num, head_num, seq_len, feature_dim)) x = K.permute_dimensions(x, [0, 2, 1, 3]) return K.reshape(x, (batch_size // head_num, seq_len, feature_dim * head_num))
@staticmethod
[docs] def _reshape_mask(mask, head_num): if mask is None: return mask seq_len = K.shape(mask)[1] mask = K.expand_dims(mask, axis=1) mask = K.tile(mask, [1, head_num, 1]) return K.reshape(mask, (-1, seq_len))
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, **kwargs) -> Tensors: if isinstance(inputs, list): q, k, v = inputs else: q = k = v = inputs if isinstance(mask, list): q_mask, k_mask, v_mask = mask else: q_mask = k_mask = v_mask = mask q = K.dot(q, self.Wq) k = K.dot(k, self.Wk) v = K.dot(v, self.Wv) if self.use_attention_bias: q += self.bq k += self.bk v += self.bv if self.attention_activation is not None: q = self.attention_activation(q) k = self.attention_activation(k) v = self.attention_activation(v) scaled_dot_product_attention = ScaledDotProductAttention( return_attention=True, history_only=self.history_only, name=f'{self.name}-Attention', ) output, attention = scaled_dot_product_attention( inputs=[ self._reshape_to_batches(q, self.head_num), self._reshape_to_batches(k, self.head_num), self._reshape_to_batches(v, self.head_num), ], mask=[ self._reshape_mask(q_mask, self.head_num), self._reshape_mask(k_mask, self.head_num), self._reshape_mask(v_mask, self.head_num), ], ) attention = self._reshape_attention_from_batches(attention, self.head_num) output = self._reshape_from_batches(output, self.head_num) output = K.dot(output, self.Wo) if self.use_attention_bias: output += self.bo if self.attention_activation is not None: output = self.attention_activation(output) if self.return_attention: return [output, attention] return output
@staticmethod
[docs] def get_custom_objects() -> dict: return {'MultiHeadAttention': MultiHeadAttention}
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None) -> Union[List[Union[Tensors, None]], Tensors]: if isinstance(mask, list): mask = mask[0] return mask
[docs] def compute_output_shape(self, input_shape: Union[Tensors, List[Tensors]]) -> Union[List[Tensors], Tensors]: if isinstance(input_shape, list): q_shape, _, v_shape = input_shape else: q_shape = _ = v_shape = input_shape output_shape = q_shape[:-1] + (v_shape[-1],) if self.return_attention: attention_shape = (*q_shape[:-1], self.head_num, v_shape[-1]) return [output_shape, attention_shape] return output_shape
[docs]class GatedAttentionUnit(L.Layer): """ Gated Attention Unit https://arxiv.org/abs/2202.10447 """ def __init__(self, attention_units: int, attention_activation: Activation = 'relu', attention_normalizer: Activation = relu2, attention_epsilon: float = 1e10, kernel_initializer: Initializer = 'glorot_normal', kernel_regularizer: Optional[Regularizer] = None, kernel_constraint: Optional[Constraint] = None, bias_initializer: Initializer = 'zeros', bias_regularizer: Optional[Regularizer] = None, bias_constraint: Optional[Constraint] = None, use_attention_bias: bool = True, use_attention_scale: bool = True, use_relative_position: bool = True, use_offset: bool = True, use_scale: bool = True, is_residual: bool = True, **kwargs): super(GatedAttentionUnit, self).__init__(**kwargs) self.supports_masking = True self.attention_units = attention_units self.attention_activation = keras.activations.get(attention_activation) self.attention_normalizer = ( keras.activations.get(attention_normalizer) if isinstance(attention_normalizer, str) else attention_normalizer) self.attention_epsilon = attention_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) self.kernel_constraint = keras.constraints.get(kernel_constraint) self.bias_initializer = keras.initializers.get(bias_initializer) self.bias_regularizer = keras.regularizers.get(bias_regularizer) self.bias_constraint = keras.constraints.get(bias_constraint) self.use_attention_bias = use_attention_bias self.use_attention_scale = use_attention_scale self.use_relative_position = use_relative_position self.use_offset = use_offset self.use_scale = use_scale self.is_residual = is_residual
[docs] def get_config(self) -> dict: config = { "attention_units": self.attention_units, "attention_activation": keras.activations.serialize(self.attention_activation), "attention_normalizer": keras.activations.serialize(self.attention_normalizer), "attention_epsilon": self.attention_epsilon, "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), "kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer), "kernel_constraint": keras.constraints.serialize(self.kernel_constraint), "bias_initializer": keras.initializers.serialize(self.bias_initializer), "bias_regularizer": keras.regularizers.serialize(self.bias_regularizer), "bias_constraint": keras.constraints.serialize(self.bias_constraint), "use_attention_bias": self.use_attention_bias, "use_attention_scale": self.use_attention_scale, "use_relative_position": self.use_relative_position, "use_offset": self.use_offset, "use_scale": self.use_scale, "is_residual": self.is_residual } base_config = super(GatedAttentionUnit, self).get_config() return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors): super(GatedAttentionUnit, self).build(input_shape) feature_dim = int(input_shape[-1]) self.Wu = self.add_weight(shape=(feature_dim, 2*feature_dim), name=f'{self.name}_Wu', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wv = self.add_weight(shape=(feature_dim, 2*feature_dim), name=f'{self.name}_Wv', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wz = self.add_weight(shape=(feature_dim, self.attention_units), name=f'{self.name}_Wz', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.Wo = self.add_weight(shape=(2*feature_dim, feature_dim), name=f'{self.name}_Wo', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_attention_bias: self.bu = self.add_weight(shape=(2*feature_dim,), name=f'{self.name}_bu', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.bv = self.add_weight(shape=(2*feature_dim,), name=f'{self.name}_bv', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.bz = self.add_weight(shape=(self.attention_units,), name=f'{self.name}_bz', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.bo = self.add_weight(shape=(feature_dim,), name=f'{self.name}_bo', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.scale_offset_q = ScaleOffset(scale=self.use_scale, offset=self.use_offset, name=f'{self.name}_scale_offset_q') self.scale_offset_k = ScaleOffset(scale=self.use_scale, offset=self.use_offset, name=f'{self.name}_scale_offset_k')
[docs] def apply_rotary_position_embeddings(self, sinusoidal: Tensors, *tensors): """ apply RoPE modified from: https://github.com/bojone/bert4keras/blob/master/bert4keras/backend.py#L310 """ def align(tensor, axes, ndim=None): assert len(axes) == K.ndim(tensor) assert ndim or min(axes) >= 0 ndim = ndim or max(axes) + 1 indices = [None] * ndim for i in axes: indices[i] = slice(None) return tensor[indices] assert len(tensors) > 0, 'at least one input tensor' assert all([ K.int_shape(tensor) == K.int_shape(tensors[0]) for tensor in tensors[1:] ]), 'all tensors must have the same shape' ndim = K.ndim(tensors[0]) sinusoidal = align(sinusoidal, [0, 1, -1], ndim) cos_pos = K.repeat_elements(sinusoidal[..., 1::2], 2, -1) sin_pos = K.repeat_elements(sinusoidal[..., ::2], 2, -1) outputs = [] for tensor in tensors: tensor2 = K.stack([-tensor[..., 1::2], tensor[..., ::2]], ndim) tensor2 = K.reshape(tensor2, K.shape(tensor)) outputs.append(tensor * cos_pos + tensor2 * sin_pos) return outputs[0] if len(outputs) == 1 else outputs
[docs] def attn(self, x: Tensors, v: Tensors, mask: Optional[Tensors] = None) -> Tensors: z = K.dot(x, self.Wz) if self.use_attention_bias: z += self.bz z = self.attention_activation(z) q, k = self.scale_offset_q(z), self.scale_offset_k(z) if self.use_relative_position: pos = SineCosinePositionEmbedding("zero", output_dim=self.attention_units)(x) q, k = self.apply_rotary_position_embeddings(pos, q, k) qk = K.batch_dot(q, k, axes=2) # (B, N, S) * (B, M, S) -> (B, N, M) if self.use_attention_scale: qk /= self.attention_units**0.5 if mask is not None: if len(K.int_shape(mask)) == len(K.int_shape(x)) - 1: mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1) qk -= self.attention_epsilon * (1.0 - mask) a = self.attention_normalizer(qk) return K.batch_dot(a, v) # (B, N, M) * (B, M, E) -> (B, N, E)
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, **kwargs) -> Tensors: u = K.dot(inputs, self.Wu) v = K.dot(inputs, self.Wv) if self.use_attention_bias: u += self.bu v += self.bv u = self.attention_activation(u) v = self.attention_activation(v) x = u * self.attn(inputs, v, mask) o = K.dot(x, self.Wo) if self.use_attention_bias: o += self.bo if self.is_residual: return inputs + o # residual return o
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None) -> Tensors: return mask
[docs] def compute_output_shape(self, input_shape: Tensors) -> Tensors: return input_shape
@staticmethod
[docs] def get_custom_objects() -> dict: return {'GatedAttentionUnit': GatedAttentionUnit}