plan2align_rm / README.md
ray24724919's picture
Update README.md
92f6d74 verified
metadata
license: apache-2.0
language:
  - en
  - zh
  - de
  - ru
base_model:
  - meta-llama/Llama-3.1-8B
tags:
  - translation
  - reasoning
  - test-time

Reward model for Plan2Align, using for test-time translation task on zh->en, zh->de, zh->ru language pairs.

@article{wang2025plan2align,
  title={Plan2Align: Predictive Planning Based Test-Time Preference Alignment in Paragraph-Level Machine Translation},
  author={Wang, Kuang-Da and Chen, Teng-Ruei and Hung, Yu Heng and Ding, Shuoyang and Wu, Yueh-Hua and Wang, Yu-Chiang Frank and Yang, Chao-Han Huck and Peng, Wen-Chih and Hsieh, Ping-Chun},
  journal={arXiv preprint arXiv:2502.20795},
  year={2025}
}

Using Reward Model

RM = AutoModelForCausalLMWithValueHead.from_pretrained('ray24724919/plan2align_rm',torch_dtype=torch_dtype)
RM.eval()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
RM.gradient_checkpointing_enable() #if need
        
value_head_weights = load_file("path-to-valuehead-safetensors")
new_state_dict = {key.replace("v_head.", "") if key.startswith("v_head.") else key: value for key, value in value_head_weights.items()}
RM.v_head.load_state_dict(new_state_dict)

Reward Function

def reward(language, text, response, device='cuda:0'):
    message=[{"role": "system", "content":' You are a helpful translator and only output the result.'},
          {"role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{text}\n### {language}:"},
          {"role": "assistant", "content": response}]
    tokenized_inputs = tokenizer.apply_chat_template(
        message,
        add_generation_prompt=False,
        return_tensors="pt"
    ).to(device)

    inputs = {
        "input_ids": tokenized_inputs,
        "attention_mask": torch.ones_like(tokenized_inputs, device=device)
    }

    with torch.no_grad():
        outputs = model(**inputs, return_value=True)
        rewards = outputs[2]

    final_reward = rewards[:, -1].item()

    return final_reward

System prompt of translation reward modeling

messages = [{"role": "system", "content": "You are a helpful translator and only output the result."},
            {"role": "user", "content": f"### Translate this from Chinese to {language}, Chinese:\n{source}\n### {language}:"},
            {"role": "assistant", "content": translation}]