Source code for langml.baselines.contrastive.utils

# -*- coding: utf-8 -*-

import random
from typing import List


[docs]CN_PUNCTUATIONS = ['。', ',', '?', '!', ';']
[docs]EN_PUNCTUATIONS = ['.', ',', '!', '?', ';', ':']
[docs]def aeda_augment(words: List[str], ratio: float = 0.3, language: str = 'EN') -> str: """ AEDA:An Easier Data Augmentation Technique for Text Classification Args: text: str, input text ratio: float, ratio to add punctuation randomly language: str, specify language from ['EN', 'CN'], default EN """ assert language in ['EN', 'CN'], 'please specify language from ["EN", "CN"]' punctuations = EN_PUNCTUATIONS if language == 'EN' else CN_PUNCTUATIONS join_str = ' ' if language == 'EN' else '' q = random.randint(1, int(ratio * len(words) + 1)) qs = random.sample(range(len(words)), q) augmented_words = [] for i, word in enumerate(words): if i in qs: augmented_words.append(random.choice(punctuations)) augmented_words.append(word) return join_str.join(augmented_words)
[docs]def whitespace_tokenize(text: str) -> List[str]: return text.split()