Source code for openicl.icl_inferencer.icl_base_inferencer

"""Basic Inferencer"""

import os
import torch
from openicl import BaseRetriever, PromptTemplate
from openicl.utils.api_service import *
from openicl.icl_evaluator import *
from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, GPT2Tokenizer, AutoConfig, T5ForConditionalGeneration
from typing import List, Union, Optional
from accelerate import Accelerator
from accelerate import init_empty_weights, infer_auto_device_map

[docs]class BaseInferencer: """Basic In-context Learning Inferencer Class Base class of In-context Learning Inferencer, with no inference method. Attributes: model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class. tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`. max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM. batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing. output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. output_json_filename (:obj:`str`, optional): File name for output `JSON` file. api_name (:obj:`str`, optional): Name of API service. call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`. """ model = None tokenizer = None call_api = False def __init__(self, model_name: Optional[str] = 'gpt2-xl', tokenizer_name: Optional[str] = None, max_model_token_num: Optional[int] = None, model_config: Optional[PretrainedConfig] = None, batch_size: Optional[int] = 1, accelerator: Optional[Accelerator] = None, output_json_filepath: Optional[str] = "./icl_inference_output", output_json_filename: Optional[str] = "predictions", api_name: Optional[str] = None, model_parallel: Optional[bool] = False, **kwargs ) -> None: self.model_name = model_name self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name self.accelerator = accelerator self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False self.api_name = api_name if 'no_split_module_classes' not in kwargs.keys(): kwargs['no_split_module_classes'] = [] if 'device_map' not in kwargs.keys(): kwargs['device_map'] = None no_split_module_classes = kwargs['no_split_module_classes'] device_map = kwargs['device_map'] self.__init_api(**kwargs) if not self.call_api: self.__init_model(self.model_name, model_config, model_parallel, device_map, no_split_module_classes) self.__init_tokenizer(self.tokenizer_name) self.device = "cuda" if torch.cuda.is_available() else "cpu" if self.model is not None: self.model.to(self.device) self.max_model_token_num = max_model_token_num self.batch_size = batch_size self.output_json_filepath = output_json_filepath self.output_json_filename = output_json_filename if not os.path.exists(self.output_json_filepath): os.makedirs(self.output_json_filepath)
[docs] def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, output_json_filename: Optional[str] = None) -> List: """Perform In-Context Inference given a retriever and optional templates. Args: retriever (:obj:`BaseRetriever`): An instance of a Retriever class that will be used to retrieve in-context examples ice_template (:obj:`PromptTemplate`, optional): A template for generating the in-context examples prompt. Defaults to None. prompt_template (:obj:`PromptTemplate`, optional): A template for generating the final prompt. Defaults to None. output_json_filepath (:obj:`str`, optional): The file path to save the results as a `JSON` file. Defaults to None. output_json_filename (:obj:`str`, optional): The file name to save the results as a `JSON` file. Defaults to None. Raises: NotImplementedError: If the function is not implemented in the subclass. Returns: :obj:`List:` A list of string, each representing the results of one inference. """ raise NotImplementedError("Method hasn't been implemented yet")
def __init_model(self, model_name, model_config, model_parallel, device_map, no_split_module_classes): if not model_parallel: if model_config is not None: self.model = self.__get_hf_model_from_config(model_name, model_config) else: self.model = self.__get_hf_model_from_name(model_name) else: if model_config is None: model_config = AutoConfig.from_pretrained(model_name) with init_empty_weights(): empty_model = AutoModelForCausalLM.from_config(model_config) if device_map is None: device_map = infer_auto_device_map(empty_model, no_split_module_classes=no_split_module_classes, dtype="float16") self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch.float16) def __get_hf_model_from_name(self, model_name): if 't5' in model_name: return T5ForConditionalGeneration.from_pretrained(model_name) else: return AutoModelForCausalLM.from_pretrained(model_name) def __get_hf_model_from_config(self, model_name, model_config): if 't5' in model_name: raise TypeError("T5 model has no 'from_config' method") else: return AutoModelForCausalLM.from_config(model_config) def __init_tokenizer(self, tokenizer_name): if self.api_name == 'opt-175b': self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b", use_fast=False) else: self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.padding_side = "left" def __init_api(self, **kwargs): if self.api_name == None: return self.call_api = is_api_available(self.api_name) if not self.call_api: UserWarning(f"api_name '{self.api_name}' is not available, Please check it") else: update_openicl_api_request_config(self.api_name, **kwargs) def get_input_token_num(self, inputs): return len(self.tokenizer(inputs, verbose=False)['input_ids'])
class GenInferencerOutputHandler: origin_prompt_dict= {} output_dict = {} prediction_dict = {} results_dict = {} def __init__(self, num: int, accelerator: Optional[Accelerator] = None ) -> None: self.num = num self.accelerator = accelerator self.origin_prompt_dict= {} self.output_dict = {} self.prediction_dict = {} self.results_dict = {} def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str): self.results_dict = { str(idx): { 'origin_prompt': self.origin_prompt_dict[str(idx)], 'output': self.output_dict[str(idx)], 'prediction': self.prediction_dict[str(idx)] } for idx in range(self.num) } if self.accelerator is not None: with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json', 'w', encoding='utf-8') as json_file: json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) json_file.close() def write_to_json(self, output_json_filepath: str, output_json_filename: str): with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file: json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) json_file.close() def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str): if self.accelerator is not None and self.accelerator.is_main_process: for pid in range(self.accelerator.num_processes): with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r', encoding='utf-8') as json_file: subprocess_results_dict = json.load(json_file) self.results_dict.update(subprocess_results_dict) json_file.close() def save_orgin_prompts(self, origin_prompts: List[str]): for idx, origin_prompt in enumerate(origin_prompts): if self.accelerator is not None: idx = idx * self.accelerator.num_processes + self.accelerator.process_index self.origin_prompt_dict[str(idx)] = origin_prompt def save_prediction_and_output(self, prediction, output, idx): if self.accelerator is not None: idx = idx * self.accelerator.num_processes + self.accelerator.process_index self.prediction_dict[str(idx)] = prediction self.output_dict[str(idx)] = output class PPLInferencerOutputHandler: results_dict = {} def __init__(self, accelerator: Optional[Accelerator] = None ) -> None: self.accelerator = accelerator self.results_dict = {} def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str): if self.accelerator is not None: with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json', 'w', encoding='utf-8') as json_file: json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) json_file.close() def write_to_json(self, output_json_filepath: str, output_json_filename: str): with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file: json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) json_file.close() def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str): if self.accelerator is not None and self.accelerator.is_main_process: for pid in range(self.accelerator.num_processes): with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r', encoding='utf-8') as json_file: subprocess_results_dict = json.load(json_file) self.results_dict.update(subprocess_results_dict) json_file.close() def save_ice(self, ice): for idx, example in enumerate(ice): if self.accelerator is not None: idx = idx * self.accelerator.num_processes + self.accelerator.process_index if str(idx) not in self.results_dict.keys(): self.results_dict[str(idx)] = {} self.results_dict[str(idx)]['in-context examples'] = example def save_predictions(self, predictions): for idx, prediction in enumerate(predictions): if self.accelerator is not None: idx = idx * self.accelerator.num_processes + self.accelerator.process_index if str(idx) not in self.results_dict.keys(): self.results_dict[str(idx)] = {} self.results_dict[str(idx)]['prediction'] = prediction def save_prompt_and_ppl(self, label, input, prompt, ppl, idx): if self.accelerator is not None: idx = idx * self.accelerator.num_processes + self.accelerator.process_index if str(idx) not in self.results_dict.keys(): self.results_dict[str(idx)] = {} if 'label: ' + str(label) not in self.results_dict[str(idx)].keys(): self.results_dict[str(idx)]['label: ' + str(label)] = {} self.results_dict[str(idx)]['label: ' + str(label)]['testing input'] = input self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl