langml.baselines.contrastive.simcse.model
Module Contents
Classes
Functions
|
- class langml.baselines.contrastive.simcse.model.SimCSE(config_path: str, ckpt_path: str, params: langml.baselines.Parameters, backbone: str = 'roberta')[source]
Bases:
langml.baselines.BaselineModel- get_pooling_output(self, model: langml.tensor_typing.Models, output_index: int, pooling_strategy: str = 'cls') langml.tensor_typing.Tensors[source]
get pooling output :param model: keras.Model, BERT model :param output_index: int, specify output index of feedforward layer. :param pooling_strategy: str, specify pooling strategy from [‘cls’, ‘first-last-avg’, ‘last-avg’], default cls