Source code for openicl.icl_dataset_reader

"""Simple Dataset Reader"""

from typing import List, Union, Optional, Dict
from datasets import load_dataset
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from datasets.splits import NamedSplit
from openicl.icl_prompt_template import PromptTemplate
from openicl.utils.check_type import _check_dataset, _check_type_list, _check_str
import random
import torch

[docs]class DatasetReader: """In-conext Learning Dataset Reader Class Generate an DatasetReader instance through 'dataset'. Attributes: dataset (:obj:`Dataset` or :obj:`DatasetDict`): The dataset to be read. input_columns (:obj:`List[str]` or :obj:`str`): A list of column names (a string of column name) in the dataset that represent(s) the input field. output_column (:obj:`str`): A column name in the dataset that represents the prediction field. ds_size (:obj:`int` or :obj:`float`, optional): The number of pieces of data to return. When ds_size is an integer and greater than or equal to 1, `ds_size` pieces of data are randomly returned. When 0 < :obj:`ds_size` < 1, ``int(len(dataset) * ds_size)`` pieces of data are randomly returned. (used for testing) references(:obj:`list`, optional): The list of references, initialized by ``self.dataset[self.test_split][self.output_column]``. input_template (:obj:`PromptTemplate`, optional): An instance of the :obj:`PromptTemplate` class, used to format the input field content during the retrieval process. (in some retrieval methods) output_template (:obj:`PromptTemplate`, optional): An instance of the :obj:`PromptTemplate` class, used to format the output field content during the retrieval process. (in some learnable retrieval methods) input_output_template (:obj:`PromptTemplate`, optional): An instance of the `PromptTemplate` class, used to format the input-output field content during the retrieval process. (in some retrieval methods) """ dataset = None input_template = None output_template = None input_output_template = None references = None def __init__(self, dataset: Union[Dataset, DatasetDict, str], input_columns: Union[List[str], str], output_column: str, name: Optional[str] = None, data_files: Optional[str] = None, input_template: Optional[PromptTemplate] = None, output_template: Optional[PromptTemplate] = None, input_output_template: Optional[PromptTemplate] = None, ds_size: Union[None, int, float] = None, split: Optional[NamedSplit] = None, test_split: Optional[str] = 'test' ) -> None: self.input_columns = _check_type_list(input_columns, [List, str]) if isinstance(self.input_columns, str): self.input_columns = self.input_columns.split() self.output_column = _check_str(output_column) self.ds_size = _check_type_list(ds_size, [None, int, float]) if input_template is not None: self.input_template = PromptTemplate._check_prompt_template(input_template) if output_template is not None: self.output_template = PromptTemplate._check_prompt_template(output_template) if input_output_template is not None: self.input_output_template = PromptTemplate._check_prompt_template(input_output_template) if isinstance(dataset, str): self.dataset = load_dataset(dataset, name=name, data_files=data_files) else: self.dataset = _check_dataset(dataset) if split is not None and isinstance(self.dataset, DatasetDict): self.dataset = self.dataset[split] if self.ds_size is not None: if isinstance(self.dataset, Dataset): self.dataset = load_partial_dataset(dataset, size=self.ds_size) if isinstance(self.dataset, DatasetDict): for ds_name in self.dataset.keys(): self.dataset[ds_name] = load_partial_dataset(self.dataset[ds_name], size=self.ds_size) if isinstance(self.dataset, DatasetDict): if test_split in self.dataset.keys(): self.references = self.dataset[test_split][self.output_column] elif isinstance(self.dataset, Dataset): self.references = self.dataset[self.output_column]
[docs] def set_references(self, column: str, split: Optional[str] = None) -> None: """Set :obj:`self.references` based on :obj:`column` and optional :obj:`split`. Args: column (:obj:`str`): A string of column name. split (:obj:`str`, optional): A string of dataset split. Defaults to ``None``. """ if split is not None: self.references = self.dataset[split][column] else: self.references = self.dataset[column]
[docs] def generate_input_field_prompt(self, entry: Dict) -> str: """Generate a prompt for the input field based on the provided :obj:`entry` data. Args: entry (:obj:`Dict`): A piece of data to be used for generating the prompt. Returns: :obj:`str`: The generated prompt. """ prompt = None if self.input_template is None: prompt = ' '.join([str(entry[ctx]) for ctx in self.input_columns]) else: prompt = self.input_template.generate_item(entry) return prompt
[docs] def generate_input_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[str]: """Generate corpus for input field. Args: dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance. split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``. Returns: :obj:`List[str]`: A list of generated input field prompts. """ if split is not None: dataset = dataset[split] corpus = [] for entry in dataset: corpus.append(self.generate_input_field_prompt(entry)) return corpus
[docs] def generate_ouput_field_prompt(self, entry: Dict) -> str: """Generate a prompt for the output field based on the provided :obj:`entry` data. Args: entry (:obj:`Dict`): A piece of data to be used for generating the prompt. Returns: :obj:`str`: The generated prompt. """ prompt = None if self.output_template is None: prompt = str(entry[self.output_column]) else: prompt = self.output_template.generate_item(entry) return prompt
[docs] def generate_output_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[str]: """Generate corpus for output field. Args: dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance. split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``. Returns: :obj:`List[str]`: A list of generated output field prompts. """ if split is not None: dataset = dataset[split] corpus = [] for entry in dataset: corpus.append(self.generate_ouput_field_prompt(entry)) return corpus
[docs] def generate_input_output_field_prompt(self, entry: Dict) -> str: """Generate a prompt for the input-output field based on the provided:obj:`entry` data. Args: entry (:obj:`Dict`): A piece of data to be used for generating the prompt. Returns: :obj:`str`: The generated prompt. """ prompt = None if self.input_output_template is None: prompt = ' '.join([entry[ctx] for ctx in self.input_columns] + [str(entry[self.output_column])]) else: prompt = self.input_output_template.generate_item(entry) return prompt
[docs] def generate_input_output_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[str]: """Generate corpus for input-output field. Args: dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance. split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``. Returns: :obj:`List[str]`: A list of generated input-output field prompts. """ if split is not None: dataset = dataset[split] corpus = [] for entry in dataset: corpus.append(self.generate_input_output_field_prompt(entry)) return corpus
def _check_dataset_reader(obj) -> "DatasetReader": if isinstance(obj, DatasetReader): return obj else: raise TypeError(f"Expected a DatasetReader object, but got {obj}") def __len__(self): return len(self.dataset) def __getitem__(self, idx): return self.dataset[idx] def __repr__(self): return f"DatasetReader({{\n dataset: {self.dataset},\n input_columns: {self.input_columns},\n output_columns: {self.output_column}\n}})"
def load_partial_dataset(dataset: Dataset, size: Optional[Union[int, float]] = None) -> Dataset: total_size = len(dataset) if size >= total_size or size <= 0: return dataset if size > 0 and size < 1: size = int(size * total_size) rand = random.Random(x=size) index_list = list(range(total_size)) rand.shuffle(index_list) dataset = dataset.select(index_list[:size]) return dataset class DatasetEncoder(torch.utils.data.Dataset): def __init__(self, datalist: List, model_name = None, tokenizer = None) -> None: self.datalist = datalist if model_name is None and tokenizer is None: raise ValueError("model_name and tokenizer could not both be None") if tokenizer is not None: self.tokenizer = tokenizer else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = "left" self.encode_dataset = [] self.init_dataset() self.datalist_length = len(self.encode_dataset) def init_dataset(self): for idx, data in enumerate(self.datalist): tokenized_data = self.tokenizer.encode_plus(data, truncation=True, return_tensors='pt', verbose=False) self.encode_dataset.append({ 'input_ids': tokenized_data.input_ids[0], 'attention_mask': tokenized_data.attention_mask[0], "metadata": {"id": idx, "len": len(tokenized_data.input_ids[0]), "text": data} }) def __len__(self): return self.datalist_length def __getitem__(self, idx): return self.encode_dataset[idx]