Source code for langml.baselines.ner.bert_crf

# -*- coding: utf-8 -*-

from langml import TF_KERAS
if TF_KERAS:
    import tensorflow.keras as keras
    import tensorflow.keras.layers as L
else:
    import keras
    import keras.layers as L

from langml.plm.albert import load_albert
from langml.plm.bert import load_bert
from langml.layers import CRF
from langml.baselines import BaselineModel, Parameters
from langml.tensor_typing import Models


[docs]class BertCRF(BaselineModel): def __init__(self, config_path: str, ckpt_path: str, params: Parameters, backbone: str = 'roberta'): self.config_path = config_path self.ckpt_path = ckpt_path self.params = params assert backbone in ['bert', 'roberta', 'albert'] if backbone == 'albert': self.load_plm = load_albert else: self.load_plm = load_bert
[docs] def build_model(self, lazy_restore=False) -> Models: if lazy_restore: model, _, restore_bert_weights = self.load_plm(self.config_path, self.ckpt_path, lazy_restore=True) else: model, _ = self.load_plm(self.config_path, self.ckpt_path) crf = CRF(self.params.tag_size, sparse_target=False, name='crf') output = L.Dropout(self.params.dropout_rate)(model.output) output = L.Dense(self.params.tag_size, name='tag')(output) output = crf(output) train_model = keras.Model(model.input, output) train_model.summary() train_model.compile(keras.optimizers.Adam(self.params.learning_rate), loss=crf.loss, metrics=[crf.accuracy]) if lazy_restore: restore_bert_weights(model) return train_model