"""Basic Retriever"""
from datasets import Dataset, DatasetDict
from typing import List, Union, Optional, Tuple, Dict
from openicl import DatasetReader, PromptTemplate
from openicl.utils.check_type import _check_str
from accelerate import Accelerator
[docs]class BaseRetriever:
"""Basic In-context Learning Retriever Class
Base class for In-context Learning Retriever, without any retrieval method.
Attributes:
dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
ice_separator (:obj:`str`, optional): A string that separates each in-context example.
ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
ice_num (:obj:`int`, optional): The number of data in the in-context examples.
index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
"""
index_ds = None
test_ds = None
def __init__(self,
dataset_reader: DatasetReader,
ice_separator: Optional[str] ='\n',
ice_eos_token: Optional[str] ='\n',
prompt_eos_token: Optional[str] = '',
ice_num: Optional[int] = 1,
index_split: Optional[str] = 'train',
test_split: Optional[str] = 'test',
accelerator: Optional[Accelerator] = None
) -> None:
self.dataset_reader = DatasetReader._check_dataset_reader(dataset_reader)
self.ice_separator = ice_separator
self.ice_eos_token = ice_eos_token
self.prompt_eos_token = prompt_eos_token
self.ice_num = ice_num
self.index_split = index_split
self.test_split = test_split
self.accelerator = accelerator
self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False
if isinstance(self.dataset_reader.dataset, Dataset):
self.index_ds = self.dataset_reader.dataset
self.test_ds = self.dataset_reader.dataset
if self.accelerator is not None:
self.test_ds = self.test_ds.shard(
num_shards=self.accelerator.num_processes,
index=self.accelerator.process_index
)
else:
self.index_ds = self.dataset_reader.dataset[self.index_split]
self.test_ds = self.dataset_reader.dataset[self.test_split]
if self.accelerator is not None:
self.test_ds = self.test_ds.shard(
num_shards=self.accelerator.num_processes,
index=self.accelerator.process_index
)
[docs] def retrieve(self) -> List[List]:
"""
Retrieve for each data in generation_ds.
Returns:
`List[List]`: the index list of in-context example for each data in `test_ds`.
"""
raise NotImplementedError("Method hasn't been implemented yet")
def get_labels(self, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None):
labels = []
if prompt_template is not None and isinstance(prompt_template.template, Dict):
labels = list(prompt_template.template.keys())[:]
elif ice_template is not None and ice_template.ice_token is not None and isinstance(ice_template.template, Dict):
labels = list(ice_template.template.keys())[:]
else:
labels = list(set(self.test_ds[self.dataset_reader.output_column]))
return labels
def generate_ice(self, idx_list: List[int], ice_template: Optional[PromptTemplate] = None) -> str:
generated_ice_list = []
dr = self.dataset_reader
for idx in idx_list:
if ice_template is None:
generated_ice_list.append(' '.join(list(map(str,[self.index_ds[idx][ctx] for ctx in dr.input_columns] + [self.index_ds[idx][dr.output_column]]))))
else:
generated_ice_list.append(ice_template.generate_ice_item(self.index_ds[idx], self.index_ds[idx][dr.output_column]))
generated_ice = self.ice_separator.join(generated_ice_list) + self.ice_eos_token
return generated_ice
def generate_prompt(self, idx: int, ice: str, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None) -> Tuple[List[str], List]:
prompt_list = []
labels = []
if prompt_template is not None and isinstance(prompt_template.template, Dict):
labels = list(prompt_template.template.keys())[:]
elif ice_template is not None and isinstance(ice_template.template, Dict) and ice_template.ice_token is not None:
labels = list(ice_template.template.keys())[:]
else:
labels = list(set(self.test_ds[self.dataset_reader.output_column]))
for label in labels:
prompt_list.append(self.generate_label_prompt(idx, ice, label))
return prompt_list, labels
def generate_label_prompt(self, idx: int, ice: str, label, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None) -> str:
if prompt_template is not None:
return prompt_template.generate_label_prompt_item(self.test_ds[idx], ice, label) + self.prompt_eos_token
elif ice_template is not None and ice_template.ice_token is not None:
return ice_template.generate_label_prompt_item(self.test_ds[idx], ice, label) + self.prompt_eos_token
else:
prefix_prompt = ' '.join(list(map(str,[self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns])))
return ice + prefix_prompt + ' ' + str(label) + self.prompt_eos_token
def generate_prompt_for_generate_task(self, idx, ice, gen_field_replace_token='', ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None):
if prompt_template is not None:
return prompt_template.generate_item(self.test_ds[idx], output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) + self.prompt_eos_token
elif ice_template is not None and ice_template.ice_token is not None:
return ice_template.generate_item(self.test_ds[idx], output_field=self.dataset_reader.output_column, output_field_replace_token=gen_field_replace_token, ice_field_replace_token=ice) + self.prompt_eos_token
else:
prefix_prompt = ' '.join(list(map(str,[self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns])))
return ice + prefix_prompt + gen_field_replace_token + self.prompt_eos_token