Source code for langml.layers.crf

# -*- coding: utf-8 -*-

from typing import Optional, Callable, Union

from langml import TF_KERAS
if TF_KERAS:
    import tensorflow.keras.backend as K
    import tensorflow.keras.layers as L
else:
    import keras.backend as K
    import keras.layers as L

import tensorflow as tf

from langml.tensor_typing import Tensors
from langml.third_party.crf import crf_log_likelihood, crf_decode


[docs]class CRF(L.Layer): def __init__(self, output_dim: int, sparse_target: bool = True, **kwargs): """ Args: output_dim (int): the number of labels to tag each temporal input. sparse_target (bool): whether the the ground-truth label represented in one-hot. Input shape: (batch_size, sentence length, output_dim) Output shape: (batch_size, sentence length, output_dim) Usage: >>> from tensorflow.keras.models import Sequential >>> from tensorflow.keras.layers import Input, Embedding, Bidirectional, LSTM, Dense >>> num_labels = 10 >>> embedding_size = 100 >>> hidden_size = 128 >>> model = Sequential() >>> model.add(Embedding(num_labels, embedding_size)) >>> model.add(Bidirectional(LSTM(hidden_size, return_sequences=True))) >>> model.add(Dense(num_labels)) >>> crf = CRF(num_labels, sparse_target=True) >>> model.add(crf) >>> model.compile('adam', loss=crf.loss, metrics=[crf.accuracy]) """ super(CRF, self).__init__(**kwargs) self.support_mask = True self.output_dim = output_dim self.sparse_target = sparse_target self.input_spec = L.InputSpec(min_ndim=3) self.supports_masking = False self.sequence_lengths = None
[docs] def build(self, input_shape: Tensors): assert len(input_shape) == 3 f_shape = input_shape input_spec = L.InputSpec(min_ndim=3, axes={-1: f_shape[-1]}) if f_shape[-1] is None: raise ValueError('The last dimension of the inputs to `CRF` ' 'should be defined. Found `None`.') if f_shape[-1] != self.output_dim: raise ValueError('The last dimension of the input shape must be equal to output' ' shape. Use a linear layer if needed.') self.input_spec = input_spec self.transitions = self.add_weight(name='transitions', shape=[self.output_dim, self.output_dim], initializer='glorot_uniform', trainable=True) self.built = True
[docs] def compute_mask(self, inputs: Tensors, mask: Optional[Tensors] = None): return None
[docs] def call(self, inputs: Tensors, sequence_lengths: Optional[Tensors] = None, training: Optional[Union[bool, int]] = None, mask: Optional[Tensors] = None, **kwargs) -> Tensors: sequences = tf.convert_to_tensor(inputs, dtype=self.dtype) if sequence_lengths is not None: assert len(sequence_lengths.shape) == 2 assert tf.convert_to_tensor(sequence_lengths).dtype == 'int32' seq_len_shape = tf.convert_to_tensor(sequence_lengths).get_shape().as_list() assert seq_len_shape[1] == 1 self.sequence_lengths = K.flatten(sequence_lengths) else: self.sequence_lengths = tf.ones(tf.shape(inputs)[0], dtype=tf.int32) * ( tf.shape(inputs)[1] ) viterbi_sequence, _ = crf_decode(sequences, self.transitions, self.sequence_lengths) output = K.one_hot(viterbi_sequence, self.output_dim) return K.in_train_phase(sequences, output, training=training)
@property
[docs] def loss(self) -> Callable: def crf_loss(y_true: Tensors, y_pred: Tensors) -> Tensors: y_true = K.argmax(y_true, 2) if self.sparse_target else y_true y_true = K.reshape(y_true, K.shape(y_pred)[:-1]) log_likelihood, _ = crf_log_likelihood( y_pred, K.cast(y_true, dtype='int32'), self.sequence_lengths, transition_params=self.transitions, ) return K.mean(-log_likelihood) return crf_loss
@property
[docs] def accuracy(self) -> Callable: def viterbi_accuracy(y_true: Tensors, y_pred: Tensors) -> Tensors: y_true = K.argmax(y_true, 2) if self.sparse_target else y_true y_true = K.reshape(y_true, K.shape(y_pred)[:-1]) viterbi_sequence, _ = crf_decode( y_pred, self.transitions, self.sequence_lengths ) mask = K.all(K.greater(y_pred, -1e6), axis=2) mask = K.cast(mask, K.floatx()) y_true = K.cast(y_true, 'int32') corrects = K.cast(K.equal(y_true, viterbi_sequence), K.floatx()) return K.sum(corrects * mask) / K.sum(mask) return viterbi_accuracy
[docs] def compute_output_shape(self, input_shape: Tensors) -> Tensors: tf.TensorShape(input_shape).assert_has_rank(3) return input_shape[:2] + (self.output_dim,)
@property
[docs] def trans(self) -> Tensors: """ transition parameters """ return K.eval(self.transitions)
[docs] def get_config(self) -> dict: config = { 'output_dim': self.output_dim, 'sparse_target': self.sparse_target, 'supports_masking': self.supports_masking, 'transitions': K.eval(self.transitions) } base_config = super(CRF, self).get_config() return dict(base_config, **config)
@staticmethod
[docs] def get_custom_objects() -> dict: return {'CRF': CRF}