Source code for langml.transformer.layers

# -*- coding: utf-8 -*-

""" Yet another transformer implementation.
"""

# TODO: Transformer Decoder

from typing import Optional, List, Union, Any

from langml import TF_KERAS
if TF_KERAS:
    import tensorflow.keras as keras
    import tensorflow.keras.backend as K
    import tensorflow.keras.layers as L
else:
    import keras
    import keras.backend as K
    import keras.layers as L
from langml.tensor_typing import Tensors, Activation, Initializer, Constraint, Regularizer


[docs]class FeedForward(L.Layer): """ Feed Forward Layer https://arxiv.org/pdf/1706.03762.pdf """ def __init__(self, units, activation: Activation = 'relu', kernel_initializer: Initializer = 'glorot_normal', kernel_regularizer: Optional[Regularizer] = None, kernel_constraint: Optional[Constraint] = None, bias_initializer: Initializer = 'zeros', bias_regularizer: Optional[Regularizer] = None, bias_constraint: Optional[Constraint] = None, use_bias: bool = True, dropout_rate: float = 0.0, **kwargs): super(FeedForward, self).__init__(**kwargs) self.supports_masking = True self.units = units self.activation = keras.activations.get(activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) self.kernel_constraint = keras.constraints.get(kernel_constraint) self.bias_initializer = keras.initializers.get(bias_initializer) self.bias_regularizer = keras.regularizers.get(bias_regularizer) self.bias_constraint = keras.constraints.get(bias_constraint) self.use_bias = use_bias self.dropout_rate = dropout_rate
[docs] def get_config(self) -> dict: config = { "units": self.units, "activation": keras.activations.serialize(self.activation), "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), "kernel_regularizer": keras.regularizers.serialize(self.kernel_regularizer), "kernel_constraint": keras.constraints.serialize(self.kernel_constraint), "bias_initializer": keras.initializers.serialize(self.bias_initializer), "bias_regularizer": keras.regularizers.serialize(self.bias_regularizer), "bias_constraint": keras.constraints.serialize(self.bias_constraint), "use_bias": self.use_bias, "dropout_rate": self.dropout_rate } base_config = super(FeedForward, self).get_config() return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors): feature_dim = int(input_shape[-1]) self.W1 = self.add_weight( shape=(feature_dim, self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, name=f'{self.name}_W1', ) self.W2 = self.add_weight( shape=(self.units, feature_dim), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, name='{}_W2'.format(self.name), ) if self.use_bias: self.b1 = self.add_weight( shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, name='{}_b1'.format(self.name), ) self.b2 = self.add_weight( shape=(feature_dim,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, name='{}_b2'.format(self.name), ) if self.dropout_rate > 0.0: self.dropout_layer = L.Dropout(self.dropout_rate) super(FeedForward, self).build(input_shape)
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, training: Optional[Any] = None, **kwargs) -> Union[List[Tensors], Tensors]: hidden = K.dot(inputs, self.W1) if self.use_bias: hidden = K.bias_add(hidden, self.b1) if self.activation is not None: hidden = self.activation(hidden) if self.dropout_rate > 0.0: hidden = self.dropout_layer(hidden) output = K.dot(hidden, self.W2) if self.use_bias: output = K.bias_add(output, self.b2) return output
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Union[Tensors, List[Tensors]]] = None) -> Union[ List[Union[Tensors, None]], Tensors]: return mask
@staticmethod
[docs] def get_custom_objects() -> dict: return {'FeedForward': FeedForward}
[docs] def compute_output_shape(self, input_shape: Tensors) -> Tensors: return input_shape