Spaces:
Configuration error
Configuration error
| # -*- coding: utf-8 -*- | |
| """ | |
| @author:XuMing([email protected]) | |
| @description: | |
| """ | |
| import math | |
| import os | |
| from dataclasses import dataclass, field | |
| from glob import glob | |
| from typing import Any, List, Union, Optional, Dict | |
| import torch | |
| from datasets import load_dataset | |
| from loguru import logger | |
| from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training | |
| from sklearn.metrics import mean_squared_error, mean_absolute_error | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| AutoConfig, | |
| PreTrainedTokenizerBase, | |
| BloomForSequenceClassification, | |
| LlamaForSequenceClassification, | |
| LlamaTokenizer, | |
| BloomTokenizerFast, | |
| AlbertForSequenceClassification, | |
| BertForSequenceClassification, | |
| BertTokenizer, | |
| AutoTokenizer, | |
| RobertaForSequenceClassification, | |
| AutoModelForSequenceClassification, | |
| RobertaTokenizer, | |
| HfArgumentParser, | |
| Trainer, | |
| TrainingArguments, | |
| set_seed, | |
| ) | |
| from transformers.trainer import TRAINING_ARGS_NAME | |
| MODEL_CLASSES = { | |
| "bert": (AutoConfig, BertForSequenceClassification, BertTokenizer), | |
| "roberta": (AutoConfig, RobertaForSequenceClassification, RobertaTokenizer), | |
| "albert": (AutoConfig, AlbertForSequenceClassification, AutoTokenizer), | |
| "bloom": (AutoConfig, BloomForSequenceClassification, BloomTokenizerFast), | |
| "llama": (AutoConfig, LlamaForSequenceClassification, LlamaTokenizer), | |
| "auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer), | |
| } | |
| class ModelArguments: | |
| """ | |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
| """ | |
| model_type: str = field( | |
| default=None, | |
| metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} | |
| ) | |
| model_name_or_path: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." | |
| ) | |
| }, | |
| ) | |
| tokenizer_name_or_path: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "The tokenizer for weights initialization.Don't set if you want to train a model from scratch." | |
| ) | |
| }, | |
| ) | |
| load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."}) | |
| cache_dir: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, | |
| ) | |
| use_fast_tokenizer: bool = field( | |
| default=False, | |
| metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
| ) | |
| torch_dtype: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " | |
| "dtype will be automatically derived from the model's weights." | |
| ), | |
| "choices": ["auto", "bfloat16", "float16", "float32"], | |
| }, | |
| ) | |
| device_map: Optional[str] = field( | |
| default="auto", | |
| metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "}, | |
| ) | |
| trust_remote_code: bool = field( | |
| default=True, | |
| metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."}, | |
| ) | |
| def __post_init__(self): | |
| if self.model_type is None: | |
| raise ValueError( | |
| "You must specify a valid model_type to run training. Available model types are " + ", ".join( | |
| MODEL_CLASSES.keys())) | |
| if self.model_name_or_path is None: | |
| raise ValueError("You must specify a valid model_name_or_path to run training.") | |
| class DataTrainingArguments: | |
| """ | |
| Arguments pertaining to what data we are going to input our model for training and eval. | |
| """ | |
| dataset_name: Optional[str] = field( | |
| default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} | |
| ) | |
| dataset_config_name: Optional[str] = field( | |
| default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} | |
| ) | |
| train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."}) | |
| validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, ) | |
| max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"}) | |
| max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"}) | |
| max_train_samples: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "For debugging purposes or quicker training, truncate the number of training examples to this " | |
| "value if set." | |
| ) | |
| }, | |
| ) | |
| max_eval_samples: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| "For debugging purposes or quicker training, truncate the number of evaluation examples to this " | |
| "value if set." | |
| ) | |
| }, | |
| ) | |
| overwrite_cache: bool = field( | |
| default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} | |
| ) | |
| validation_split_percentage: Optional[int] = field( | |
| default=1, | |
| metadata={ | |
| "help": "The percentage of the train set used as validation set in case there's no validation split" | |
| }, | |
| ) | |
| preprocessing_num_workers: Optional[int] = field( | |
| default=4, | |
| metadata={"help": "The number of processes to use for the preprocessing."}, | |
| ) | |
| class PeftArguments(TrainingArguments): | |
| use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"}) | |
| target_modules: Optional[str] = field(default="all") | |
| lora_rank: Optional[int] = field(default=8) | |
| lora_dropout: Optional[float] = field(default=0.05) | |
| lora_alpha: Optional[float] = field(default=32.0) | |
| modules_to_save: Optional[str] = field(default=None) | |
| peft_path: Optional[str] = field(default=None) | |
| def compute_metrics(eval_preds): | |
| preds, labels = eval_preds | |
| # Here, predictions is rewards_chosen and rewards_rejected. | |
| if isinstance(preds, torch.Tensor): | |
| preds = preds.detach().cpu().numpy() | |
| if isinstance(labels, torch.Tensor): | |
| labels = labels.detach().cpu().numpy() | |
| # MSE | |
| mse = mean_squared_error(labels, preds) | |
| # MAE | |
| mae = mean_absolute_error(labels, preds) | |
| return {"mse": mse, "mae": mae} | |
| class RewardDataCollatorWithPadding: | |
| """We need to define a special data collator that batches the data in our chosen vs rejected format""" | |
| tokenizer: PreTrainedTokenizerBase | |
| padding: Union[bool, str] = True | |
| max_length: Optional[int] = None | |
| pad_to_multiple_of: Optional[int] = None | |
| return_tensors: str = "pt" | |
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| features_chosen = [] | |
| features_rejected = [] | |
| for feature in features: | |
| features_chosen.append( | |
| { | |
| "input_ids": feature["input_ids_chosen"], | |
| "attention_mask": feature["attention_mask_chosen"], | |
| } | |
| ) | |
| features_rejected.append( | |
| { | |
| "input_ids": feature["input_ids_rejected"], | |
| "attention_mask": feature["attention_mask_rejected"], | |
| } | |
| ) | |
| batch_chosen = self.tokenizer.pad( | |
| features_chosen, | |
| padding=self.padding, | |
| max_length=self.max_length, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors=self.return_tensors, | |
| ) | |
| batch_rejected = self.tokenizer.pad( | |
| features_rejected, | |
| padding=self.padding, | |
| max_length=self.max_length, | |
| pad_to_multiple_of=self.pad_to_multiple_of, | |
| return_tensors=self.return_tensors, | |
| ) | |
| batch = { | |
| "input_ids_chosen": batch_chosen["input_ids"], | |
| "attention_mask_chosen": batch_chosen["attention_mask"], | |
| "input_ids_rejected": batch_rejected["input_ids"], | |
| "attention_mask_rejected": batch_rejected["attention_mask"], | |
| "return_loss": True, | |
| } | |
| return batch | |
| class RewardTrainer(Trainer): | |
| """ | |
| Trainer for reward models | |
| Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155 | |
| """ | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| rewards_chosen = model(input_ids=inputs["input_ids_chosen"], | |
| attention_mask=inputs["attention_mask_chosen"])[0] | |
| rewards_rejected = model(input_ids=inputs["input_ids_rejected"], | |
| attention_mask=inputs["attention_mask_rejected"])[0] | |
| loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() | |
| if return_outputs: | |
| return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected} | |
| return loss | |
| def evaluate( | |
| self, | |
| eval_dataset: Optional[Dataset] = None, | |
| ignore_keys: Optional[List[str]] = None, | |
| metric_key_prefix: str = "eval", | |
| ) -> Dict[str, float]: | |
| if eval_dataset is None: | |
| eval_dataset = self.eval_dataset | |
| return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) | |
| def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): | |
| # Prepare inputs for chosen and rejected separately | |
| device = model.device | |
| inputs_chosen = { | |
| "input_ids": inputs["input_ids_chosen"].to(device), | |
| "attention_mask": inputs["attention_mask_chosen"].to(device), | |
| } | |
| outputs_chosen = model(**inputs_chosen) | |
| rewards_chosen = outputs_chosen.logits.detach() | |
| inputs_rejected = { | |
| "input_ids": inputs["input_ids_rejected"].to(device), | |
| "attention_mask": inputs["attention_mask_rejected"].to(device), | |
| } | |
| outputs_rejected = model(**inputs_rejected) | |
| rewards_rejected = outputs_rejected.logits.detach() | |
| # Keep the compute_loss method | |
| loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() | |
| if prediction_loss_only: | |
| return (loss, None, None) | |
| return (loss, rewards_chosen, rewards_rejected) | |
| def save_model(self, output_dir=None, _internal_call=False): | |
| """Save the LoRA model.""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
| self.model.save_pretrained(output_dir) | |
| def save_model(output_dir, model, tokenizer, args): | |
| """Save the model and the tokenizer.""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Take care of distributed/parallel training | |
| model_to_save = model.module if hasattr(model, "module") else model | |
| model_to_save.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME)) | |
| class CastOutputToFloat(torch.nn.Sequential): | |
| """Cast the output of the model to float""" | |
| def forward(self, x): | |
| return super().forward(x).to(torch.float32) | |
| def print_trainable_parameters(model): | |
| """ | |
| Prints the number of trainable parameters in the model. | |
| """ | |
| trainable_params = 0 | |
| all_param = 0 | |
| for _, param in model.named_parameters(): | |
| all_param += param.numel() | |
| if param.requires_grad: | |
| trainable_params += param.numel() | |
| print( | |
| f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" | |
| ) | |
| def find_all_linear_names(peft_model, int4=False, int8=False): | |
| cls = torch.nn.Linear | |
| if int4 or int8: | |
| import bitsandbytes as bnb | |
| if int4: | |
| cls = bnb.nn.Linear4bit | |
| elif int8: | |
| cls = bnb.nn.Linear8bitLt | |
| lora_module_names = set() | |
| for name, module in peft_model.named_modules(): | |
| if isinstance(module, cls): | |
| # last layer is not add to lora_module_names | |
| if 'lm_head' in name: | |
| continue | |
| if 'score' in name: | |
| continue | |
| names = name.split('.') | |
| lora_module_names.add(names[0] if len(names) == 1 else names[-1]) | |
| return sorted(lora_module_names) | |
| def main(): | |
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments)) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| logger.info(f"Model args: {model_args}") | |
| logger.info(f"Data args: {data_args}") | |
| logger.info(f"Training args: {training_args}") | |
| logger.info( | |
| f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" | |
| + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
| ) | |
| # Set seed before initializing model. | |
| set_seed(training_args.seed) | |
| # Load model | |
| if not model_args.model_type: | |
| raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.") | |
| config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type] | |
| if model_args.model_name_or_path: | |
| torch_dtype = ( | |
| model_args.torch_dtype | |
| if model_args.torch_dtype in ["auto", None] | |
| else getattr(torch, model_args.torch_dtype) | |
| ) | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| if world_size > 1: | |
| model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0} | |
| config = config_class.from_pretrained( | |
| model_args.model_name_or_path, | |
| num_labels=1, | |
| torch_dtype=torch_dtype, | |
| trust_remote_code=model_args.trust_remote_code, | |
| cache_dir=model_args.cache_dir | |
| ) | |
| if model_args.model_type in ['bloom', 'llama']: | |
| model = model_class.from_pretrained( | |
| model_args.model_name_or_path, | |
| config=config, | |
| load_in_8bit=model_args.load_in_8bit, | |
| device_map=model_args.device_map, | |
| trust_remote_code=model_args.trust_remote_code, | |
| ) | |
| model.score = CastOutputToFloat(model.score) | |
| else: | |
| model = model_class.from_pretrained( | |
| model_args.model_name_or_path, | |
| config=config, | |
| cache_dir=model_args.cache_dir, | |
| ignore_mismatched_sizes=True | |
| ) | |
| model.to(training_args.device) | |
| else: | |
| raise ValueError(f"Error, model_name_or_path is None, RM must be loaded from a pre-trained model") | |
| # Load tokenizer | |
| if model_args.model_type == "bloom": | |
| model_args.use_fast_tokenizer = True | |
| tokenizer_kwargs = { | |
| "cache_dir": model_args.cache_dir, | |
| "use_fast": model_args.use_fast_tokenizer, | |
| "trust_remote_code": model_args.trust_remote_code, | |
| } | |
| tokenizer_name_or_path = model_args.tokenizer_name_or_path | |
| if not tokenizer_name_or_path: | |
| tokenizer_name_or_path = model_args.model_name_or_path | |
| tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = 0 | |
| if training_args.use_peft: | |
| if training_args.peft_path is not None: | |
| logger.info(f"Peft from pre-trained model: {training_args.peft_path}") | |
| model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True) | |
| else: | |
| logger.info("Init new peft model") | |
| target_modules = training_args.target_modules.split(',') if training_args.target_modules else None | |
| if target_modules and 'all' in target_modules: | |
| target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit) | |
| modules_to_save = training_args.modules_to_save | |
| if modules_to_save is not None: | |
| modules_to_save = modules_to_save.split(',') | |
| logger.info(f"Peft target_modules: {target_modules}") | |
| logger.info(f"Peft lora_rank: {training_args.lora_rank}") | |
| peft_config = LoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| target_modules=target_modules, | |
| inference_mode=False, | |
| r=training_args.lora_rank, | |
| lora_alpha=training_args.lora_alpha, | |
| lora_dropout=training_args.lora_dropout, | |
| modules_to_save=modules_to_save) | |
| model = get_peft_model(model, peft_config) | |
| if model_args.load_in_8bit: | |
| model = prepare_model_for_int8_training(model) | |
| model.print_trainable_parameters() | |
| else: | |
| logger.info("Full parameters training") | |
| print_trainable_parameters(model) | |
| # Get reward dataset for tuning the reward model. | |
| if data_args.dataset_name is not None: | |
| # Downloading and loading a dataset from the hub. | |
| raw_datasets = load_dataset( | |
| data_args.dataset_name, | |
| data_args.dataset_config_name, | |
| cache_dir=model_args.cache_dir, | |
| ) | |
| if "validation" not in raw_datasets.keys(): | |
| raw_datasets["validation"] = load_dataset( | |
| data_args.dataset_name, | |
| data_args.dataset_config_name, | |
| split=f"train[:{data_args.validation_split_percentage}%]", | |
| cache_dir=model_args.cache_dir, | |
| ) | |
| raw_datasets["train"] = load_dataset( | |
| data_args.dataset_name, | |
| data_args.dataset_config_name, | |
| split=f"train[{data_args.validation_split_percentage}%:]", | |
| cache_dir=model_args.cache_dir, | |
| ) | |
| else: | |
| data_files = {} | |
| if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir): | |
| train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob( | |
| f'{data_args.train_file_dir}/**/*.jsonl', recursive=True) | |
| logger.info(f"train files: {', '.join(train_data_files)}") | |
| data_files["train"] = train_data_files | |
| if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir): | |
| eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob( | |
| f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True) | |
| logger.info(f"eval files: {', '.join(eval_data_files)}") | |
| data_files["validation"] = eval_data_files | |
| raw_datasets = load_dataset( | |
| 'json', | |
| data_files=data_files, | |
| cache_dir=model_args.cache_dir, | |
| ) | |
| # If no validation data is there, validation_split_percentage will be used to divide the dataset. | |
| if "validation" not in raw_datasets.keys(): | |
| raw_datasets["validation"] = load_dataset( | |
| 'json', | |
| data_files=data_files, | |
| split=f"train[:{data_args.validation_split_percentage}%]", | |
| cache_dir=model_args.cache_dir, | |
| ) | |
| raw_datasets["train"] = load_dataset( | |
| 'json', | |
| data_files=data_files, | |
| split=f"train[{data_args.validation_split_percentage}%:]", | |
| cache_dir=model_args.cache_dir, | |
| ) | |
| logger.info(f"Raw datasets: {raw_datasets}") | |
| # Preprocessing the datasets | |
| full_max_length = data_args.max_source_length + data_args.max_target_length | |
| def preprocess_reward_function(examples): | |
| """ | |
| Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer | |
| and text_rejected is the other. | |
| """ | |
| new_examples = { | |
| "input_ids_chosen": [], | |
| "attention_mask_chosen": [], | |
| "input_ids_rejected": [], | |
| "attention_mask_rejected": [], | |
| } | |
| for question, chosen, rejected in zip(examples["question"], examples["response_chosen"], | |
| examples["response_rejected"]): | |
| tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen) | |
| tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected) | |
| new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) | |
| new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) | |
| new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) | |
| new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) | |
| return new_examples | |
| train_dataset = None | |
| max_train_samples = 0 | |
| if training_args.do_train: | |
| if "train" not in raw_datasets: | |
| raise ValueError("--do_train requires a train dataset") | |
| train_dataset = raw_datasets['train'] | |
| max_train_samples = len(train_dataset) | |
| if data_args.max_train_samples is not None and data_args.max_train_samples > 0: | |
| max_train_samples = min(len(train_dataset), data_args.max_train_samples) | |
| train_dataset = train_dataset.select(range(max_train_samples)) | |
| logger.debug(f"Example train_dataset[0]: {train_dataset[0]}") | |
| with training_args.main_process_first(desc="Train dataset tokenization"): | |
| tokenized_dataset = train_dataset.shuffle().map( | |
| preprocess_reward_function, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=train_dataset.column_names, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| train_dataset = tokenized_dataset.filter( | |
| lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len( | |
| x['input_ids_chosen']) <= full_max_length | |
| ) | |
| logger.debug(f"Num train_samples: {len(train_dataset)}") | |
| logger.debug("Tokenized training example:") | |
| logger.debug(tokenizer.decode(train_dataset[0]['input_ids_chosen'])) | |
| eval_dataset = None | |
| max_eval_samples = 0 | |
| if training_args.do_eval: | |
| with training_args.main_process_first(desc="Eval dataset tokenization"): | |
| if "validation" not in raw_datasets: | |
| raise ValueError("--do_eval requires a validation dataset") | |
| eval_dataset = raw_datasets["validation"] | |
| max_eval_samples = len(eval_dataset) | |
| if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0: | |
| max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) | |
| eval_dataset = eval_dataset.select(range(max_eval_samples)) | |
| logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}") | |
| tokenized_dataset = eval_dataset.map( | |
| preprocess_reward_function, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=eval_dataset.column_names, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| eval_dataset = tokenized_dataset.filter( | |
| lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len( | |
| x['input_ids_chosen']) <= full_max_length | |
| ) | |
| logger.debug(f"Num eval_samples: {len(eval_dataset)}") | |
| logger.debug("Tokenized eval example:") | |
| logger.debug(tokenizer.decode(eval_dataset[0]['input_ids_chosen'])) | |
| # Initialize our Trainer | |
| if training_args.gradient_checkpointing: | |
| model.gradient_checkpointing_enable() | |
| model.config.use_cache = False | |
| else: | |
| model.config.use_cache = True | |
| model.enable_input_require_grads() | |
| if torch.cuda.device_count() > 1: | |
| # Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available | |
| model.is_parallelizable = True | |
| model.model_parallel = True | |
| trainer = RewardTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset if training_args.do_train else None, | |
| eval_dataset=eval_dataset if training_args.do_eval else None, | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics, | |
| data_collator=RewardDataCollatorWithPadding( | |
| tokenizer=tokenizer, max_length=full_max_length, padding="max_length" | |
| ), | |
| ) | |
| # Training | |
| if training_args.do_train: | |
| logger.info("*** Train ***") | |
| logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}") | |
| checkpoint = None | |
| if training_args.resume_from_checkpoint is not None: | |
| checkpoint = training_args.resume_from_checkpoint | |
| train_result = trainer.train(resume_from_checkpoint=checkpoint) | |
| metrics = train_result.metrics | |
| metrics["train_samples"] = max_train_samples | |
| logger.debug(f"Training metrics: {metrics}") | |
| trainer.log_metrics("train", metrics) | |
| trainer.save_metrics("train", metrics) | |
| trainer.save_state() | |
| logger.info(f"Saving model checkpoint to {training_args.output_dir}") | |
| save_model(training_args.output_dir, model, tokenizer, training_args) | |
| # Evaluation | |
| if training_args.do_eval and trainer.is_world_process_zero(): | |
| logger.info("*** Evaluate ***") | |
| metrics = trainer.evaluate() | |
| metrics["eval_samples"] = max_eval_samples | |
| try: | |
| perplexity = math.exp(metrics["eval_loss"]) | |
| except OverflowError: | |
| perplexity = float("inf") | |
| metrics["perplexity"] = perplexity | |
| logger.debug(f"Eval metrics: {metrics}") | |
| trainer.log_metrics("eval", metrics) | |
| trainer.save_metrics("eval", metrics) | |
| if __name__ == "__main__": | |
| main() | |