Source code for langml.baselines.contrastive.simcse.dataloder

# -*- coding: utf-8 -*-

import json
import math
from random import shuffle
from typing import Callable, List, Tuple

import numpy as np
import tensorflow as tf
from boltons.iterutils import chunked_iter

from langml.baselines import BaseDataLoader
from langml.baselines.contrastive.utils import aeda_augment, whitespace_tokenize
from langml import TF_KERAS
if TF_KERAS:
    from tensorflow.keras.preprocessing.sequence import pad_sequences
else:
    from keras.preprocessing.sequence import pad_sequences


[docs]class DataLoader(BaseDataLoader): def __init__(self, data: List, tokenizer: object, batch_size: int = 32): self.data = data self.batch_size = batch_size self.tokenizer = tokenizer
[docs] def __len__(self) -> int: return math.ceil(len(self.data) / self.batch_size)
@staticmethod
[docs] def load_data(fpath: str, apply_aeda: bool = True, aeda_tokenize: Callable = whitespace_tokenize, aeda_language: str = 'EN') -> Tuple[ List[Tuple[str, str]], List[Tuple[str, str, int]]]: """ Args: fpath: str, path of data apply_aeda: bool, whether to apply the AEDA technique to augment data, default True aeda_tokenize: Callable, specify aeda tokenize function, it works when set apply_aeda=True aeda_language: str, specifying the language, it works when set apply_aeda=True """ data, data_with_label = [], [] with open(fpath, 'r', encoding='utf-8') as reader: for line in reader: line = line.strip() if not line: continue obj = json.loads(line) if 'text_left' in obj: data_with_label.append((obj['text_left'], obj['text_right'], int(obj['label']))) texts = [obj['text_left'], obj['text_right']] else: texts = [obj['text']] for text in texts: if apply_aeda: augmented_text = aeda_augment(aeda_tokenize(text), language=aeda_language) else: augmented_text = text data.append((text, augmented_text)) return data, data_with_label
[docs] def make_iter(self, random: bool = False): if random: shuffle(self.data) for chunks in chunked_iter(self.data, self.batch_size): batch_tokens, batch_segments, batch_augmented_tokens, batch_augmented_segments = [], [], [], [] for text, augmented_text in chunks: tokenized = self.tokenizer.encode(text) batch_tokens.append(tokenized.ids) batch_segments.append(tokenized.segment_ids) tokenized = self.tokenizer.encode(augmented_text) batch_augmented_tokens.append(tokenized.ids) batch_augmented_segments.append(tokenized.segment_ids) batch_tokens = pad_sequences(batch_tokens, padding='post', truncating='post') batch_segments = pad_sequences(batch_segments, padding='post', truncating='post') batch_augmented_tokens = pad_sequences(batch_augmented_tokens, padding='post', truncating='post') batch_augmented_segments = pad_sequences(batch_augmented_segments, padding='post', truncating='post') batch_labels = np.zeros((len(batch_tokens), 1)) yield [batch_tokens, batch_segments, batch_augmented_tokens, batch_augmented_segments], batch_labels
[docs]class TFDataLoader(DataLoader):
[docs] def make_iter(self, random: bool = False): def gen_features(): for text, augmented_text in self.data: tokenized = self.tokenizer.encode(text) d = { 'Input-Token': tokenized.ids, 'Input-Segment': tokenized.segment_ids } tokenized = self.tokenizer.encode(augmented_text) d = dict(d, **{ 'Input-Augmented-Token': tokenized.ids, 'Input-Augmented-Segment': tokenized.segment_ids }) yield d, [0] d = { 'Input-Token': tf.int64, 'Input-Segment': tf.int64, 'Input-Augmented-Token': tf.int64, 'Input-Augmented-Segment': tf.int64 } output_types = (d, tf.int64) d = { 'Input-Token': tf.TensorShape((None, )), 'Input-Segment': tf.TensorShape((None, )), 'Input-Augmented-Token': tf.TensorShape((None, )), 'Input-Augmented-Segment': tf.TensorShape((None, )) } output_shapes = (d, tf.TensorShape((1, ))) dataset = tf.data.Dataset.from_generator(gen_features, output_types=output_types, output_shapes=output_shapes) dataset = dataset.repeat() if random: dataset = dataset.shuffle(self.batch_size * 1000) dataset = dataset.padded_batch(self.batch_size, output_shapes).prefetch(self.batch_size * 1000) return dataset
[docs] def __call__(self, random: bool = False): return self.make_iter(random=random)