# -*- coding: utf-8 -*-
from typing import Optional, List, Union
from langml import TF_KERAS
if TF_KERAS:
import tensorflow.keras as keras
import tensorflow.keras.backend as K
import tensorflow.keras.layers as L
else:
import keras
import keras.backend as K
import keras.layers as L
from langml.tensor_typing import Tensors, Initializer, Constraint, Regularizer
[docs]class TokenEmbedding(L.Embedding):
@staticmethod
[docs] def get_custom_objects() -> dict:
return {'TokenEmbedding': TokenEmbedding}
[docs] def compute_mask(self,
inputs: Tensors,
mask: Optional[Tensors] = None) -> List[Union[Tensors, None]]:
return [super(TokenEmbedding, self).compute_mask(inputs, mask), None]
[docs] def call(self, inputs: Tensors) -> List[Tensors]:
return [super(TokenEmbedding, self).call(inputs), self.embeddings + 0]
[docs] def compute_output_shape(self, input_shape: Tensors) -> List[Tensors]:
return [super(TokenEmbedding, self).compute_output_shape(input_shape), K.int_shape(self.embeddings)]
[docs]class EmbeddingMatching(L.Layer):
def __init__(self,
initializer: Initializer = 'zeros',
regularizer: Optional[Regularizer] = None,
constraint: Optional[Constraint] = None,
use_bias: bool = True,
use_softmax: bool = True,
**kwargs):
super(EmbeddingMatching, self).__init__(**kwargs)
self.supports_masking = True
self.initializer = keras.initializers.get(initializer)
self.regularizer = keras.regularizers.get(regularizer)
self.constraint = keras.constraints.get(constraint)
self.use_bias = use_bias
self.use_softmax = use_softmax
[docs] def get_config(self) -> dict:
config = {
'initializer': keras.initializers.serialize(self.initializer),
'regularizer': keras.regularizers.serialize(self.regularizer),
'constraint': keras.constraints.serialize(self.constraint),
'use_bias': self.use_bias,
'use_softmax': self.use_softmax,
}
base_config = super(EmbeddingMatching, self).get_config()
return dict(base_config, **config)
[docs] def build(self, input_shape: Tensors):
if self.use_bias:
self.bias = self.add_weight(
shape=(int(input_shape[1][0]), ),
initializer=self.initializer,
regularizer=self.regularizer,
constraint=self.constraint,
name='bias',
)
super(EmbeddingMatching, self).build(input_shape)
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None) -> Tensors:
if isinstance(mask, list):
return mask[0]
return mask
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, **kwargs) -> Tensors:
inputs, embeddings = inputs
output = K.dot(inputs, K.transpose(embeddings))
if self.use_bias:
output = K.bias_add(output, self.bias)
if self.use_softmax:
return K.softmax(output)
return output
@staticmethod
[docs] def get_custom_objects() -> dict:
return {'EmbeddingMatching': EmbeddingMatching}
[docs] def compute_output_shape(self, input_shape: Tensors) -> Tensors:
return input_shape[0][:2] + (input_shape[1][0], )
[docs]class Masked(L.Layer):
"""Generate output mask based on the given mask.
https://arxiv.org/pdf/1810.04805.pdf
"""
def __init__(self,
return_masked: bool = False,
**kwargs):
super(Masked, self).__init__(**kwargs)
self.supports_masking = True
self.return_masked = return_masked
@staticmethod
[docs] def get_custom_objects() -> dict:
return {'Masked': Masked}
[docs] def get_config(self) -> dict:
config = {
'return_masked': self.return_masked,
}
base_config = super(Masked, self).get_config()
return dict(base_config, **config)
[docs] def compute_mask(self,
inputs: Tensors,
mask: Optional[Tensors] = None) -> Union[List[Union[Tensors, None]], Tensors]:
token_mask = K.not_equal(inputs[1], 0)
masked = K.all(K.stack([token_mask, mask[0]], axis=0), axis=0)
if self.return_masked:
return [masked, None]
return masked
[docs] def call(self, inputs: Tensors, mask: Optional[Tensors] = None, **kwargs) -> Tensors:
output = inputs[0] + 0
if self.return_masked:
return [output, K.cast(self.compute_mask(inputs, mask)[0], K.floatx())]
return output
[docs] def compute_output_shape(self, input_shape: Tensors) -> Union[List[Tensors], Tensors]:
if self.return_masked:
return [input_shape[0], (2, ) + input_shape[1]]
return input_shape[0]