"""Basic Inferencer"""
import os
import torch
from openicl import BaseRetriever, PromptTemplate
from openicl.utils.api_service import *
from openicl.icl_evaluator import *
from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, GPT2Tokenizer, AutoConfig, T5ForConditionalGeneration
from typing import List, Union, Optional
from accelerate import Accelerator
from accelerate import init_empty_weights, infer_auto_device_map
[docs]class BaseInferencer:
"""Basic In-context Learning Inferencer Class
Base class of In-context Learning Inferencer, with no inference method.
Attributes:
model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class.
tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`.
max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing.
output_json_filepath (:obj:`str`, optional): File path for output `JSON` file.
output_json_filename (:obj:`str`, optional): File name for output `JSON` file.
api_name (:obj:`str`, optional): Name of API service.
call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`.
"""
model = None
tokenizer = None
call_api = False
def __init__(self,
model_name: Optional[str] = 'gpt2-xl',
tokenizer_name: Optional[str] = None,
max_model_token_num: Optional[int] = None,
model_config: Optional[PretrainedConfig] = None,
batch_size: Optional[int] = 1,
accelerator: Optional[Accelerator] = None,
output_json_filepath: Optional[str] = "./icl_inference_output",
output_json_filename: Optional[str] = "predictions",
api_name: Optional[str] = None,
model_parallel: Optional[bool] = False,
**kwargs
) -> None:
self.model_name = model_name
self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name
self.accelerator = accelerator
self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False
self.api_name = api_name
if 'no_split_module_classes' not in kwargs.keys():
kwargs['no_split_module_classes'] = []
if 'device_map' not in kwargs.keys():
kwargs['device_map'] = None
no_split_module_classes = kwargs['no_split_module_classes']
device_map = kwargs['device_map']
self.__init_api(**kwargs)
if not self.call_api:
self.__init_model(self.model_name, model_config, model_parallel, device_map, no_split_module_classes)
self.__init_tokenizer(self.tokenizer_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.model is not None:
self.model.to(self.device)
self.max_model_token_num = max_model_token_num
self.batch_size = batch_size
self.output_json_filepath = output_json_filepath
self.output_json_filename = output_json_filename
if not os.path.exists(self.output_json_filepath):
os.makedirs(self.output_json_filepath)
[docs] def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, output_json_filename: Optional[str] = None) -> List:
"""Perform In-Context Inference given a retriever and optional templates.
Args:
retriever (:obj:`BaseRetriever`): An instance of a Retriever class that will be used to retrieve in-context examples
ice_template (:obj:`PromptTemplate`, optional): A template for generating the in-context examples prompt. Defaults to None.
prompt_template (:obj:`PromptTemplate`, optional): A template for generating the final prompt. Defaults to None.
output_json_filepath (:obj:`str`, optional): The file path to save the results as a `JSON` file. Defaults to None.
output_json_filename (:obj:`str`, optional): The file name to save the results as a `JSON` file. Defaults to None.
Raises:
NotImplementedError: If the function is not implemented in the subclass.
Returns:
:obj:`List:` A list of string, each representing the results of one inference.
"""
raise NotImplementedError("Method hasn't been implemented yet")
def __init_model(self, model_name, model_config, model_parallel, device_map, no_split_module_classes):
if not model_parallel:
if model_config is not None:
self.model = self.__get_hf_model_from_config(model_name, model_config)
else:
self.model = self.__get_hf_model_from_name(model_name)
else:
if model_config is None:
model_config = AutoConfig.from_pretrained(model_name)
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_config(model_config)
if device_map is None:
device_map = infer_auto_device_map(empty_model, no_split_module_classes=no_split_module_classes, dtype="float16")
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch.float16)
def __get_hf_model_from_name(self, model_name):
if 't5' in model_name:
return T5ForConditionalGeneration.from_pretrained(model_name)
else:
return AutoModelForCausalLM.from_pretrained(model_name)
def __get_hf_model_from_config(self, model_name, model_config):
if 't5' in model_name:
raise TypeError("T5 model has no 'from_config' method")
else:
return AutoModelForCausalLM.from_config(model_config)
def __init_tokenizer(self, tokenizer_name):
if self.api_name == 'opt-175b':
self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.padding_side = "left"
def __init_api(self, **kwargs):
if self.api_name == None:
return
self.call_api = is_api_available(self.api_name)
if not self.call_api:
UserWarning(f"api_name '{self.api_name}' is not available, Please check it")
else:
update_openicl_api_request_config(self.api_name, **kwargs)
def get_input_token_num(self, inputs):
return len(self.tokenizer(inputs, verbose=False)['input_ids'])
class GenInferencerOutputHandler:
origin_prompt_dict= {}
output_dict = {}
prediction_dict = {}
results_dict = {}
def __init__(self,
num: int,
accelerator: Optional[Accelerator] = None
) -> None:
self.num = num
self.accelerator = accelerator
self.origin_prompt_dict= {}
self.output_dict = {}
self.prediction_dict = {}
self.results_dict = {}
def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str):
self.results_dict = {
str(idx): {
'origin_prompt': self.origin_prompt_dict[str(idx)],
'output': self.output_dict[str(idx)],
'prediction': self.prediction_dict[str(idx)]
} for idx in range(self.num)
}
if self.accelerator is not None:
with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json', 'w', encoding='utf-8') as json_file:
json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
json_file.close()
def write_to_json(self, output_json_filepath: str, output_json_filename: str):
with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file:
json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
json_file.close()
def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str):
if self.accelerator is not None and self.accelerator.is_main_process:
for pid in range(self.accelerator.num_processes):
with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r', encoding='utf-8') as json_file:
subprocess_results_dict = json.load(json_file)
self.results_dict.update(subprocess_results_dict)
json_file.close()
def save_orgin_prompts(self, origin_prompts: List[str]):
for idx, origin_prompt in enumerate(origin_prompts):
if self.accelerator is not None:
idx = idx * self.accelerator.num_processes + self.accelerator.process_index
self.origin_prompt_dict[str(idx)] = origin_prompt
def save_prediction_and_output(self, prediction, output, idx):
if self.accelerator is not None:
idx = idx * self.accelerator.num_processes + self.accelerator.process_index
self.prediction_dict[str(idx)] = prediction
self.output_dict[str(idx)] = output
class PPLInferencerOutputHandler:
results_dict = {}
def __init__(self,
accelerator: Optional[Accelerator] = None
) -> None:
self.accelerator = accelerator
self.results_dict = {}
def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str):
if self.accelerator is not None:
with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json', 'w', encoding='utf-8') as json_file:
json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
json_file.close()
def write_to_json(self, output_json_filepath: str, output_json_filename: str):
with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file:
json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
json_file.close()
def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str):
if self.accelerator is not None and self.accelerator.is_main_process:
for pid in range(self.accelerator.num_processes):
with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r', encoding='utf-8') as json_file:
subprocess_results_dict = json.load(json_file)
self.results_dict.update(subprocess_results_dict)
json_file.close()
def save_ice(self, ice):
for idx, example in enumerate(ice):
if self.accelerator is not None:
idx = idx * self.accelerator.num_processes + self.accelerator.process_index
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example
def save_predictions(self, predictions):
for idx, prediction in enumerate(predictions):
if self.accelerator is not None:
idx = idx * self.accelerator.num_processes + self.accelerator.process_index
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['prediction'] = prediction
def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
if self.accelerator is not None:
idx = idx * self.accelerator.num_processes + self.accelerator.process_index
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys():
self.results_dict[str(idx)]['label: ' + str(label)] = {}
self.results_dict[str(idx)]['label: ' + str(label)]['testing input'] = input
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl