Source code for langml.baselines.ner

# -*- coding: utf-8 -*-

import re
from typing import Dict, List, Optional

import numpy as np
from seqeval.metrics import classification_report

from langml import TF_VERSION
from langml.utils import bio_decode
from langml.tensor_typing import Models


[docs]re_split = re.compile(r'.*?[\n。]+')
[docs]class Infer: def __init__(self, model: Models, tokenizer: object, id2label: Dict, max_chunk_len: Optional[int] = None, is_bert: bool = True): self.model = model self.tokenizer = tokenizer self.id2label = id2label self.max_chunk_len = max_chunk_len self.is_bert = is_bert
[docs] def decode_one(self, text: str, base_position: int = 0): """ Args: - text: str - base_position: int Return: list of tuple: [(entity, start, end, entity_type)] """ tokened = self.tokenizer.encode(text) mapping = self.tokenizer.tokens_mapping(text, tokened.tokens) token_ids = np.array([tokened.ids]) segment_ids = np.array([tokened.segment_ids]) if TF_VERSION > 1: if self.is_bert: logits = self.model([token_ids, segment_ids]) else: logits = self.model([token_ids]) else: if self.is_bert: logits = self.model.predict([token_ids, segment_ids]) else: logits = self.model.predict([token_ids]) tags = [self.id2label[i] for i in np.argmax(logits[0], axis=1)] entities = bio_decode(tags) res = [] for s, e, t in entities: s = mapping[s] e = mapping[e] s = 0 if not s else s[0] e = len(text) - 1 if not e else e[-1] res.append((base_position + s, base_position + e + 1, text[s: e + 1], t)) return res
[docs] def __call__(self, text: str): if self.max_chunk_len is None or len(text) < self.max_chunk_len: return self.decode_one(text) sentences = re_split.findall(text) if not sentences: return self.decode_one(text) results = [] prev, base_position = 0, 0 for i in range(1, len(sentences)): current_text = ''.join(sentences[prev: i]) if len(current_text) <= self.max_chunk_len and len(''.join(sentences[prev: i+1])) > self.max_chunk_len: results.extend(self.decode_one(current_text, base_position=base_position)) prev = i base_position += len(current_text) results.extend(self.decode_one(''.join(sentences[prev:]), base_position=base_position)) return results
[docs]def compute_detail_metrics(model: Models, dataloader: object, id2label: Dict, is_bert: bool = True): all_preds, all_golds = [], [] for pairs in dataloader.data: token_ids, segment_ids, labels = dataloader.encode_data(pairs) token_ids = np.array([token_ids]) segment_ids = np.array([segment_ids]) if TF_VERSION > 1: if is_bert: logits = model([token_ids, segment_ids]) else: logits = model([token_ids]) else: if is_bert: logits = model.predict([token_ids, segment_ids]) else: logits = model.predict([token_ids]) gold_tags = [id2label[i] for i in labels] pred_tags = [id2label[i] for i in np.argmax(logits[0], axis=1)] assert len(gold_tags) == len(pred_tags), f'g: {len(gold_tags)}, t: {len(pred_tags)}' all_preds.append(pred_tags) all_golds.append(gold_tags) print(classification_report(all_golds, all_preds, digits=4))