Source code for langml.prompt.base

# -*- coding: utf-8 -*-

from abc import ABCMeta, abstractmethod
from typing import Dict, List

from langml.tensor_typing import Models
from langml.tokenizer import Tokenizer
from langml.plm import load_albert, load_bert
from langml.log import info


[docs]class Template: def __init__(self, template: List[str], label_tokens_map: Dict[str, List[str]], tokenizer: Tokenizer) -> None: self.tokenizer = tokenizer self.unk_id = self.tokenizer.token_to_id(self.tokenizer.special_tokens.UNK) self.template_ids = self.encode_template(template) self.label2tokens, self.id2label = self.encode_label_tokens_map(label_tokens_map) info(f'template ids: {self.template_ids}')
[docs] def __len__(self) -> int: return len(self.template_ids)
[docs] def encode_template(self, template: str) -> List[int]: return [self.tokenizer.token_to_id(token) for token in template]
[docs] def encode_label_tokens_map(self, label_tokens_map: Dict[str, List[str]]) -> Dict[str, List[int]]: label2ids, id2label = {}, {} for label, tokens in label_tokens_map.items(): token_ids = [] for token in tokens: token_id = self.tokenizer.token_to_id(token) assert token_id != self.unk_id, f'unknown token {token}! please specify a token from vocabulary' token_ids.append(token_id) id2label[token_id] = label label2ids[label] = token_ids return label2ids, id2label
[docs] def decode_label(self, idx: int, default='<UNK>') -> str: return self.id2label.get(idx, default)
[docs]class BasePromptModel(metaclass=ABCMeta): def __init__(self, plm_backbone: str, plm_config_path: str, plm_ckpt_path: str, template: Template, learning_rate: float = 1e-5, freeze_plm: bool = True) -> None: """ Initialize Prompt Model Args: - plm_backbone: str, backbone of pretrained language model - plm_config_path: str, configure path of pretrained language model - plm_ckpt_path: str, checkpoint path of pretrained language model - template: List[str], template - label_tokens_map: str, verbalizer, map of label to tokens - tokenizer: langml.Tokenizer, tokenizer - learning_rate: float, learning rate - freeze_plm: bool, whether to freeze pretrained language model weights """ self.model = None self.freeze_plm = freeze_plm if plm_backbone == 'albert': _, self.plm, self.lazy_restore_callback = load_albert( config_path=plm_config_path, checkpoint_path=plm_ckpt_path, pretraining=True, with_mlm=True, with_nsp=False, lazy_restore=True) else: _, self.plm, self.lazy_restore_callback = load_bert( config_path=plm_config_path, checkpoint_path=plm_ckpt_path, pretraining=True, with_mlm=True, with_nsp=False, lazy_restore=True) self.template = template self.learning_rate = learning_rate @abstractmethod
[docs] def build_model(self) -> Models: raise NotImplementedError
[docs]class BasePromptTask(metaclass=ABCMeta): def __init__(self, prompt_model: BasePromptModel, tokenizer: Tokenizer) -> None: self.prompt_model = prompt_model self.template = prompt_model.template self.tokenizer = tokenizer self.mask_id = self.tokenizer.token_to_id(self.tokenizer.special_tokens.MASK) self.model = self.prompt_model.build_model() @abstractmethod
[docs] def fit(self): raise NotImplementedError
@abstractmethod
[docs] def predict(self): raise NotImplementedError
[docs]class BaseDataGenerator(metaclass=ABCMeta): @abstractmethod
[docs] def make_iter(self, random: bool = False): raise NotImplementedError
@abstractmethod
[docs] def __len__(self): raise NotImplementedError
[docs] def __call__(self, random: bool = False): while True: for inputs, labels in self.make_iter(random=random): yield inputs, labels