Source code for langml.layers.layer_norm

# -*- coding: utf-8 -*-

from typing import Optional, Union

from langml import TF_KERAS
if TF_KERAS:
    import tensorflow.keras as keras
    import tensorflow.keras.backend as K
    import tensorflow.keras.layers as L
else:
    import keras
    import keras.backend as K
    import keras.layers as L

from langml.tensor_typing import Tensors, Initializer, Constraint, Regularizer


[docs]class LayerNorm(L.Layer): def __init__(self, center: bool = True, scale: bool = True, epsilon: float = 1e-7, gamma_initializer: Initializer = 'ones', gamma_regularizer: Optional[Regularizer] = None, gamma_constraint: Optional[Constraint] = None, beta_initializer: Initializer = 'zeros', beta_regularizer: Optional[Regularizer] = None, beta_constraint: Optional[Constraint] = None, **kwargs): super(LayerNorm, self).__init__(**kwargs) self.supports_masking = True self.center = center self.scale = scale self.epsilon = epsilon self.gamma_initializer = keras.initializers.get(gamma_initializer) self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) self.gamma_constraint = keras.constraints.get(gamma_constraint) self.beta_initializer = keras.initializers.get(beta_initializer) self.beta_regularizer = keras.regularizers.get(beta_regularizer) self.beta_constraint = keras.constraints.get(beta_constraint)
[docs] def get_config(self) -> dict: config = { "center": self.center, "scale": self.scale, "epsilon": self.epsilon, "gamma_initializer": keras.initializers.serialize(self.gamma_initializer), "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer), "gamma_constraint": keras.constraints.serialize(self.gamma_constraint), "beta_initializer": keras.initializers.serialize(self.beta_initializer), "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer), "beta_constraint": keras.constraints.serialize(self.beta_constraint) } base_config = super(LayerNorm, self).get_config() return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors): shape = input_shape[-1:] if self.scale: self.gamma = self.add_weight( shape=shape, initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, constraint=self.gamma_constraint, name='gamma', ) if self.center: self.beta = self.add_weight( shape=shape, initializer=self.beta_initializer, regularizer=self.beta_regularizer, constraint=self.beta_constraint, name='beta', ) super(LayerNorm, self).build(input_shape)
[docs] def call(self, inputs: Tensors, **kwargs) -> Tensors: # layer norm: specify axis=-1 mean = K.mean(inputs, axis=-1, keepdims=True) variance = K.mean(K.square(inputs), axis=-1, keepdims=True) std = K.sqrt(variance + self.epsilon) # standard normalization: x = (x - \mu) / \std outputs = (inputs - mean) / std if self.scale: outputs *= self.gamma if self.center: outputs += self.beta return outputs
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None) -> Union[Tensors, None]: return mask
@staticmethod
[docs] def get_custom_objects() -> dict: return {'LayerNorm': LayerNorm}
[docs] def compute_output_shape(self, input_shape: Tensors) -> Tensors: return input_shape