Trying to perform CPT of llama on a new language (Language is similar to Hindi, hence some tokens already present). The modelās validation loss seems to plateau very early on into the training. Here 1 epoch is around 6k steps and validation loss seems to already be lowest at step 750.
My dataset is around 100k size. Im using Lora as well
Apply LoRA
lora_config = LoraConfig(
r=128,
lora_alpha=64,
target_modules=[āq_projā, āk_projā, āv_projā, āo_projā,āgate_projā, āup_projā, ādown_projā, ālm_headā],
lora_dropout=0,
bias=ānoneā,
task_type=āCAUSAL_LMā
)
model = get_peft_model(model, lora_config)
Here are my training args
sft_config = SFTConfig(
learning_rate=1e-3,
lr_scheduler_type=ācosineā,
per_device_train_batch_size=8,
warmup_ratio=0.05,
num_train_epochs=5,
max_grad_norm = 1.0,
logging_steps=250,
eval_strategy="steps",
eval_steps=250,
save_strategy="steps",
save_steps=500,
output_dir="./llama-cpt",
bf16=True,
fp16=False,
dataset_text_field="text",
logging_dir="./logs",
save_total_limit=2,
report_to="none",
max_seq_length=512,
dataset_num_proc = 8,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset[ātrainā],
eval_dataset=dataset[ātestā],
args=sft_config,
callbacks=[early_stopping_callback]
)
Ive tried different arangement, like more r value, embed_head and lm_head added onto the modules, different leaerning rates, etc. But similar trend in validation loss, either its around this range or around the range of 1.59-1.60.
Moreover, Ive also tried mistral-7b-v0.1, same issues.
I thought it might be because the model is not able to learn because of less tokens, so tried vocab expansion, but same issues.
What else could i try?
