Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +9 -0
- images/dataset.png +3 -0
- images/demo1.png +3 -0
- previous_version/Video-R1-main-previous/images/2B_curve.png +3 -0
- previous_version/Video-R1-main-previous/images/7B_curve.png +3 -0
- previous_version/Video-R1-main-previous/images/7B_nextqa.png +3 -0
- previous_version/Video-R1-main-previous/images/CATER_new_003595.gif +3 -0
- previous_version/Video-R1-main-previous/images/sample.png +3 -0
- previous_version/Video-R1-main-previous/src/distill_r1/create_hf_dataset.py +119 -0
- previous_version/Video-R1-main-previous/src/distill_r1/generate_scene_qa_pairs.ipynb +569 -0
- previous_version/Video-R1-main-previous/src/distill_r1/grpo_r1_distilled.jpg +3 -0
- previous_version/Video-R1-main-previous/src/distill_r1/query_r1.py +114 -0
- previous_version/Video-R1-main-previous/src/eval/prompts/geoqa_test_prompts.jsonl +0 -0
- previous_version/Video-R1-main-previous/src/eval/prompts/superclevr_test200_counting_problems.jsonl +200 -0
- previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_counting_superclevr.py +136 -0
- previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa.py +149 -0
- previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa_multigpu.py +205 -0
- previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_video_counting.py +141 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/.python-version +1 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/README.md +94 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/pyproject.toml +75 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements-dev.lock +84 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements.lock +32 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py +7 -0
- previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py +379 -0
- previous_version/Video-R1-main-previous/src/r1-v/temp_image.png +3 -0
- src/r1-v/.gitignore +178 -0
- src/r1-v/LICENSE +201 -0
- src/r1-v/Makefile +20 -0
- src/r1-v/setup.cfg +41 -0
- src/r1-v/setup.py +132 -0
- src/r1-v/src/open_r1/__init__.py +0 -0
- src/r1-v/src/open_r1/evaluate.py +85 -0
- src/r1-v/src/open_r1/generate.py +156 -0
- src/r1-v/src/open_r1/grpo-cot-72BEval.py +489 -0
- src/r1-v/src/open_r1/grpo-cot-LLMEval.py +552 -0
- src/r1-v/src/open_r1/grpo-cot-answerBERT-eval.py +429 -0
- src/r1-v/src/open_r1/grpo-cot-noDesEval.py +446 -0
- src/r1-v/src/open_r1/grpo-cot-noInfo.py +346 -0
- src/r1-v/src/open_r1/grpo-cot-qwenEval.py +523 -0
- src/r1-v/src/open_r1/grpo-cot-selfEval.py +457 -0
- src/r1-v/src/open_r1/grpo-cot-selfEvalConst.py +456 -0
- src/r1-v/src/open_r1/grpo-cot.py +351 -0
- src/r1-v/src/open_r1/grpo-description-LLMEval.py +579 -0
- src/r1-v/src/open_r1/grpo.py +318 -0
- src/r1-v/src/open_r1/grpo_vllm_caption.py +266 -0
- src/r1-v/src/open_r1/sft_video.py +304 -0
- src/r1-v/src/open_r1/trainer/__init__.py +12 -0
- src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_error.py +1061 -0
- src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_orig.py +935 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
images/curves.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
images/demo2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
images/performance.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
images/curves.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
images/demo2.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
images/performance.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
images/dataset.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
images/demo1.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
previous_version/Video-R1-main-previous/src/r1-v/temp_image.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
previous_version/Video-R1-main-previous/src/distill_r1/grpo_r1_distilled.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
previous_version/Video-R1-main-previous/images/7B_nextqa.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
previous_version/Video-R1-main-previous/images/sample.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
previous_version/Video-R1-main-previous/images/CATER_new_003595.gif filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
previous_version/Video-R1-main-previous/images/2B_curve.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
previous_version/Video-R1-main-previous/images/7B_curve.png filter=lfs diff=lfs merge=lfs -text
|
images/dataset.png
ADDED
|
Git LFS Details
|
images/demo1.png
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/images/2B_curve.png
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/images/7B_curve.png
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/images/7B_nextqa.png
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/images/CATER_new_003595.gif
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/images/sample.png
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/src/distill_r1/create_hf_dataset.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
random.seed(1234)
|
| 8 |
+
VAL_NUM = 5000
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_r1_train_dataset(
|
| 12 |
+
valid_pair_json,
|
| 13 |
+
data_dir,
|
| 14 |
+
img_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/trainA/",
|
| 15 |
+
):
|
| 16 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 17 |
+
pairs = [json.loads(line) for line in open(valid_pair_json, "r")]
|
| 18 |
+
mapped_pairs = []
|
| 19 |
+
|
| 20 |
+
for idx, pair in tqdm(enumerate(pairs)):
|
| 21 |
+
img_filename = pair["img_filename"]
|
| 22 |
+
new_pair = {}
|
| 23 |
+
try:
|
| 24 |
+
new_pair["thinking"] = (
|
| 25 |
+
pair["r1_response"]
|
| 26 |
+
.split("<think>")[1]
|
| 27 |
+
.split("</think>")[0]
|
| 28 |
+
.replace("scene description", "image")
|
| 29 |
+
)
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Error processing pair response: ", pair["r1_response"])
|
| 32 |
+
continue # skip this pair
|
| 33 |
+
# add index to distinguish the same image
|
| 34 |
+
dataset_filename = (
|
| 35 |
+
img_filename.split(".")[0] + "_" + str(idx) + "." + img_filename.split(".")[1]
|
| 36 |
+
)
|
| 37 |
+
if not os.path.exists(f"{data_dir}/{img_filename}"):
|
| 38 |
+
os.system(f"cp {img_dir}/{img_filename} {data_dir}/{dataset_filename}")
|
| 39 |
+
q, a = pair["q"], pair["a"]
|
| 40 |
+
new_pair["problem"] = q
|
| 41 |
+
# get the thinking path
|
| 42 |
+
|
| 43 |
+
new_pair["thinking"] = "<think>" + new_pair["thinking"] + "</think>"
|
| 44 |
+
new_pair["solution"] = f"<answer> {a} </answer>"
|
| 45 |
+
new_pair["file_name"] = dataset_filename
|
| 46 |
+
mapped_pairs.append(new_pair)
|
| 47 |
+
with open(f"{data_dir}/metadata.jsonl", "w") as f:
|
| 48 |
+
for pair in mapped_pairs:
|
| 49 |
+
f.write(json.dumps(pair) + "\n")
|
| 50 |
+
|
| 51 |
+
train_dataset = load_dataset(
|
| 52 |
+
"imagefolder",
|
| 53 |
+
data_dir=data_dir,
|
| 54 |
+
split="train",
|
| 55 |
+
)
|
| 56 |
+
return train_dataset
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_val_dataset(
|
| 60 |
+
json_file,
|
| 61 |
+
data_dir,
|
| 62 |
+
val_num=VAL_NUM,
|
| 63 |
+
image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valB",
|
| 64 |
+
):
|
| 65 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 66 |
+
val = json.load(open(json_file))
|
| 67 |
+
random.shuffle(val)
|
| 68 |
+
val = val[:val_num]
|
| 69 |
+
val_pairs = []
|
| 70 |
+
for idx, pair in tqdm(enumerate(val)):
|
| 71 |
+
q, a = pair["q"], pair["a"]
|
| 72 |
+
img_filename = pair["img_filename"]
|
| 73 |
+
# copy images to the DATA_DIR
|
| 74 |
+
val_filename = (
|
| 75 |
+
img_filename.split(".")[0] + f"_{idx}." + img_filename.split(".")[1]
|
| 76 |
+
)
|
| 77 |
+
if not os.path.exists(f"{data_dir}/{img_filename}"):
|
| 78 |
+
os.system(f"cp {image_dir}/{img_filename} {data_dir}/{val_filename}")
|
| 79 |
+
new_pair = {}
|
| 80 |
+
new_pair["problem"] = q
|
| 81 |
+
new_pair["solution"] = f"<answer> {a} </answer>"
|
| 82 |
+
new_pair["file_name"] = val_filename
|
| 83 |
+
val_pairs.append(new_pair)
|
| 84 |
+
with open(f"{data_dir}/metadata.jsonl", "w") as f:
|
| 85 |
+
for pair in val_pairs:
|
| 86 |
+
f.write(json.dumps(pair) + "\n")
|
| 87 |
+
val_dataset = load_dataset("imagefolder", data_dir=data_dir, split="train")
|
| 88 |
+
return val_dataset
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# valA split
|
| 92 |
+
VALA_DATA_DIR = "data/Clevr_CoGenT_ValA"
|
| 93 |
+
VALB_DATA_DIR = "data/Clevr_CoGenT_ValB"
|
| 94 |
+
valA_json = (
|
| 95 |
+
"/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_valA.json"
|
| 96 |
+
)
|
| 97 |
+
valB_json = (
|
| 98 |
+
"/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_valB.json"
|
| 99 |
+
)
|
| 100 |
+
TRAIN_DATADIR = "data/Clevr_CoGenT_TrainA_R1"
|
| 101 |
+
train_dataset = create_r1_train_dataset(
|
| 102 |
+
"/home/lilei/Visual-R1/filter_results_v2/valid_pairs.jsonl",
|
| 103 |
+
TRAIN_DATADIR,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# print(train_dataset)
|
| 107 |
+
valA_dataset = create_val_dataset(
|
| 108 |
+
valA_json,
|
| 109 |
+
VALA_DATA_DIR,
|
| 110 |
+
image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valA",
|
| 111 |
+
)
|
| 112 |
+
valB_dataset = create_val_dataset(
|
| 113 |
+
valB_json,
|
| 114 |
+
VALB_DATA_DIR,
|
| 115 |
+
image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valB",
|
| 116 |
+
)
|
| 117 |
+
valA_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_ValA")
|
| 118 |
+
valB_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_ValB")
|
| 119 |
+
train_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_TrainA_R1")
|
previous_version/Video-R1-main-previous/src/distill_r1/generate_scene_qa_pairs.ipynb
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "3a704ea6-2e61-4aaa-97aa-416579c9bc13",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import json\n",
|
| 11 |
+
"import random"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 4,
|
| 17 |
+
"id": "c4920a8f-cddd-4063-8cab-215d238b5dad",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [
|
| 20 |
+
{
|
| 21 |
+
"name": "stdout",
|
| 22 |
+
"output_type": "stream",
|
| 23 |
+
"text": [
|
| 24 |
+
"CLEVR_trainA_scenes.json CLEVR_valA_scenes.json CLEVR_valB_scenes.json\n"
|
| 25 |
+
]
|
| 26 |
+
}
|
| 27 |
+
],
|
| 28 |
+
"source": [
|
| 29 |
+
"!ls CLEVR_CoGenT_v1.0/scenes"
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "code",
|
| 34 |
+
"execution_count": 15,
|
| 35 |
+
"id": "934fa005-3b2a-43ed-8a71-6a12b7579546",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"outputs": [],
|
| 38 |
+
"source": [
|
| 39 |
+
"split = \"valB\"\n",
|
| 40 |
+
"clevr_train_json = f\"CLEVR_CoGenT_v1.0/scenes/CLEVR_{split}_scenes.json\"\n",
|
| 41 |
+
"train_qs = f\"CLEVR_CoGenT_v1.0/questions/CLEVR_{split}_questions.json\"\n",
|
| 42 |
+
"data = json.load(open(clevr_train_json))\n",
|
| 43 |
+
"qs = json.load(open(train_qs))"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": 16,
|
| 49 |
+
"id": "1f0d6180-94c4-4aea-bd2b-8d5cfeb0aecb",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [
|
| 52 |
+
{
|
| 53 |
+
"name": "stdout",
|
| 54 |
+
"output_type": "stream",
|
| 55 |
+
"text": [
|
| 56 |
+
"[{'pixel_coords': [343, 131, 11.278693199157715], 'size': 'small', 'color': 'green', 'material': 'metal', 'shape': 'sphere', '3d_coords': [0.9906095862388611, 2.083291530609131, 0.3499999940395355], 'rotation': 107.73596690369371}, {'pixel_coords': [396, 172, 9.857704162597656], 'size': 'small', 'color': 'cyan', 'material': 'rubber', 'shape': 'sphere', '3d_coords': [2.69626522064209, 1.5257188081741333, 0.3499999940395355], 'rotation': 305.3536122513589}, {'pixel_coords': [115, 182, 8.91348934173584], 'size': 'large', 'color': 'yellow', 'material': 'rubber', 'shape': 'cylinder', '3d_coords': [0.049163494259119034, -2.864100217819214, 0.699999988079071], 'rotation': 161.8370138842408}, {'pixel_coords': [203, 131, 10.548327445983887], 'size': 'large', 'color': 'purple', 'material': 'rubber', 'shape': 'cube', '3d_coords': [-0.4719269275665283, -0.5699371695518494, 0.699999988079071], 'rotation': 159.41862667811446}, {'pixel_coords': [253, 75, 13.141877174377441], 'size': 'large', 'color': 'red', 'material': 'rubber', 'shape': 'cube', '3d_coords': [-2.036878824234009, 2.222999334335327, 0.699999988079071], 'rotation': 37.40490732771224}]\n",
|
| 57 |
+
"len: 5\n"
|
| 58 |
+
]
|
| 59 |
+
}
|
| 60 |
+
],
|
| 61 |
+
"source": [
|
| 62 |
+
"print(data['scenes'][0]['objects'])\n",
|
| 63 |
+
"print(\"len: \", len(data['scenes'][0]['objects']))"
|
| 64 |
+
]
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"cell_type": "code",
|
| 68 |
+
"execution_count": 17,
|
| 69 |
+
"id": "7c828ca4-08f9-4927-a745-224a95379c2f",
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [],
|
| 72 |
+
"source": [
|
| 73 |
+
"def object_info_to_description(object_list):\n",
|
| 74 |
+
" descriptions = []\n",
|
| 75 |
+
" random.shuffle(object_list)\n",
|
| 76 |
+
" for obj in object_list:\n",
|
| 77 |
+
" desc = f\"A {obj['size']} {obj['color']} {obj['material']} {obj['shape']}\"\n",
|
| 78 |
+
" desc += f\" rotated {obj['rotation']:.1f}° located at\"\n",
|
| 79 |
+
" desc += f\" 3D coordinates ({obj['3d_coords'][0]:.2f}, {obj['3d_coords'][1]:.2f}, {obj['3d_coords'][2]:.2f})\"\n",
|
| 80 |
+
" desc += f\" and pixel coordinates ({obj['pixel_coords'][0]}, {obj['pixel_coords'][1]}, {obj['pixel_coords'][2]:.2f})\"\n",
|
| 81 |
+
" descriptions.append(desc)\n",
|
| 82 |
+
" \n",
|
| 83 |
+
" final_description = \"Scene Description:\\n\"\n",
|
| 84 |
+
" for i, desc in enumerate(descriptions, 1):\n",
|
| 85 |
+
" final_description += f\"{desc}\\n\"\n",
|
| 86 |
+
" \n",
|
| 87 |
+
" return final_description"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": 18,
|
| 93 |
+
"id": "cb048e25-d554-4bd7-bf11-878e071b5987",
|
| 94 |
+
"metadata": {},
|
| 95 |
+
"outputs": [
|
| 96 |
+
{
|
| 97 |
+
"data": {
|
| 98 |
+
"text/plain": [
|
| 99 |
+
"'Scene Description:\\nA large yellow rubber cylinder rotated 161.8° located at 3D coordinates (0.05, -2.86, 0.70) and pixel coordinates (115, 182, 8.91)\\nA large purple rubber cube rotated 159.4° located at 3D coordinates (-0.47, -0.57, 0.70) and pixel coordinates (203, 131, 10.55)\\nA large red rubber cube rotated 37.4° located at 3D coordinates (-2.04, 2.22, 0.70) and pixel coordinates (253, 75, 13.14)\\nA small green metal sphere rotated 107.7° located at 3D coordinates (0.99, 2.08, 0.35) and pixel coordinates (343, 131, 11.28)\\nA small cyan rubber sphere rotated 305.4° located at 3D coordinates (2.70, 1.53, 0.35) and pixel coordinates (396, 172, 9.86)\\n'"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"execution_count": 18,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"output_type": "execute_result"
|
| 105 |
+
}
|
| 106 |
+
],
|
| 107 |
+
"source": [
|
| 108 |
+
"object_info_to_description(data['scenes'][0]['objects'])"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": 19,
|
| 114 |
+
"id": "ffacd5f3-e9a4-46ca-8c50-187ab12c9f1b",
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"outputs": [],
|
| 117 |
+
"source": [
|
| 118 |
+
"img2obj_dict = {}\n",
|
| 119 |
+
"for scene in data['scenes']:\n",
|
| 120 |
+
" obj_list = scene['objects']\n",
|
| 121 |
+
" img2obj_dict[scene['image_filename']] = obj_list"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": 20,
|
| 127 |
+
"id": "db35f03c-1529-4776-bf4f-3bd44e960e5f",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [
|
| 130 |
+
{
|
| 131 |
+
"data": {
|
| 132 |
+
"text/plain": [
|
| 133 |
+
"{'question_index': 0,\n",
|
| 134 |
+
" 'question_family_index': 29,\n",
|
| 135 |
+
" 'image_index': 0,\n",
|
| 136 |
+
" 'question': 'The big thing that is in front of the large rubber cube in front of the small thing that is behind the tiny matte ball is what color?',\n",
|
| 137 |
+
" 'answer': 'yellow',\n",
|
| 138 |
+
" 'image_filename': 'CLEVR_valB_000000.png',\n",
|
| 139 |
+
" 'split': 'valB',\n",
|
| 140 |
+
" 'program': [{'value_inputs': [], 'inputs': [], 'function': 'scene'},\n",
|
| 141 |
+
" {'value_inputs': ['small'], 'inputs': [0], 'function': 'filter_size'},\n",
|
| 142 |
+
" {'value_inputs': ['rubber'], 'inputs': [1], 'function': 'filter_material'},\n",
|
| 143 |
+
" {'value_inputs': ['sphere'], 'inputs': [2], 'function': 'filter_shape'},\n",
|
| 144 |
+
" {'value_inputs': [], 'inputs': [3], 'function': 'unique'},\n",
|
| 145 |
+
" {'value_inputs': ['behind'], 'inputs': [4], 'function': 'relate'},\n",
|
| 146 |
+
" {'value_inputs': ['small'], 'inputs': [5], 'function': 'filter_size'},\n",
|
| 147 |
+
" {'value_inputs': [], 'inputs': [6], 'function': 'unique'},\n",
|
| 148 |
+
" {'value_inputs': ['front'], 'inputs': [7], 'function': 'relate'},\n",
|
| 149 |
+
" {'value_inputs': ['large'], 'inputs': [8], 'function': 'filter_size'},\n",
|
| 150 |
+
" {'value_inputs': ['rubber'], 'inputs': [9], 'function': 'filter_material'},\n",
|
| 151 |
+
" {'value_inputs': ['cube'], 'inputs': [10], 'function': 'filter_shape'},\n",
|
| 152 |
+
" {'value_inputs': [], 'inputs': [11], 'function': 'unique'},\n",
|
| 153 |
+
" {'value_inputs': ['front'], 'inputs': [12], 'function': 'relate'},\n",
|
| 154 |
+
" {'value_inputs': ['large'], 'inputs': [13], 'function': 'filter_size'},\n",
|
| 155 |
+
" {'value_inputs': [], 'inputs': [14], 'function': 'unique'},\n",
|
| 156 |
+
" {'value_inputs': [], 'inputs': [15], 'function': 'query_color'}]}"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
"execution_count": 20,
|
| 160 |
+
"metadata": {},
|
| 161 |
+
"output_type": "execute_result"
|
| 162 |
+
}
|
| 163 |
+
],
|
| 164 |
+
"source": [
|
| 165 |
+
"qs['questions'][0]"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "code",
|
| 170 |
+
"execution_count": 21,
|
| 171 |
+
"id": "66b746fc-569c-4922-a442-79dbbc09e33b",
|
| 172 |
+
"metadata": {},
|
| 173 |
+
"outputs": [],
|
| 174 |
+
"source": [
|
| 175 |
+
"random.shuffle(qs['questions'])\n",
|
| 176 |
+
"cnt = 0 \n",
|
| 177 |
+
"qa_pairs = [] \n",
|
| 178 |
+
"added_pair = set()\n",
|
| 179 |
+
"for qd in qs['questions']:\n",
|
| 180 |
+
" img_idx = qd['image_filename']\n",
|
| 181 |
+
" total_count = len(img2obj_dict[img_idx]) # object list length\n",
|
| 182 |
+
" desc = object_info_to_description(img2obj_dict[img_idx])\n",
|
| 183 |
+
" question, answer = qd['question'], qd['answer']\n",
|
| 184 |
+
" if 'how many' in question.lower() or 'number' in question.lower():\n",
|
| 185 |
+
" qa_pairs.append({\n",
|
| 186 |
+
" \"img_filename\": img_idx,\n",
|
| 187 |
+
" 'q': question,\n",
|
| 188 |
+
" 'a': answer,\n",
|
| 189 |
+
" 'description': desc \n",
|
| 190 |
+
" })\n",
|
| 191 |
+
" if img_idx not in added_pair:\n",
|
| 192 |
+
" qa_pairs.append({\n",
|
| 193 |
+
" \"img_filename\": img_idx,\n",
|
| 194 |
+
" 'q': \"How many items are there in the described scene?\",\n",
|
| 195 |
+
" 'a': total_count,\n",
|
| 196 |
+
" 'description': desc \n",
|
| 197 |
+
" })\n",
|
| 198 |
+
" added_pair.add(img_idx)\n"
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
"cell_type": "code",
|
| 203 |
+
"execution_count": 22,
|
| 204 |
+
"id": "c271fa7b-fed5-472f-a302-6ec203c4b787",
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [
|
| 207 |
+
{
|
| 208 |
+
"data": {
|
| 209 |
+
"text/plain": [
|
| 210 |
+
"59978"
|
| 211 |
+
]
|
| 212 |
+
},
|
| 213 |
+
"execution_count": 22,
|
| 214 |
+
"metadata": {},
|
| 215 |
+
"output_type": "execute_result"
|
| 216 |
+
}
|
| 217 |
+
],
|
| 218 |
+
"source": [
|
| 219 |
+
"len(qa_pairs)"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": 23,
|
| 225 |
+
"id": "b0da8a70-c3f5-4e48-b384-3684933d72ef",
|
| 226 |
+
"metadata": {},
|
| 227 |
+
"outputs": [
|
| 228 |
+
{
|
| 229 |
+
"data": {
|
| 230 |
+
"text/plain": [
|
| 231 |
+
"14884"
|
| 232 |
+
]
|
| 233 |
+
},
|
| 234 |
+
"execution_count": 23,
|
| 235 |
+
"metadata": {},
|
| 236 |
+
"output_type": "execute_result"
|
| 237 |
+
}
|
| 238 |
+
],
|
| 239 |
+
"source": [
|
| 240 |
+
"len(added_pair)"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "code",
|
| 245 |
+
"execution_count": 24,
|
| 246 |
+
"id": "c648587e-2ec0-427c-b594-f55dd187b4d9",
|
| 247 |
+
"metadata": {},
|
| 248 |
+
"outputs": [],
|
| 249 |
+
"source": [
|
| 250 |
+
"# save for later loading\n",
|
| 251 |
+
"with open(f\"clever_counting_problems_clevr_cogent_v1.0_{split}.json\", 'w') as fw:\n",
|
| 252 |
+
" json.dump( qa_pairs, fw, indent=4)"
|
| 253 |
+
]
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"cell_type": "code",
|
| 257 |
+
"execution_count": 20,
|
| 258 |
+
"id": "b3a8cbe4-4261-41d3-a481-43a0b1cc2795",
|
| 259 |
+
"metadata": {},
|
| 260 |
+
"outputs": [],
|
| 261 |
+
"source": [
|
| 262 |
+
"random.shuffle(qa_pairs)"
|
| 263 |
+
]
|
| 264 |
+
},
|
| 265 |
+
{
|
| 266 |
+
"cell_type": "code",
|
| 267 |
+
"execution_count": 57,
|
| 268 |
+
"id": "d6dff4e7-65dd-4e82-82df-340ec2a57919",
|
| 269 |
+
"metadata": {},
|
| 270 |
+
"outputs": [
|
| 271 |
+
{
|
| 272 |
+
"data": {
|
| 273 |
+
"text/plain": [
|
| 274 |
+
"[{'img_filename': 'CLEVR_trainA_048403.png',\n",
|
| 275 |
+
" 'q': 'How many things are both on the right side of the big yellow rubber thing and left of the purple ball?',\n",
|
| 276 |
+
" 'a': '5',\n",
|
| 277 |
+
" 'description': 'Scene Description:\\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\\n'},\n",
|
| 278 |
+
" {'img_filename': 'CLEVR_trainA_048403.png',\n",
|
| 279 |
+
" 'q': 'How many items are there in the described scene?',\n",
|
| 280 |
+
" 'a': 9,\n",
|
| 281 |
+
" 'description': 'Scene Description:\\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\\n'}]"
|
| 282 |
+
]
|
| 283 |
+
},
|
| 284 |
+
"execution_count": 57,
|
| 285 |
+
"metadata": {},
|
| 286 |
+
"output_type": "execute_result"
|
| 287 |
+
}
|
| 288 |
+
],
|
| 289 |
+
"source": [
|
| 290 |
+
"qa_pairs[:2]"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": 26,
|
| 296 |
+
"id": "a6a66364-5b47-4138-91d6-a045404d21b1",
|
| 297 |
+
"metadata": {},
|
| 298 |
+
"outputs": [],
|
| 299 |
+
"source": [
|
| 300 |
+
"def query_r1(query='who are you?', model=\"deepseek-ai/DeepSeek-R1\"):\n",
|
| 301 |
+
" # Create the chat completion\n",
|
| 302 |
+
" response = client.chat.completions.create(\n",
|
| 303 |
+
" model=model,\n",
|
| 304 |
+
" messages=[\n",
|
| 305 |
+
" {'role': 'user', \n",
|
| 306 |
+
" 'content': query}\n",
|
| 307 |
+
" ],\n",
|
| 308 |
+
" stream=False,\n",
|
| 309 |
+
" )\n",
|
| 310 |
+
" # Print the response\n",
|
| 311 |
+
" return response.choices[0].message.content"
|
| 312 |
+
]
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"cell_type": "code",
|
| 316 |
+
"execution_count": 44,
|
| 317 |
+
"id": "e5d5649f-c4e3-4f3f-b76e-7f7ed27f68e8",
|
| 318 |
+
"metadata": {},
|
| 319 |
+
"outputs": [],
|
| 320 |
+
"source": [
|
| 321 |
+
"def format_query(qa_dict):\n",
|
| 322 |
+
" query = \"Answer the question according to scene description.\\n\\n\"\n",
|
| 323 |
+
" query += qa_dict['description']\n",
|
| 324 |
+
" query += f\"\\nQuestion:\\n{qa_dict['q']}\"\n",
|
| 325 |
+
" return query \n",
|
| 326 |
+
" "
|
| 327 |
+
]
|
| 328 |
+
},
|
| 329 |
+
{
|
| 330 |
+
"cell_type": "code",
|
| 331 |
+
"execution_count": 39,
|
| 332 |
+
"id": "7f568a4e-f217-464a-8329-bbefb64d9653",
|
| 333 |
+
"metadata": {},
|
| 334 |
+
"outputs": [
|
| 335 |
+
{
|
| 336 |
+
"name": "stdout",
|
| 337 |
+
"output_type": "stream",
|
| 338 |
+
"text": [
|
| 339 |
+
"<think>Okay, let's see. The user is asking how many items are there in the described scene. Let me go through the scene description step by step.\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"So, the scene description lists each object with details like color, material, shape, rotation, 3D coordinates, and pixel coordinates. Each entry starts with \"A\" which usually indicates one item each. Let me count each one.\n",
|
| 342 |
+
"\n",
|
| 343 |
+
"First entry: \"A small green metal cylinder...\" That's one. Second: \"A small blue rubber cylinder...\" Second item. Third: \"A small cyan rubber cylinder...\" That's three. Fourth: \"A large cyan metal sphere...\" Four. Fifth: \"A large brown metal cube...\" Five. Sixth: \"A large yellow rubber cube...\" Six. Seventh: \"A large brown rubber cylinder...\" That's seven. \n",
|
| 344 |
+
"\n",
|
| 345 |
+
"Wait, did I miss any? Let me check again. The list has entries from \"A small green...\" up to the seventh one. Each sentence starts with \"A\", which suggests each is a separate item. No commas separating multiple items in a single entry. Each has different attributes and coordinates, so they must all be distinct. \n",
|
| 346 |
+
"\n",
|
| 347 |
+
"So the answer should be 7 items.\n",
|
| 348 |
+
"</think>\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"There are 7 items in the described scene. Each entry corresponds to one distinct object, listed by their properties, coordinates, and rotations.\n",
|
| 351 |
+
"None\n"
|
| 352 |
+
]
|
| 353 |
+
}
|
| 354 |
+
],
|
| 355 |
+
"source": [
|
| 356 |
+
"debug_query = format_query(qa_pairs[0])\n",
|
| 357 |
+
"print(query_r1(debug_query))"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"cell_type": "code",
|
| 362 |
+
"execution_count": 41,
|
| 363 |
+
"id": "cdc4231a-8ef4-4cf6-a575-d84ae7bbd0b5",
|
| 364 |
+
"metadata": {},
|
| 365 |
+
"outputs": [
|
| 366 |
+
{
|
| 367 |
+
"name": "stdout",
|
| 368 |
+
"output_type": "stream",
|
| 369 |
+
"text": [
|
| 370 |
+
"Answer the question accordingly to scene description.\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"Scene Description:\n",
|
| 373 |
+
"A small green metal cylinder rotated 329.5° located at 3D coordinates (-2.49, -1.65, 0.35) and pixel coordinates (111, 132, 11.81)\n",
|
| 374 |
+
"A small blue rubber cylinder rotated 312.2° located at 3D coordinates (-1.73, -2.91, 0.35) and pixel coordinates (76, 163, 10.57)\n",
|
| 375 |
+
"A small cyan rubber cylinder rotated 48.4° located at 3D coordinates (-2.10, -0.22, 0.35) and pixel coordinates (172, 118, 12.41)\n",
|
| 376 |
+
"A large cyan metal sphere rotated 27.4° located at 3D coordinates (1.52, -1.26, 0.70) and pixel coordinates (247, 181, 9.33)\n",
|
| 377 |
+
"A large brown metal cube rotated 107.7° located at 3D coordinates (-0.73, 2.39, 0.70) and pixel coordinates (290, 92, 12.93)\n",
|
| 378 |
+
"A large yellow rubber cube rotated 288.2° located at 3D coordinates (0.52, 0.63, 0.70) and pixel coordinates (279, 130, 11.09)\n",
|
| 379 |
+
"A large brown rubber cylinder rotated 229.8° located at 3D coordinates (2.38, 0.38, 0.70) and pixel coordinates (343, 166, 9.77)\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"Question:\n",
|
| 382 |
+
"How many items are there in the described scene?\n"
|
| 383 |
+
]
|
| 384 |
+
}
|
| 385 |
+
],
|
| 386 |
+
"source": [
|
| 387 |
+
"print(debug_query)"
|
| 388 |
+
]
|
| 389 |
+
},
|
| 390 |
+
{
|
| 391 |
+
"cell_type": "code",
|
| 392 |
+
"execution_count": 42,
|
| 393 |
+
"id": "4cf90eb6-2cce-4e3d-8190-c44168a66dca",
|
| 394 |
+
"metadata": {},
|
| 395 |
+
"outputs": [
|
| 396 |
+
{
|
| 397 |
+
"data": {
|
| 398 |
+
"text/plain": [
|
| 399 |
+
"{'img_filename': 'CLEVR_train_044000.png',\n",
|
| 400 |
+
" 'q': 'How many rubber objects are either small blue spheres or small things?',\n",
|
| 401 |
+
" 'a': '2',\n",
|
| 402 |
+
" 'description': 'Scene Description:\\nA large purple rubber sphere rotated 78.4° located at 3D coordinates (2.27, 0.87, 0.70) and pixel coordinates (360, 156, 9.49)\\nA large gray metal cube rotated 152.7° located at 3D coordinates (2.79, -1.26, 0.70) and pixel coordinates (301, 213, 7.91)\\nA large purple metal sphere rotated 79.2° located at 3D coordinates (-2.66, -2.74, 0.70) and pixel coordinates (51, 126, 10.61)\\nA large blue rubber sphere rotated 279.5° located at 3D coordinates (1.31, 2.72, 0.70) and pixel coordinates (376, 112, 11.19)\\nA small brown rubber cube rotated 124.1° located at 3D coordinates (-2.49, 2.61, 0.35) and pixel coordinates (251, 82, 13.79)\\nA small green rubber sphere rotated 323.9° located at 3D coordinates (-2.02, 0.45, 0.35) and pixel coordinates (197, 109, 12.22)\\n'}"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
"execution_count": 42,
|
| 406 |
+
"metadata": {},
|
| 407 |
+
"output_type": "execute_result"
|
| 408 |
+
}
|
| 409 |
+
],
|
| 410 |
+
"source": [
|
| 411 |
+
"qa_pairs[1]"
|
| 412 |
+
]
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"cell_type": "code",
|
| 416 |
+
"execution_count": 45,
|
| 417 |
+
"id": "33fcd4eb-1f33-47d4-a453-76ef00e6d5d3",
|
| 418 |
+
"metadata": {},
|
| 419 |
+
"outputs": [
|
| 420 |
+
{
|
| 421 |
+
"name": "stdout",
|
| 422 |
+
"output_type": "stream",
|
| 423 |
+
"text": [
|
| 424 |
+
"<think>Okay, let's tackle this question. The user wants to know how many rubber objects are either small blue spheres or small things. Hmm, first, I need to parse each part of the question correctly.\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"Let's go back to the scene description. The objects are listed with their attributes: size, color, material, shape, rotation, 3D and pixel coordinates. The key here is to filter the objects based on the criteria given.\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"The question has two parts: \"small blue spheres\" OR \"small things\". So any rubber object that is either a small blue sphere or any small thing (regardless of other attributes) counts. But wait, do both categories need to be rubber? Because the question says \"rubber objects are either...\". So rubber is the material, and the condition is either being a small blue sphere or a small thing. So first, all rubber objects, and among them, count those that are either small blue spheres or small (regardless of color or shape). Wait, no. Let me re-read the question.\n",
|
| 429 |
+
"\n",
|
| 430 |
+
"The question is: \"How many rubber objects are either small blue spheres or small things?\" So rubber is the material. Within all rubber objects, count the ones that are either (1) small blue spheres OR (2) small things. Wait, does (2) being small things mean that even if they're small and of any color or shape, but they must be rubber?\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"Yes, because the entire set is rubber objects. So first, select all objects where material is rubber. Then, within those, count how many are either (1) small, blue, sphere, or (2) small (any color or shape). Wait, but the structure is \"either X or Y\", where X is \"small blue sphere\" and Y is \"small things\". But \"small things\" would include all small objects, regardless of color and shape. However, since we've already narrowed it to rubber objects, \"small things\" here would be small rubber objects, regardless of color and shape.\n",
|
| 433 |
+
"\n",
|
| 434 |
+
"But wait, the condition is within rubber objects. So for the first part, small blue spheres (must check size, color, shape) and for the second part, small things (size is small, any color and shape, but since material is already rubber, that's covered). But wait, does the OR merge the two conditions, leading to rubber objects that are either (small blue spheres) or (small any-color any-shape).\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"So the combined condition is: object is rubber AND ( (is small AND blue AND sphere) OR (is small) ). Wait, but if the condition for the second part is just \"small things\", which would imply any small object. But the entire group is rubber objects. So it's rubber objects that are small blue spheres OR rubber objects that are small (regardless of color or shape).\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"Wait, no. Let's parse the sentence again: \"rubber objects are either small blue spheres or small things\". The \"either/or\" applies to \"small blue spheres\" and \"small things\". So, each rubber object has to be either (a small blue sphere) or (a small thing). However, \"small things\" here might refer to any small object regardless of other attributes. So if a rubber object is small, regardless of color or shape, it counts. But then, the first condition (small blue sphere) would also satisfy being a small thing. Wait, so there's an overlap. But when dealing with OR conditions, we have to avoid double-counting. So, the actual count is the number of small rubber objects (since any small rubber object is covered by the second part, which includes all small rubber objects, whether blue spheres or not) plus any objects that are small blue spheres but not rubber? But no, the question specifies \"rubber objects\", so we can ignore non-rubber ones.\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"Wait, perhaps the wording is: \"rubber objects that are either small blue spheres or small things\". So \"small things\" here must reference other attributes. Wait, maybe there's ambiguity here. If the user is grouping \"small things\" as a separate category, regardless of being the other attributes. Let me try to approach this step by step.\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"First, list all the rubber objects from the scene description:\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"Looking through the list:\n",
|
| 445 |
+
"\n",
|
| 446 |
+
"1. A large purple rubber sphere ... location etc.\n",
|
| 447 |
+
"So material rubber, large, purple, sphere.\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"2. A large gray metal cube ... metal, so not rubber.\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"3. A large purple metal sphere ... metal, not rubber.\n",
|
| 452 |
+
"\n",
|
| 453 |
+
"4. A large blue rubber sphere ... rubber, large, blue, sphere.\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"5. A small brown rubber cube ... rubber, small, brown, cube.\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"6. A small green rubber sphere ... rubber, small, green, sphere.\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"So the rubber objects are items 1,4,5,6.\n",
|
| 460 |
+
"\n",
|
| 461 |
+
"Now, for each of these four rubber objects, check if they are either (small blue sphere) or (small things).\n",
|
| 462 |
+
"\n",
|
| 463 |
+
"Let's check each:\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"1. Large purple rubber sphere. Size: large. So for the first condition (small blue sphere): no. For the second condition (small thing): size large, so no. Not included.\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"4. Large blue rubber sphere. Large. So, even though it's blue and sphere, the size is large. So (small blue sphere: no, since size is large). (small thing: no.) So not included.\n",
|
| 468 |
+
"\n",
|
| 469 |
+
"5. Small brown rubber cube. Size is small. So regardless of color and shape (brown, cube), this is a small thing. So it meets the second condition. So count it.\n",
|
| 470 |
+
"\n",
|
| 471 |
+
"6. Small green rubber sphere. Size: small. So meets the second condition (small thing). But it's a sphere and green, not blue. So this counts under the second category.\n",
|
| 472 |
+
"\n",
|
| 473 |
+
"Now, what about the first condition: small blue spheres. Are there any rubber objects that are small blue spheres? The only blue rubber sphere is item 4, which is large. So none of the rubber objects are small blue spheres.\n",
|
| 474 |
+
"\n",
|
| 475 |
+
"So now the two options under the OR are:\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"- small blue sphere (none found)\n",
|
| 478 |
+
"- small things (items 5 and 6), both are rubber.\n",
|
| 479 |
+
"\n",
|
| 480 |
+
"Therefore total count is 0 + 2 = 2.\n",
|
| 481 |
+
"\n",
|
| 482 |
+
"Wait, but the initial question is: number of rubber objects that are either (small blue spheres) or (small things). Since \"small blue spheres\" are a subset of \"small things\" (since if something is a small blue sphere, it is a small thing). Therefore if there were any such objects, they would be counted twice, but in this case, since there are none, the total is just the count of small things.\n",
|
| 483 |
+
"\n",
|
| 484 |
+
"But according to the data, the rubber objects that are small are two: item 5 (small brown cube) and item 6 (small green sphere). Neither is a blue sphere, so no overlap. Therefore, sum is 2.\n",
|
| 485 |
+
"\n",
|
| 486 |
+
"So the answer should be 2.\n",
|
| 487 |
+
"\n",
|
| 488 |
+
"I need to confirm again whether the OR condition includes the union, meaning any rubber object that is either condition. So if either condition is true, count it.\n",
|
| 489 |
+
"\n",
|
| 490 |
+
"Since both 5 and 6 are small, even though they aren't blue spheres, they belong to the second category (small things). The blue sphere (if it were small) would have been counted as well. But in this case, the valid counts are 5 and 6. So yes, answer is 2.\n",
|
| 491 |
+
"</think>\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"The answer is 2. \n",
|
| 494 |
+
"\n",
|
| 495 |
+
"**Explanation:**\n",
|
| 496 |
+
"First, identify all rubber objects from the scene description:\n",
|
| 497 |
+
"1. Large purple rubber sphere (not small)\n",
|
| 498 |
+
"2. Large blue rubber sphere (not small)\n",
|
| 499 |
+
"3. Small brown rubber cube (small)\n",
|
| 500 |
+
"4. Small green rubber sphere (small)\n",
|
| 501 |
+
"\n",
|
| 502 |
+
"Next, apply the criteria:\n",
|
| 503 |
+
"- **Small blue spheres**: None of the rubber objects meet this (the only blue rubber sphere is large).\n",
|
| 504 |
+
"- **Small rubber objects (regardless of color/shape)**: The small brown rubber cube and small green rubber sphere qualify (2 objects).\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"Thus, there are **2 rubber objects** that fit either criterion.\n"
|
| 507 |
+
]
|
| 508 |
+
}
|
| 509 |
+
],
|
| 510 |
+
"source": [
|
| 511 |
+
"debug_query1 = format_query(qa_pairs[1])\n",
|
| 512 |
+
"res1 = query_r1(debug_query1)"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
{
|
| 516 |
+
"cell_type": "code",
|
| 517 |
+
"execution_count": 47,
|
| 518 |
+
"id": "8e516bd0-f1e5-4898-88a3-3afcaf0ae34e",
|
| 519 |
+
"metadata": {},
|
| 520 |
+
"outputs": [
|
| 521 |
+
{
|
| 522 |
+
"data": {
|
| 523 |
+
"text/plain": [
|
| 524 |
+
"{'img_filename': 'CLEVR_train_044000.png',\n",
|
| 525 |
+
" 'q': 'How many rubber objects are either small blue spheres or small things?',\n",
|
| 526 |
+
" 'a': '2',\n",
|
| 527 |
+
" 'description': 'Scene Description:\\nA large purple rubber sphere rotated 78.4° located at 3D coordinates (2.27, 0.87, 0.70) and pixel coordinates (360, 156, 9.49)\\nA large gray metal cube rotated 152.7° located at 3D coordinates (2.79, -1.26, 0.70) and pixel coordinates (301, 213, 7.91)\\nA large purple metal sphere rotated 79.2° located at 3D coordinates (-2.66, -2.74, 0.70) and pixel coordinates (51, 126, 10.61)\\nA large blue rubber sphere rotated 279.5° located at 3D coordinates (1.31, 2.72, 0.70) and pixel coordinates (376, 112, 11.19)\\nA small brown rubber cube rotated 124.1° located at 3D coordinates (-2.49, 2.61, 0.35) and pixel coordinates (251, 82, 13.79)\\nA small green rubber sphere rotated 323.9° located at 3D coordinates (-2.02, 0.45, 0.35) and pixel coordinates (197, 109, 12.22)\\n'}"
|
| 528 |
+
]
|
| 529 |
+
},
|
| 530 |
+
"execution_count": 47,
|
| 531 |
+
"metadata": {},
|
| 532 |
+
"output_type": "execute_result"
|
| 533 |
+
}
|
| 534 |
+
],
|
| 535 |
+
"source": [
|
| 536 |
+
"qa_pairs[1]"
|
| 537 |
+
]
|
| 538 |
+
},
|
| 539 |
+
{
|
| 540 |
+
"cell_type": "code",
|
| 541 |
+
"execution_count": null,
|
| 542 |
+
"id": "92784518-49e2-443d-9541-2785cbb944cf",
|
| 543 |
+
"metadata": {},
|
| 544 |
+
"outputs": [],
|
| 545 |
+
"source": []
|
| 546 |
+
}
|
| 547 |
+
],
|
| 548 |
+
"metadata": {
|
| 549 |
+
"kernelspec": {
|
| 550 |
+
"display_name": "Python 3 (ipykernel)",
|
| 551 |
+
"language": "python",
|
| 552 |
+
"name": "python3"
|
| 553 |
+
},
|
| 554 |
+
"language_info": {
|
| 555 |
+
"codemirror_mode": {
|
| 556 |
+
"name": "ipython",
|
| 557 |
+
"version": 3
|
| 558 |
+
},
|
| 559 |
+
"file_extension": ".py",
|
| 560 |
+
"mimetype": "text/x-python",
|
| 561 |
+
"name": "python",
|
| 562 |
+
"nbconvert_exporter": "python",
|
| 563 |
+
"pygments_lexer": "ipython3",
|
| 564 |
+
"version": "3.12.2"
|
| 565 |
+
}
|
| 566 |
+
},
|
| 567 |
+
"nbformat": 4,
|
| 568 |
+
"nbformat_minor": 5
|
| 569 |
+
}
|
previous_version/Video-R1-main-previous/src/distill_r1/grpo_r1_distilled.jpg
ADDED
|
Git LFS Details
|
previous_version/Video-R1-main-previous/src/distill_r1/query_r1.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
import os
|
| 4 |
+
from openai import OpenAI
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from threading import Lock
|
| 10 |
+
import time
|
| 11 |
+
from prompt import R1_SYS_PROMPT
|
| 12 |
+
# Initialize the client
|
| 13 |
+
client = OpenAI(
|
| 14 |
+
api_key=os.environ.get("SL_KEY", "YOUR_SILCONFLOW_KEY"),
|
| 15 |
+
base_url="https://api.siliconflow.cn/v1",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Create a lock for thread-safe file writing
|
| 19 |
+
file_lock = Lock()
|
| 20 |
+
|
| 21 |
+
def format_query(qa_dict: Dict, v2=False) -> str:
|
| 22 |
+
query = "Answer the question according to scene description.\n\n"
|
| 23 |
+
query += qa_dict["description"]
|
| 24 |
+
query += f"\nQuestion:\n{qa_dict['q']}"
|
| 25 |
+
if v2:
|
| 26 |
+
query += "\nInstructions:\n"
|
| 27 |
+
query += "1. Carefully analyze the scene description\n"
|
| 28 |
+
query += "2. Provide your reasoning if necessary\n"
|
| 29 |
+
query += "3. For the final answer, start a new line with '**The answer is: **' followed by your answer\n"
|
| 30 |
+
return query
|
| 31 |
+
|
| 32 |
+
def write_to_jsonl(result: Dict, filename: str):
|
| 33 |
+
"""Thread-safe function to write a result to JSONL file"""
|
| 34 |
+
with file_lock:
|
| 35 |
+
with open(filename, 'a') as f:
|
| 36 |
+
f.write(json.dumps(result) + '\n')
|
| 37 |
+
|
| 38 |
+
def query_r1(qa_pair: Dict, output_file: str, model: str = "deepseek-ai/DeepSeek-R1", v2=False) -> Optional[Dict]:
|
| 39 |
+
query = format_query(qa_pair, v2=v2)
|
| 40 |
+
try:
|
| 41 |
+
response = client.chat.completions.create(
|
| 42 |
+
model=model,
|
| 43 |
+
messages=[
|
| 44 |
+
{"role": "system", "content": R1_SYS_PROMPT},
|
| 45 |
+
{"role": "user", "content": query}],
|
| 46 |
+
stream=False,
|
| 47 |
+
max_tokens=4096
|
| 48 |
+
)
|
| 49 |
+
result = {
|
| 50 |
+
**qa_pair,
|
| 51 |
+
"r1_response": response.choices[0].message.content,
|
| 52 |
+
"timestamp": datetime.now().isoformat()
|
| 53 |
+
}
|
| 54 |
+
# Write result immediately
|
| 55 |
+
write_to_jsonl(result, output_file)
|
| 56 |
+
time.sleep(4)
|
| 57 |
+
return result
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Error processing query: {e}")
|
| 60 |
+
error_result = {
|
| 61 |
+
**qa_pair,
|
| 62 |
+
"error": str(e),
|
| 63 |
+
"timestamp": datetime.now().isoformat()
|
| 64 |
+
}
|
| 65 |
+
write_to_jsonl(error_result, f"errors_{output_file}")
|
| 66 |
+
time.sleep(10)
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def process_qa_pairs_parallel(qa_pairs: List[Dict], output_file: str, max_workers: int = 10) -> List[Dict]:
|
| 70 |
+
successful_count = 0
|
| 71 |
+
|
| 72 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 73 |
+
# Create futures for all qa_pairs
|
| 74 |
+
futures = [executor.submit(query_r1, qa_pair, output_file, v2="v2" in output_file) for qa_pair in qa_pairs]
|
| 75 |
+
|
| 76 |
+
# Process results as they complete with progress bar
|
| 77 |
+
results = []
|
| 78 |
+
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
| 79 |
+
try:
|
| 80 |
+
result = future.result()
|
| 81 |
+
if result is not None:
|
| 82 |
+
results.append(result)
|
| 83 |
+
successful_count += 1
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Failed to process query: {e}")
|
| 86 |
+
|
| 87 |
+
return results
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
# Load and shuffle QA pairs
|
| 91 |
+
random.seed(1234)
|
| 92 |
+
qa_pairs = json.load(open("/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_trainA.json"))
|
| 93 |
+
random.shuffle(qa_pairs)
|
| 94 |
+
qa_pairs = qa_pairs[:10000]
|
| 95 |
+
# Create output filename with timestamp
|
| 96 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 97 |
+
output_file = f"r1_results_clevr_cogent_v1.0_trainA_v2.jsonl"
|
| 98 |
+
|
| 99 |
+
finished = set()
|
| 100 |
+
with open(output_file, 'r') as f:
|
| 101 |
+
for line in f:
|
| 102 |
+
ins = json.loads(line)
|
| 103 |
+
key = ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"])
|
| 104 |
+
finished.add(key)
|
| 105 |
+
qa_pairs = [ins for ins in qa_pairs if ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"]) not in finished]
|
| 106 |
+
print("Finished: ", len(finished))
|
| 107 |
+
print("Remaining: ", len(qa_pairs))
|
| 108 |
+
# Process QA pairs in parallel
|
| 109 |
+
r1_results = process_qa_pairs_parallel(qa_pairs, output_file)
|
| 110 |
+
|
| 111 |
+
# Print final statistics
|
| 112 |
+
print(f"Successfully processed {len(r1_results)} out of {len(qa_pairs)} queries")
|
| 113 |
+
print(f"Results saved to {output_file}")
|
| 114 |
+
print(f"Any errors were saved to errors_{output_file}")
|
previous_version/Video-R1-main-previous/src/eval/prompts/geoqa_test_prompts.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
previous_version/Video-R1-main-previous/src/eval/prompts/superclevr_test200_counting_problems.jsonl
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"image_path": "./images/superCLEVR_new_025000.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 2 |
+
{"image_path": "./images/superCLEVR_new_025001.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 3 |
+
{"image_path": "./images/superCLEVR_new_025002.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 4 |
+
{"image_path": "./images/superCLEVR_new_025003.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 5 |
+
{"image_path": "./images/superCLEVR_new_025004.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 6 |
+
{"image_path": "./images/superCLEVR_new_025005.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 7 |
+
{"image_path": "./images/superCLEVR_new_025006.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 8 |
+
{"image_path": "./images/superCLEVR_new_025007.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 9 |
+
{"image_path": "./images/superCLEVR_new_025008.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 10 |
+
{"image_path": "./images/superCLEVR_new_025009.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 11 |
+
{"image_path": "./images/superCLEVR_new_025010.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 12 |
+
{"image_path": "./images/superCLEVR_new_025011.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 13 |
+
{"image_path": "./images/superCLEVR_new_025012.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 14 |
+
{"image_path": "./images/superCLEVR_new_025013.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 15 |
+
{"image_path": "./images/superCLEVR_new_025014.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 16 |
+
{"image_path": "./images/superCLEVR_new_025015.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 17 |
+
{"image_path": "./images/superCLEVR_new_025016.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 18 |
+
{"image_path": "./images/superCLEVR_new_025017.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 19 |
+
{"image_path": "./images/superCLEVR_new_025018.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 20 |
+
{"image_path": "./images/superCLEVR_new_025019.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 21 |
+
{"image_path": "./images/superCLEVR_new_025020.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 22 |
+
{"image_path": "./images/superCLEVR_new_025021.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 23 |
+
{"image_path": "./images/superCLEVR_new_025022.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 24 |
+
{"image_path": "./images/superCLEVR_new_025023.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 25 |
+
{"image_path": "./images/superCLEVR_new_025024.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 26 |
+
{"image_path": "./images/superCLEVR_new_025025.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 27 |
+
{"image_path": "./images/superCLEVR_new_025026.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 28 |
+
{"image_path": "./images/superCLEVR_new_025027.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 29 |
+
{"image_path": "./images/superCLEVR_new_025028.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 30 |
+
{"image_path": "./images/superCLEVR_new_025029.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 31 |
+
{"image_path": "./images/superCLEVR_new_025030.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 32 |
+
{"image_path": "./images/superCLEVR_new_025031.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 33 |
+
{"image_path": "./images/superCLEVR_new_025032.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 34 |
+
{"image_path": "./images/superCLEVR_new_025033.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 35 |
+
{"image_path": "./images/superCLEVR_new_025034.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 36 |
+
{"image_path": "./images/superCLEVR_new_025035.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 37 |
+
{"image_path": "./images/superCLEVR_new_025036.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 38 |
+
{"image_path": "./images/superCLEVR_new_025037.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 39 |
+
{"image_path": "./images/superCLEVR_new_025038.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 40 |
+
{"image_path": "./images/superCLEVR_new_025039.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 41 |
+
{"image_path": "./images/superCLEVR_new_025040.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 42 |
+
{"image_path": "./images/superCLEVR_new_025041.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 43 |
+
{"image_path": "./images/superCLEVR_new_025042.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 44 |
+
{"image_path": "./images/superCLEVR_new_025043.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 45 |
+
{"image_path": "./images/superCLEVR_new_025044.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 46 |
+
{"image_path": "./images/superCLEVR_new_025045.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 47 |
+
{"image_path": "./images/superCLEVR_new_025046.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 48 |
+
{"image_path": "./images/superCLEVR_new_025047.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 49 |
+
{"image_path": "./images/superCLEVR_new_025048.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 50 |
+
{"image_path": "./images/superCLEVR_new_025049.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 51 |
+
{"image_path": "./images/superCLEVR_new_025050.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 52 |
+
{"image_path": "./images/superCLEVR_new_025051.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 53 |
+
{"image_path": "./images/superCLEVR_new_025052.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 54 |
+
{"image_path": "./images/superCLEVR_new_025053.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 55 |
+
{"image_path": "./images/superCLEVR_new_025054.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 56 |
+
{"image_path": "./images/superCLEVR_new_025055.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 57 |
+
{"image_path": "./images/superCLEVR_new_025056.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 58 |
+
{"image_path": "./images/superCLEVR_new_025057.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 59 |
+
{"image_path": "./images/superCLEVR_new_025058.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 60 |
+
{"image_path": "./images/superCLEVR_new_025059.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 61 |
+
{"image_path": "./images/superCLEVR_new_025060.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 62 |
+
{"image_path": "./images/superCLEVR_new_025061.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 63 |
+
{"image_path": "./images/superCLEVR_new_025062.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 64 |
+
{"image_path": "./images/superCLEVR_new_025063.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 65 |
+
{"image_path": "./images/superCLEVR_new_025064.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 66 |
+
{"image_path": "./images/superCLEVR_new_025065.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 67 |
+
{"image_path": "./images/superCLEVR_new_025066.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 68 |
+
{"image_path": "./images/superCLEVR_new_025067.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 69 |
+
{"image_path": "./images/superCLEVR_new_025068.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 70 |
+
{"image_path": "./images/superCLEVR_new_025069.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 71 |
+
{"image_path": "./images/superCLEVR_new_025070.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 72 |
+
{"image_path": "./images/superCLEVR_new_025071.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 73 |
+
{"image_path": "./images/superCLEVR_new_025072.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 74 |
+
{"image_path": "./images/superCLEVR_new_025073.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 75 |
+
{"image_path": "./images/superCLEVR_new_025074.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 76 |
+
{"image_path": "./images/superCLEVR_new_025075.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 77 |
+
{"image_path": "./images/superCLEVR_new_025076.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 78 |
+
{"image_path": "./images/superCLEVR_new_025077.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 79 |
+
{"image_path": "./images/superCLEVR_new_025078.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 80 |
+
{"image_path": "./images/superCLEVR_new_025079.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 81 |
+
{"image_path": "./images/superCLEVR_new_025080.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 82 |
+
{"image_path": "./images/superCLEVR_new_025081.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 83 |
+
{"image_path": "./images/superCLEVR_new_025082.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 84 |
+
{"image_path": "./images/superCLEVR_new_025083.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 85 |
+
{"image_path": "./images/superCLEVR_new_025084.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 86 |
+
{"image_path": "./images/superCLEVR_new_025085.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 87 |
+
{"image_path": "./images/superCLEVR_new_025086.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 88 |
+
{"image_path": "./images/superCLEVR_new_025087.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 89 |
+
{"image_path": "./images/superCLEVR_new_025088.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 90 |
+
{"image_path": "./images/superCLEVR_new_025089.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 91 |
+
{"image_path": "./images/superCLEVR_new_025090.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 92 |
+
{"image_path": "./images/superCLEVR_new_025091.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 93 |
+
{"image_path": "./images/superCLEVR_new_025092.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 94 |
+
{"image_path": "./images/superCLEVR_new_025093.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 95 |
+
{"image_path": "./images/superCLEVR_new_025094.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 96 |
+
{"image_path": "./images/superCLEVR_new_025095.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 97 |
+
{"image_path": "./images/superCLEVR_new_025096.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 98 |
+
{"image_path": "./images/superCLEVR_new_025097.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 99 |
+
{"image_path": "./images/superCLEVR_new_025098.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 100 |
+
{"image_path": "./images/superCLEVR_new_025099.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 101 |
+
{"image_path": "./images/superCLEVR_new_025100.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 102 |
+
{"image_path": "./images/superCLEVR_new_025101.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 103 |
+
{"image_path": "./images/superCLEVR_new_025102.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 104 |
+
{"image_path": "./images/superCLEVR_new_025103.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 105 |
+
{"image_path": "./images/superCLEVR_new_025104.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 106 |
+
{"image_path": "./images/superCLEVR_new_025105.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 107 |
+
{"image_path": "./images/superCLEVR_new_025106.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 108 |
+
{"image_path": "./images/superCLEVR_new_025107.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 109 |
+
{"image_path": "./images/superCLEVR_new_025108.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 110 |
+
{"image_path": "./images/superCLEVR_new_025109.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 111 |
+
{"image_path": "./images/superCLEVR_new_025110.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 112 |
+
{"image_path": "./images/superCLEVR_new_025111.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 113 |
+
{"image_path": "./images/superCLEVR_new_025112.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 114 |
+
{"image_path": "./images/superCLEVR_new_025113.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 115 |
+
{"image_path": "./images/superCLEVR_new_025114.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 116 |
+
{"image_path": "./images/superCLEVR_new_025115.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 117 |
+
{"image_path": "./images/superCLEVR_new_025116.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 118 |
+
{"image_path": "./images/superCLEVR_new_025117.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 119 |
+
{"image_path": "./images/superCLEVR_new_025118.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 120 |
+
{"image_path": "./images/superCLEVR_new_025119.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 121 |
+
{"image_path": "./images/superCLEVR_new_025120.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 122 |
+
{"image_path": "./images/superCLEVR_new_025121.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 123 |
+
{"image_path": "./images/superCLEVR_new_025122.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 124 |
+
{"image_path": "./images/superCLEVR_new_025123.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 125 |
+
{"image_path": "./images/superCLEVR_new_025124.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 126 |
+
{"image_path": "./images/superCLEVR_new_025125.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 127 |
+
{"image_path": "./images/superCLEVR_new_025126.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 128 |
+
{"image_path": "./images/superCLEVR_new_025127.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 129 |
+
{"image_path": "./images/superCLEVR_new_025128.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 130 |
+
{"image_path": "./images/superCLEVR_new_025129.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 131 |
+
{"image_path": "./images/superCLEVR_new_025130.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 132 |
+
{"image_path": "./images/superCLEVR_new_025131.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 133 |
+
{"image_path": "./images/superCLEVR_new_025132.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 134 |
+
{"image_path": "./images/superCLEVR_new_025133.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 135 |
+
{"image_path": "./images/superCLEVR_new_025134.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 136 |
+
{"image_path": "./images/superCLEVR_new_025135.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 137 |
+
{"image_path": "./images/superCLEVR_new_025136.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 138 |
+
{"image_path": "./images/superCLEVR_new_025137.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 139 |
+
{"image_path": "./images/superCLEVR_new_025138.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 140 |
+
{"image_path": "./images/superCLEVR_new_025139.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 141 |
+
{"image_path": "./images/superCLEVR_new_025140.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 142 |
+
{"image_path": "./images/superCLEVR_new_025141.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 143 |
+
{"image_path": "./images/superCLEVR_new_025142.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 144 |
+
{"image_path": "./images/superCLEVR_new_025143.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 145 |
+
{"image_path": "./images/superCLEVR_new_025144.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 146 |
+
{"image_path": "./images/superCLEVR_new_025145.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 147 |
+
{"image_path": "./images/superCLEVR_new_025146.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 148 |
+
{"image_path": "./images/superCLEVR_new_025147.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 149 |
+
{"image_path": "./images/superCLEVR_new_025148.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 150 |
+
{"image_path": "./images/superCLEVR_new_025149.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 151 |
+
{"image_path": "./images/superCLEVR_new_025150.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 152 |
+
{"image_path": "./images/superCLEVR_new_025151.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 153 |
+
{"image_path": "./images/superCLEVR_new_025152.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 154 |
+
{"image_path": "./images/superCLEVR_new_025153.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 155 |
+
{"image_path": "./images/superCLEVR_new_025154.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 156 |
+
{"image_path": "./images/superCLEVR_new_025155.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 157 |
+
{"image_path": "./images/superCLEVR_new_025156.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 158 |
+
{"image_path": "./images/superCLEVR_new_025157.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 159 |
+
{"image_path": "./images/superCLEVR_new_025158.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 160 |
+
{"image_path": "./images/superCLEVR_new_025159.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 161 |
+
{"image_path": "./images/superCLEVR_new_025160.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 162 |
+
{"image_path": "./images/superCLEVR_new_025161.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 163 |
+
{"image_path": "./images/superCLEVR_new_025162.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 164 |
+
{"image_path": "./images/superCLEVR_new_025163.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 165 |
+
{"image_path": "./images/superCLEVR_new_025164.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 166 |
+
{"image_path": "./images/superCLEVR_new_025165.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 167 |
+
{"image_path": "./images/superCLEVR_new_025166.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 168 |
+
{"image_path": "./images/superCLEVR_new_025167.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 169 |
+
{"image_path": "./images/superCLEVR_new_025168.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 170 |
+
{"image_path": "./images/superCLEVR_new_025169.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 171 |
+
{"image_path": "./images/superCLEVR_new_025170.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 172 |
+
{"image_path": "./images/superCLEVR_new_025171.png", "question": "How many different items are there in the image?", "ground_truth": 7}
|
| 173 |
+
{"image_path": "./images/superCLEVR_new_025172.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 174 |
+
{"image_path": "./images/superCLEVR_new_025173.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 175 |
+
{"image_path": "./images/superCLEVR_new_025174.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 176 |
+
{"image_path": "./images/superCLEVR_new_025175.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 177 |
+
{"image_path": "./images/superCLEVR_new_025176.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 178 |
+
{"image_path": "./images/superCLEVR_new_025177.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 179 |
+
{"image_path": "./images/superCLEVR_new_025178.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 180 |
+
{"image_path": "./images/superCLEVR_new_025179.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 181 |
+
{"image_path": "./images/superCLEVR_new_025180.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 182 |
+
{"image_path": "./images/superCLEVR_new_025181.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 183 |
+
{"image_path": "./images/superCLEVR_new_025182.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 184 |
+
{"image_path": "./images/superCLEVR_new_025183.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 185 |
+
{"image_path": "./images/superCLEVR_new_025184.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 186 |
+
{"image_path": "./images/superCLEVR_new_025185.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 187 |
+
{"image_path": "./images/superCLEVR_new_025186.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 188 |
+
{"image_path": "./images/superCLEVR_new_025187.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 189 |
+
{"image_path": "./images/superCLEVR_new_025188.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 190 |
+
{"image_path": "./images/superCLEVR_new_025189.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 191 |
+
{"image_path": "./images/superCLEVR_new_025190.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 192 |
+
{"image_path": "./images/superCLEVR_new_025191.png", "question": "How many different items are there in the image?", "ground_truth": 8}
|
| 193 |
+
{"image_path": "./images/superCLEVR_new_025192.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 194 |
+
{"image_path": "./images/superCLEVR_new_025193.png", "question": "How many different items are there in the image?", "ground_truth": 9}
|
| 195 |
+
{"image_path": "./images/superCLEVR_new_025194.png", "question": "How many different items are there in the image?", "ground_truth": 10}
|
| 196 |
+
{"image_path": "./images/superCLEVR_new_025195.png", "question": "How many different items are there in the image?", "ground_truth": 5}
|
| 197 |
+
{"image_path": "./images/superCLEVR_new_025196.png", "question": "How many different items are there in the image?", "ground_truth": 6}
|
| 198 |
+
{"image_path": "./images/superCLEVR_new_025197.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
| 199 |
+
{"image_path": "./images/superCLEVR_new_025198.png", "question": "How many different items are there in the image?", "ground_truth": 4}
|
| 200 |
+
{"image_path": "./images/superCLEVR_new_025199.png", "question": "How many different items are there in the image?", "ground_truth": 3}
|
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_counting_superclevr.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MODEL_PATH="Qwen2-VL-2B-GRPO-CLEVR-70k/checkpoint-100" # Qwen2vl-2b-Instruct for original scores
|
| 11 |
+
BSZ=64 # reduce it if GPU OOM
|
| 12 |
+
OUTPUT_PATH="./logs/counting_results_superclevr_200_qwen2vl_2b_instruct_grpo_100.json"
|
| 13 |
+
PROMPT_PATH="./prompts/superclevr_test200_counting_problems.jsonl"
|
| 14 |
+
|
| 15 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 16 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 17 |
+
MODEL_PATH,
|
| 18 |
+
torch_dtype=torch.bfloat16,
|
| 19 |
+
attn_implementation="flash_attention_2",
|
| 20 |
+
device_map="auto",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# default processer
|
| 24 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 25 |
+
|
| 26 |
+
data = []
|
| 27 |
+
with open(PROMPT_PATH, "r") as f:
|
| 28 |
+
for line in f:
|
| 29 |
+
data.append(json.loads(line))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
| 33 |
+
|
| 34 |
+
messages = []
|
| 35 |
+
|
| 36 |
+
for i in data:
|
| 37 |
+
message = [{
|
| 38 |
+
"role": "user",
|
| 39 |
+
"content": [
|
| 40 |
+
{
|
| 41 |
+
"type": "image",
|
| 42 |
+
"image": f"file://{i['image_path']}"
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"type": "text",
|
| 46 |
+
"text": QUESTION_TEMPLATE.format(Question=i['question'])
|
| 47 |
+
}
|
| 48 |
+
]
|
| 49 |
+
}]
|
| 50 |
+
messages.append(message)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
all_outputs = [] # List to store all answers
|
| 56 |
+
|
| 57 |
+
# Process data in batches
|
| 58 |
+
for i in tqdm(range(0, len(messages), BSZ)):
|
| 59 |
+
batch_messages = messages[i:i + BSZ]
|
| 60 |
+
|
| 61 |
+
# Preparation for inference
|
| 62 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 63 |
+
|
| 64 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 65 |
+
inputs = processor(
|
| 66 |
+
text=text,
|
| 67 |
+
images=image_inputs,
|
| 68 |
+
videos=video_inputs,
|
| 69 |
+
padding=True,
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
)
|
| 72 |
+
inputs = inputs.to("cuda")
|
| 73 |
+
|
| 74 |
+
# Inference: Generation of the output
|
| 75 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 76 |
+
|
| 77 |
+
generated_ids_trimmed = [
|
| 78 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 79 |
+
]
|
| 80 |
+
batch_output_text = processor.batch_decode(
|
| 81 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
all_outputs.extend(batch_output_text)
|
| 85 |
+
print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def extract_number_answer(output_str):
|
| 89 |
+
# Try to find the number within <answer> tags, if can not find, return None
|
| 90 |
+
answer_pattern = r'<answer>\s*(\d+)\s*</answer>'
|
| 91 |
+
match = re.search(answer_pattern, output_str)
|
| 92 |
+
|
| 93 |
+
if match:
|
| 94 |
+
return int(match.group(1))
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
final_output = []
|
| 99 |
+
correct_number = 0
|
| 100 |
+
|
| 101 |
+
for input_example, model_output in zip(data,all_outputs):
|
| 102 |
+
original_output = model_output
|
| 103 |
+
ground_truth = input_example['ground_truth']
|
| 104 |
+
model_answer = extract_number_answer(original_output)
|
| 105 |
+
|
| 106 |
+
# Create a result dictionary for this example
|
| 107 |
+
result = {
|
| 108 |
+
'question': input_example,
|
| 109 |
+
'ground_truth': ground_truth,
|
| 110 |
+
'model_output': original_output,
|
| 111 |
+
'extracted_answer': model_answer
|
| 112 |
+
}
|
| 113 |
+
final_output.append(result)
|
| 114 |
+
|
| 115 |
+
# Count correct answers
|
| 116 |
+
if model_answer is not None and model_answer == ground_truth:
|
| 117 |
+
correct_number += 1
|
| 118 |
+
|
| 119 |
+
# Calculate and print accuracy
|
| 120 |
+
accuracy = correct_number / len(data) * 100
|
| 121 |
+
print(f"\nAccuracy: {accuracy:.2f}%")
|
| 122 |
+
|
| 123 |
+
# Save results to a JSON file
|
| 124 |
+
output_path = OUTPUT_PATH
|
| 125 |
+
with open(output_path, "w") as f:
|
| 126 |
+
json.dump({
|
| 127 |
+
'accuracy': accuracy,
|
| 128 |
+
'results': final_output
|
| 129 |
+
}, f, indent=2)
|
| 130 |
+
|
| 131 |
+
print(f"Results saved to {output_path}")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
from math_verify import parse, verify
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
MODEL_PATH="<MODEL_PATH>" # qwen2vl model or grpoed model on geoqa train
|
| 11 |
+
BSZ=50 # reduce it if GPU OOM
|
| 12 |
+
OUTPUT_PATH="<OUTPUT_LOG>"
|
| 13 |
+
PROMPT_PATH="./prompts/geoqa_test_prompts.jsonl"
|
| 14 |
+
|
| 15 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 16 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 17 |
+
MODEL_PATH,
|
| 18 |
+
torch_dtype=torch.bfloat16,
|
| 19 |
+
attn_implementation="flash_attention_2",
|
| 20 |
+
device_map="auto",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# default processer
|
| 24 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 25 |
+
|
| 26 |
+
data = []
|
| 27 |
+
with open(PROMPT_PATH, "r") as f:
|
| 28 |
+
for line in f:
|
| 29 |
+
data.append(json.loads(line))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
| 33 |
+
|
| 34 |
+
messages = []
|
| 35 |
+
|
| 36 |
+
data = data
|
| 37 |
+
|
| 38 |
+
for i in data:
|
| 39 |
+
message = [{
|
| 40 |
+
"role": "user",
|
| 41 |
+
"content": [
|
| 42 |
+
{
|
| 43 |
+
"type": "image",
|
| 44 |
+
"image": f"file://{i['image_path']}"
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"type": "text",
|
| 48 |
+
"text": QUESTION_TEMPLATE.format(Question=i['question'])
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
}]
|
| 52 |
+
messages.append(message)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
all_outputs = [] # List to store all answers
|
| 58 |
+
|
| 59 |
+
# Process data in batches
|
| 60 |
+
for i in tqdm(range(0, len(messages), BSZ)):
|
| 61 |
+
batch_messages = messages[i:i + BSZ]
|
| 62 |
+
|
| 63 |
+
# Preparation for inference
|
| 64 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 65 |
+
|
| 66 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 67 |
+
inputs = processor(
|
| 68 |
+
text=text,
|
| 69 |
+
images=image_inputs,
|
| 70 |
+
videos=video_inputs,
|
| 71 |
+
padding=True,
|
| 72 |
+
return_tensors="pt",
|
| 73 |
+
)
|
| 74 |
+
inputs = inputs.to("cuda")
|
| 75 |
+
|
| 76 |
+
# Inference: Generation of the output
|
| 77 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False)
|
| 78 |
+
|
| 79 |
+
generated_ids_trimmed = [
|
| 80 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 81 |
+
]
|
| 82 |
+
batch_output_text = processor.batch_decode(
|
| 83 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
all_outputs.extend(batch_output_text)
|
| 87 |
+
print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
final_output = []
|
| 94 |
+
correct_number = 0
|
| 95 |
+
|
| 96 |
+
for input_example, model_output in zip(data,all_outputs):
|
| 97 |
+
original_output = model_output
|
| 98 |
+
ground_truth = input_example['ground_truth']
|
| 99 |
+
model_answer = parse(original_output)
|
| 100 |
+
|
| 101 |
+
# Count correct answers
|
| 102 |
+
if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0:
|
| 103 |
+
correct_number += 1
|
| 104 |
+
is_correct = True
|
| 105 |
+
else:
|
| 106 |
+
is_correct = False
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
result = {
|
| 110 |
+
'question': input_example,
|
| 111 |
+
'ground_truth': ground_truth,
|
| 112 |
+
'model_output': original_output,
|
| 113 |
+
'extracted_answer':str(model_answer[0]) if model_answer is not None else None,
|
| 114 |
+
'is_correct':is_correct
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print("no answer parsed",e,model_answer)
|
| 119 |
+
result = {
|
| 120 |
+
'question': input_example,
|
| 121 |
+
'ground_truth': ground_truth,
|
| 122 |
+
'model_output': original_output,
|
| 123 |
+
'extracted_answer':None,
|
| 124 |
+
'is_correct':is_correct
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
final_output.append(result)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Calculate and print accuracy
|
| 133 |
+
accuracy = correct_number / len(data) * 100
|
| 134 |
+
print(f"\nAccuracy: {accuracy:.2f}%")
|
| 135 |
+
|
| 136 |
+
# Save results to a JSON file
|
| 137 |
+
output_path = OUTPUT_PATH
|
| 138 |
+
with open(output_path, "w") as f:
|
| 139 |
+
json.dump({
|
| 140 |
+
'accuracy': accuracy,
|
| 141 |
+
'results': final_output
|
| 142 |
+
}, f, indent=2, ensure_ascii=False)
|
| 143 |
+
|
| 144 |
+
print(f"Results saved to {output_path}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa_multigpu.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
import tqdm
|
| 6 |
+
from math_verify import parse, verify
|
| 7 |
+
import argparse
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from torch.multiprocessing import Process, set_start_method, Manager
|
| 10 |
+
from transformers.utils.logging import disable_progress_bar
|
| 11 |
+
disable_progress_bar()
|
| 12 |
+
|
| 13 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 14 |
+
# >>>>> 1. get evaluation configuration <<<<<
|
| 15 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 16 |
+
def get_eval_config():
|
| 17 |
+
parser = argparse.ArgumentParser(description="Inference script for GeoQA evaluation.")
|
| 18 |
+
parser.add_argument("--model_path", required=True, type=str, help="Path to the model checkpoint (e.g., qwen2vl model or a fine-tuned model).")
|
| 19 |
+
parser.add_argument("--batch_size", default=4, type=int, help="Batch size for inference. Reduce if GPU OOM (default: 50).")
|
| 20 |
+
parser.add_argument("--output_path", required=True, type=str, help="Path to save inference result (e.g., JSON file).")
|
| 21 |
+
parser.add_argument("--prompt_path", required=True, type=str, help="Path to the prompts JSONL file for GeoQA evaluation.")
|
| 22 |
+
all_gpu = ",".join(map(str, range(torch.cuda.device_count())))
|
| 23 |
+
parser.add_argument("--gpu_ids", default=all_gpu, help="comma-separated list of GPU IDs to use")
|
| 24 |
+
args = parser.parse_args()
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 28 |
+
# >>>>>>>>>> 2. load testset <<<<<<<<<<<<<
|
| 29 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 30 |
+
def prepare_test_messages(testset_path):
|
| 31 |
+
testset_data = pd.read_json(testset_path, lines=True).to_dict(orient="records")
|
| 32 |
+
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
| 33 |
+
tested_messages = []
|
| 34 |
+
for i in testset_data:
|
| 35 |
+
message = [{
|
| 36 |
+
"role": "user",
|
| 37 |
+
"content": [
|
| 38 |
+
{
|
| 39 |
+
"type": "image",
|
| 40 |
+
"image": f"file://{i['image_path']}"
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"type": "text",
|
| 44 |
+
"text": QUESTION_TEMPLATE.format(Question=i['question'])
|
| 45 |
+
}
|
| 46 |
+
]
|
| 47 |
+
}]
|
| 48 |
+
tested_messages.append(message)
|
| 49 |
+
return testset_data, tested_messages
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 55 |
+
# >>>>> 3. use several GPUs to accelerate inference at testset <<<<<
|
| 56 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 57 |
+
|
| 58 |
+
def init_model(model_path, gpu_id):
|
| 59 |
+
"""init a model(args.model_path) on a specific gpu"""
|
| 60 |
+
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 61 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 62 |
+
model_path,
|
| 63 |
+
torch_dtype=torch.bfloat16,
|
| 64 |
+
attn_implementation="flash_attention_2",
|
| 65 |
+
device_map=f"cuda:{gpu_id}",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# default processer
|
| 69 |
+
processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
|
| 70 |
+
return model, processor
|
| 71 |
+
|
| 72 |
+
def answer_a_batch_question_qwen(batch_messages, model, processor):
|
| 73 |
+
""" let qwen answer a batch of questions """
|
| 74 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 75 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 76 |
+
inputs = processor(
|
| 77 |
+
text=text,
|
| 78 |
+
images=image_inputs,
|
| 79 |
+
videos=video_inputs,
|
| 80 |
+
padding=True,
|
| 81 |
+
return_tensors="pt",
|
| 82 |
+
)
|
| 83 |
+
inputs = inputs.to(model.device)
|
| 84 |
+
|
| 85 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024) # do_sample=False
|
| 86 |
+
generated_ids_trimmed = [
|
| 87 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 88 |
+
]
|
| 89 |
+
batch_output_text = processor.batch_decode(
|
| 90 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 91 |
+
)
|
| 92 |
+
return batch_output_text
|
| 93 |
+
|
| 94 |
+
def infer_on_single_gpu(model_path, device_id, chunk_of_tested_messages, batch_size, results=None):
|
| 95 |
+
"""init model on this single gpu and let it answer asign chunk of questions"""
|
| 96 |
+
model, processor = init_model(model_path, device_id)
|
| 97 |
+
|
| 98 |
+
### split batch
|
| 99 |
+
responses = []
|
| 100 |
+
batch_messages_list = [chunk_of_tested_messages[start: start + batch_size]
|
| 101 |
+
for start in range(0, len(chunk_of_tested_messages), batch_size)]
|
| 102 |
+
|
| 103 |
+
for batch_messages in tqdm.auto.tqdm(batch_messages_list, desc=f"GPU {device_id} progress", position=device_id, leave=False):
|
| 104 |
+
batch_output_text = answer_a_batch_question_qwen(batch_messages, model, processor)
|
| 105 |
+
|
| 106 |
+
responses.extend(batch_output_text)
|
| 107 |
+
|
| 108 |
+
results[device_id] = responses
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def multi_gpu_inference(prompts, gpu_ids, model_path, batch_size):
|
| 113 |
+
""" let each gpu (along with a model) answer a chunk of questions """
|
| 114 |
+
set_start_method("spawn", force=True)
|
| 115 |
+
manager = Manager()
|
| 116 |
+
gpu_id2result = manager.dict()
|
| 117 |
+
|
| 118 |
+
gpu_ids = [int(gpu_id.strip()) for gpu_id in gpu_ids.split(',')]
|
| 119 |
+
num_gpus = len(gpu_ids)
|
| 120 |
+
|
| 121 |
+
chunk_size = len(prompts) // num_gpus
|
| 122 |
+
processes = []
|
| 123 |
+
for i, gpu_id in enumerate(gpu_ids):
|
| 124 |
+
start_idx = i * chunk_size
|
| 125 |
+
end_idx = (i + 1) * chunk_size if i != num_gpus - 1 else len(prompts)
|
| 126 |
+
chunk = prompts[start_idx: end_idx]
|
| 127 |
+
process = Process(target=infer_on_single_gpu, args=(model_path, gpu_id, chunk, batch_size, gpu_id2result))
|
| 128 |
+
process.start()
|
| 129 |
+
processes.append(process)
|
| 130 |
+
|
| 131 |
+
# for process in tqdm.auto.tqdm(processes, desc="Inference progress", position=num_gpus, leave=True):
|
| 132 |
+
for process in processes:
|
| 133 |
+
process.join()
|
| 134 |
+
|
| 135 |
+
all_predicts = []
|
| 136 |
+
for gpu_id in gpu_ids:
|
| 137 |
+
all_predicts.extend(gpu_id2result[gpu_id])
|
| 138 |
+
|
| 139 |
+
return all_predicts
|
| 140 |
+
|
| 141 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 142 |
+
# >>>>>>>>>> 4. compute metrics <<<<<<<<<<<
|
| 143 |
+
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|
| 144 |
+
|
| 145 |
+
def compute_metrics(testset_data, all_predicts):
|
| 146 |
+
final_output = []
|
| 147 |
+
correct_number = 0
|
| 148 |
+
|
| 149 |
+
for input_example, model_output in zip(testset_data, all_predicts):
|
| 150 |
+
original_output = model_output
|
| 151 |
+
ground_truth = input_example['ground_truth']
|
| 152 |
+
model_answer = parse(original_output)
|
| 153 |
+
|
| 154 |
+
# Count correct answers
|
| 155 |
+
if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0:
|
| 156 |
+
correct_number += 1
|
| 157 |
+
is_correct = True
|
| 158 |
+
else:
|
| 159 |
+
is_correct = False
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
result = {
|
| 163 |
+
'question': input_example,
|
| 164 |
+
'ground_truth': ground_truth,
|
| 165 |
+
'model_output': original_output,
|
| 166 |
+
'extracted_answer':str(model_answer[0]) if model_answer is not None else None,
|
| 167 |
+
'is_correct':is_correct
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print("no answer parsed",e,model_answer)
|
| 172 |
+
result = {
|
| 173 |
+
'question': input_example,
|
| 174 |
+
'ground_truth': ground_truth,
|
| 175 |
+
'model_output': original_output,
|
| 176 |
+
'extracted_answer':None,
|
| 177 |
+
'is_correct':is_correct
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
final_output.append(result)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# Calculate and print accuracy
|
| 186 |
+
accuracy = correct_number / len(tested_messages) * 100
|
| 187 |
+
print(f"\nAccuracy: {accuracy:.2f}%")
|
| 188 |
+
|
| 189 |
+
# Save results to a JSON file
|
| 190 |
+
with open(args.output_path, "w") as f:
|
| 191 |
+
json.dump({
|
| 192 |
+
'accuracy': accuracy,
|
| 193 |
+
'results': final_output
|
| 194 |
+
}, f, indent=2, ensure_ascii=False)
|
| 195 |
+
|
| 196 |
+
print(f"Results saved to {args.output_path}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
args = get_eval_config()
|
| 202 |
+
testset_data, tested_messages = prepare_test_messages(testset_path=args.prompt_path)
|
| 203 |
+
all_predicts = multi_gpu_inference(tested_messages, args.gpu_ids, args.model_path, args.batch_size)
|
| 204 |
+
compute_metrics(testset_data, all_predicts)
|
| 205 |
+
|
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_video_counting.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
| 2 |
+
from qwen_vl_utils import process_vision_info
|
| 3 |
+
import torch
|
| 4 |
+
import json
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
MODEL_PATH="YOUR_PATH" # Qwen2vl-2b-Instruct for original scores
|
| 12 |
+
BSZ=64 # reduce it if GPU OOM
|
| 13 |
+
OUTPUT_PATH="YOUR_PATH/test.json"
|
| 14 |
+
PROMPT_PATH="YOUR_PATH/test_dvd.jsonl"
|
| 15 |
+
|
| 16 |
+
#We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
| 17 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 18 |
+
MODEL_PATH,
|
| 19 |
+
torch_dtype=torch.bfloat16,
|
| 20 |
+
attn_implementation="flash_attention_2",
|
| 21 |
+
device_map="auto",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# default processer
|
| 25 |
+
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
| 26 |
+
|
| 27 |
+
data = []
|
| 28 |
+
with open(PROMPT_PATH, "r") as f:
|
| 29 |
+
for line in f:
|
| 30 |
+
data.append(json.loads(line))
|
| 31 |
+
|
| 32 |
+
# detailed step-by-step
|
| 33 |
+
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
|
| 34 |
+
|
| 35 |
+
messages = []
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
for x in data:
|
| 39 |
+
message = [{
|
| 40 |
+
"role": "user",
|
| 41 |
+
"content": [
|
| 42 |
+
{
|
| 43 |
+
"type": "video",
|
| 44 |
+
"video": os.getcwd() + "/src/r1-v/data" + x['video_filename'][1:]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"type": "text",
|
| 48 |
+
"text": QUESTION_TEMPLATE.format(Question=x['problem'])
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
}]
|
| 52 |
+
messages.append(message)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
all_outputs = [] # List to store all answers
|
| 58 |
+
|
| 59 |
+
# Process data in batches
|
| 60 |
+
for i in tqdm(range(0, len(messages), BSZ)):
|
| 61 |
+
batch_messages = messages[i:i + BSZ]
|
| 62 |
+
|
| 63 |
+
# Preparation for inference
|
| 64 |
+
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
image_inputs, video_inputs = process_vision_info(batch_messages)
|
| 68 |
+
inputs = processor(
|
| 69 |
+
text=text,
|
| 70 |
+
images=image_inputs,
|
| 71 |
+
videos=video_inputs,
|
| 72 |
+
padding=True,
|
| 73 |
+
return_tensors="pt",
|
| 74 |
+
)
|
| 75 |
+
inputs = inputs.to("cuda")
|
| 76 |
+
|
| 77 |
+
# Inference: Generation of the output
|
| 78 |
+
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
|
| 79 |
+
|
| 80 |
+
generated_ids_trimmed = [
|
| 81 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 82 |
+
]
|
| 83 |
+
batch_output_text = processor.batch_decode(
|
| 84 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
all_outputs.extend(batch_output_text)
|
| 89 |
+
print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def extract_number_answer(output_str):
|
| 93 |
+
# Try to find the number within <answer> tags, if can not find, return None
|
| 94 |
+
answer_pattern = r'<answer>\s*(\d+)\s*</answer>'
|
| 95 |
+
match = re.search(answer_pattern, output_str)
|
| 96 |
+
|
| 97 |
+
if match:
|
| 98 |
+
return int(match.group(1))
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
final_output = []
|
| 103 |
+
correct_number = 0
|
| 104 |
+
|
| 105 |
+
for input_example, model_output in zip(data,all_outputs):
|
| 106 |
+
original_output = model_output
|
| 107 |
+
ground_truth = extract_number_answer(input_example['solution'])
|
| 108 |
+
model_answer = extract_number_answer(original_output)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Create a result dictionary for this example
|
| 112 |
+
result = {
|
| 113 |
+
'question': input_example,
|
| 114 |
+
'ground_truth': ground_truth,
|
| 115 |
+
'model_output': original_output,
|
| 116 |
+
'extracted_answer': model_answer
|
| 117 |
+
}
|
| 118 |
+
final_output.append(result)
|
| 119 |
+
|
| 120 |
+
# Count correct answers
|
| 121 |
+
if model_answer is not None and model_answer == ground_truth:
|
| 122 |
+
correct_number += 1
|
| 123 |
+
|
| 124 |
+
# Calculate and print accuracy
|
| 125 |
+
accuracy = correct_number / len(data) * 100
|
| 126 |
+
print(f"\nAccuracy: {accuracy:.2f}%")
|
| 127 |
+
|
| 128 |
+
# Save results to a JSON file
|
| 129 |
+
output_path = OUTPUT_PATH
|
| 130 |
+
with open(output_path, "w") as f:
|
| 131 |
+
json.dump({
|
| 132 |
+
'accuracy': accuracy,
|
| 133 |
+
'results': final_output
|
| 134 |
+
}, f, indent=2)
|
| 135 |
+
|
| 136 |
+
print(f"Results saved to {output_path}")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.8.19
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# qwen-vl-utils
|
| 2 |
+
|
| 3 |
+
Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
|
| 4 |
+
|
| 5 |
+
## Install
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
pip install qwen-vl-utils
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Usage
|
| 12 |
+
|
| 13 |
+
### Qwen2VL
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
| 17 |
+
from qwen_vl_utils import process_vision_info
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
|
| 21 |
+
messages = [
|
| 22 |
+
# Image
|
| 23 |
+
## Local file path
|
| 24 |
+
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
|
| 25 |
+
## Image URL
|
| 26 |
+
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
|
| 27 |
+
## Base64 encoded image
|
| 28 |
+
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
|
| 29 |
+
## PIL.Image.Image
|
| 30 |
+
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
|
| 31 |
+
## Model dynamically adjusts image size, specify dimensions if required.
|
| 32 |
+
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
|
| 33 |
+
# Video
|
| 34 |
+
## Local video path
|
| 35 |
+
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
|
| 36 |
+
## Local video frames
|
| 37 |
+
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
|
| 38 |
+
## Model dynamically adjusts video nframes, video height and width. specify args if required.
|
| 39 |
+
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
| 43 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
|
| 44 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 45 |
+
images, videos = process_vision_info(messages)
|
| 46 |
+
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
|
| 47 |
+
print(inputs)
|
| 48 |
+
generated_ids = model.generate(**inputs)
|
| 49 |
+
print(generated_ids)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Qwen2.5VL
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
| 56 |
+
from qwen_vl_utils import process_vision_info
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
|
| 60 |
+
# based on the maximum tokens that the model can accept.
|
| 61 |
+
# export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
|
| 65 |
+
messages = [
|
| 66 |
+
# Image
|
| 67 |
+
## Local file path
|
| 68 |
+
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
|
| 69 |
+
## Image URL
|
| 70 |
+
[{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
|
| 71 |
+
## Base64 encoded image
|
| 72 |
+
[{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
|
| 73 |
+
## PIL.Image.Image
|
| 74 |
+
[{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
|
| 75 |
+
## Model dynamically adjusts image size, specify dimensions if required.
|
| 76 |
+
[{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
|
| 77 |
+
# Video
|
| 78 |
+
## Local video path
|
| 79 |
+
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
|
| 80 |
+
## Local video frames
|
| 81 |
+
[{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
|
| 82 |
+
## Model dynamically adjusts video nframes, video height and width. specify args if required.
|
| 83 |
+
[{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
| 87 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
|
| 88 |
+
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 89 |
+
images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
|
| 90 |
+
inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
|
| 91 |
+
print(inputs)
|
| 92 |
+
generated_ids = model.generate(**inputs)
|
| 93 |
+
print(generated_ids)
|
| 94 |
+
```
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/pyproject.toml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "qwen-vl-utils"
|
| 3 |
+
version = "0.0.10"
|
| 4 |
+
description = "Qwen Vision Language Model Utils - PyTorch"
|
| 5 |
+
authors = [
|
| 6 |
+
{ name = "Qwen Team", email = "[email protected]" },
|
| 7 |
+
]
|
| 8 |
+
dependencies = [
|
| 9 |
+
"requests",
|
| 10 |
+
"pillow",
|
| 11 |
+
"av",
|
| 12 |
+
"packaging",
|
| 13 |
+
]
|
| 14 |
+
readme = "README.md"
|
| 15 |
+
requires-python = ">= 3.8"
|
| 16 |
+
license = {text = "Apache-2.0"}
|
| 17 |
+
keywords = [
|
| 18 |
+
'large language model',
|
| 19 |
+
'vision language model',
|
| 20 |
+
'qwen-vl',
|
| 21 |
+
'pytorch',
|
| 22 |
+
]
|
| 23 |
+
classifiers = [
|
| 24 |
+
'Development Status :: 4 - Beta',
|
| 25 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
| 26 |
+
'Programming Language :: Python :: 3',
|
| 27 |
+
'License :: OSI Approved :: Apache Software License',
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
[project.urls]
|
| 31 |
+
Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
|
| 32 |
+
Repository = "https://github.com/QwenLM/Qwen2-VL.git"
|
| 33 |
+
Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
|
| 34 |
+
|
| 35 |
+
[project.optional-dependencies]
|
| 36 |
+
decord = [
|
| 37 |
+
"decord",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
[build-system]
|
| 41 |
+
requires = ["hatchling"]
|
| 42 |
+
build-backend = "hatchling.build"
|
| 43 |
+
|
| 44 |
+
[tool.rye]
|
| 45 |
+
managed = true
|
| 46 |
+
dev-dependencies = [
|
| 47 |
+
"torch",
|
| 48 |
+
"torchvision",
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
[tool.hatch.metadata]
|
| 52 |
+
allow-direct-references = true
|
| 53 |
+
|
| 54 |
+
[tool.hatch.build.targets.wheel]
|
| 55 |
+
packages = ["src/qwen_vl_utils"]
|
| 56 |
+
|
| 57 |
+
[tool.ruff]
|
| 58 |
+
line-length = 119
|
| 59 |
+
|
| 60 |
+
[tool.ruff.lint]
|
| 61 |
+
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
|
| 62 |
+
select = ["C", "E", "F", "I", "W"]
|
| 63 |
+
|
| 64 |
+
[tool.ruff.lint.per-file-ignores]
|
| 65 |
+
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
| 66 |
+
|
| 67 |
+
[tool.ruff.lint.isort]
|
| 68 |
+
lines-after-imports = 2
|
| 69 |
+
known-first-party = ["qwen_vl_utils"]
|
| 70 |
+
|
| 71 |
+
[tool.ruff.format]
|
| 72 |
+
quote-style = "double"
|
| 73 |
+
indent-style = "space"
|
| 74 |
+
skip-magic-trailing-comma = false
|
| 75 |
+
line-ending = "auto"
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements-dev.lock
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# generated by rye
|
| 2 |
+
# use `rye lock` or `rye sync` to update this lockfile
|
| 3 |
+
#
|
| 4 |
+
# last locked with the following flags:
|
| 5 |
+
# pre: false
|
| 6 |
+
# features: ["decord"]
|
| 7 |
+
# all-features: false
|
| 8 |
+
# with-sources: false
|
| 9 |
+
# generate-hashes: false
|
| 10 |
+
# universal: false
|
| 11 |
+
|
| 12 |
+
-e file:.
|
| 13 |
+
av==12.3.0
|
| 14 |
+
# via qwen-vl-utils
|
| 15 |
+
certifi==2022.12.7
|
| 16 |
+
# via requests
|
| 17 |
+
charset-normalizer==2.1.1
|
| 18 |
+
# via requests
|
| 19 |
+
decord==0.6.0
|
| 20 |
+
# via qwen-vl-utils
|
| 21 |
+
filelock==3.13.1
|
| 22 |
+
# via torch
|
| 23 |
+
# via triton
|
| 24 |
+
fsspec==2024.2.0
|
| 25 |
+
# via torch
|
| 26 |
+
idna==3.4
|
| 27 |
+
# via requests
|
| 28 |
+
jinja2==3.1.3
|
| 29 |
+
# via torch
|
| 30 |
+
markupsafe==2.1.5
|
| 31 |
+
# via jinja2
|
| 32 |
+
mpmath==1.3.0
|
| 33 |
+
# via sympy
|
| 34 |
+
networkx==3.1
|
| 35 |
+
# via torch
|
| 36 |
+
numpy==1.24.1
|
| 37 |
+
# via decord
|
| 38 |
+
# via torchvision
|
| 39 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 40 |
+
# via nvidia-cudnn-cu12
|
| 41 |
+
# via nvidia-cusolver-cu12
|
| 42 |
+
# via torch
|
| 43 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 44 |
+
# via torch
|
| 45 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 46 |
+
# via torch
|
| 47 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 48 |
+
# via torch
|
| 49 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 50 |
+
# via torch
|
| 51 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 52 |
+
# via torch
|
| 53 |
+
nvidia-curand-cu12==10.3.2.106
|
| 54 |
+
# via torch
|
| 55 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 56 |
+
# via torch
|
| 57 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 58 |
+
# via nvidia-cusolver-cu12
|
| 59 |
+
# via torch
|
| 60 |
+
nvidia-nccl-cu12==2.20.5
|
| 61 |
+
# via torch
|
| 62 |
+
nvidia-nvjitlink-cu12==12.6.68
|
| 63 |
+
# via nvidia-cusolver-cu12
|
| 64 |
+
# via nvidia-cusparse-cu12
|
| 65 |
+
nvidia-nvtx-cu12==12.1.105
|
| 66 |
+
# via torch
|
| 67 |
+
packaging==24.1
|
| 68 |
+
# via qwen-vl-utils
|
| 69 |
+
pillow==10.2.0
|
| 70 |
+
# via qwen-vl-utils
|
| 71 |
+
# via torchvision
|
| 72 |
+
requests==2.28.1
|
| 73 |
+
# via qwen-vl-utils
|
| 74 |
+
sympy==1.12
|
| 75 |
+
# via torch
|
| 76 |
+
torch==2.4.0
|
| 77 |
+
# via torchvision
|
| 78 |
+
torchvision==0.19.0
|
| 79 |
+
triton==3.0.0
|
| 80 |
+
# via torch
|
| 81 |
+
typing-extensions==4.9.0
|
| 82 |
+
# via torch
|
| 83 |
+
urllib3==1.26.13
|
| 84 |
+
# via requests
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements.lock
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# generated by rye
|
| 2 |
+
# use `rye lock` or `rye sync` to update this lockfile
|
| 3 |
+
#
|
| 4 |
+
# last locked with the following flags:
|
| 5 |
+
# pre: false
|
| 6 |
+
# features: ["decord"]
|
| 7 |
+
# all-features: false
|
| 8 |
+
# with-sources: false
|
| 9 |
+
# generate-hashes: false
|
| 10 |
+
# universal: false
|
| 11 |
+
|
| 12 |
+
-e file:.
|
| 13 |
+
av==12.3.0
|
| 14 |
+
# via qwen-vl-utils
|
| 15 |
+
certifi==2022.12.7
|
| 16 |
+
# via requests
|
| 17 |
+
charset-normalizer==2.1.1
|
| 18 |
+
# via requests
|
| 19 |
+
decord==0.6.0
|
| 20 |
+
# via qwen-vl-utils
|
| 21 |
+
idna==3.4
|
| 22 |
+
# via requests
|
| 23 |
+
numpy==1.24.4
|
| 24 |
+
# via decord
|
| 25 |
+
packaging==24.1
|
| 26 |
+
# via qwen-vl-utils
|
| 27 |
+
pillow==10.2.0
|
| 28 |
+
# via qwen-vl-utils
|
| 29 |
+
requests==2.28.1
|
| 30 |
+
# via qwen-vl-utils
|
| 31 |
+
urllib3==1.26.13
|
| 32 |
+
# via requests
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vision_process import (
|
| 2 |
+
extract_vision_info,
|
| 3 |
+
fetch_image,
|
| 4 |
+
fetch_video,
|
| 5 |
+
process_vision_info,
|
| 6 |
+
smart_resize,
|
| 7 |
+
)
|
previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import time
|
| 9 |
+
import warnings
|
| 10 |
+
from functools import lru_cache
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
|
| 13 |
+
import requests
|
| 14 |
+
import torch
|
| 15 |
+
import torchvision
|
| 16 |
+
from packaging import version
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from torchvision import io, transforms
|
| 19 |
+
from torchvision.transforms import InterpolationMode
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
IMAGE_FACTOR = 28
|
| 26 |
+
MIN_PIXELS = 4 * 28 * 28
|
| 27 |
+
MAX_PIXELS = 16384 * 28 * 28
|
| 28 |
+
MAX_RATIO = 200
|
| 29 |
+
|
| 30 |
+
# VIDEO_MIN_PIXELS = 128 * 28 * 28
|
| 31 |
+
# VIDEO_MAX_PIXELS = 768 * 28 * 28
|
| 32 |
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
| 33 |
+
VIDEO_MAX_PIXELS = 128 * 28 * 28
|
| 34 |
+
FRAME_FACTOR = 2
|
| 35 |
+
FPS = 2.0
|
| 36 |
+
FPS_MIN_FRAMES = 4
|
| 37 |
+
FPS_MAX_FRAMES = 16
|
| 38 |
+
|
| 39 |
+
# Set the maximum number of video token inputs.
|
| 40 |
+
# Here, 128K represents the maximum number of input tokens for the VLLM model.
|
| 41 |
+
# Remember to adjust it according to your own configuration.
|
| 42 |
+
VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
|
| 43 |
+
logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def round_by_factor(number: int, factor: int) -> int:
|
| 47 |
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
| 48 |
+
return round(number / factor) * factor
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def ceil_by_factor(number: int, factor: int) -> int:
|
| 52 |
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
| 53 |
+
return math.ceil(number / factor) * factor
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def floor_by_factor(number: int, factor: int) -> int:
|
| 57 |
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
| 58 |
+
return math.floor(number / factor) * factor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def smart_resize(
|
| 62 |
+
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
| 63 |
+
) -> tuple[int, int]:
|
| 64 |
+
"""
|
| 65 |
+
Rescales the image so that the following conditions are met:
|
| 66 |
+
|
| 67 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
| 68 |
+
|
| 69 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
| 70 |
+
|
| 71 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
| 72 |
+
"""
|
| 73 |
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
| 76 |
+
)
|
| 77 |
+
h_bar = max(factor, round_by_factor(height, factor))
|
| 78 |
+
w_bar = max(factor, round_by_factor(width, factor))
|
| 79 |
+
if h_bar * w_bar > max_pixels:
|
| 80 |
+
beta = math.sqrt((height * width) / max_pixels)
|
| 81 |
+
h_bar = floor_by_factor(height / beta, factor)
|
| 82 |
+
w_bar = floor_by_factor(width / beta, factor)
|
| 83 |
+
elif h_bar * w_bar < min_pixels:
|
| 84 |
+
beta = math.sqrt(min_pixels / (height * width))
|
| 85 |
+
h_bar = ceil_by_factor(height * beta, factor)
|
| 86 |
+
w_bar = ceil_by_factor(width * beta, factor)
|
| 87 |
+
return h_bar, w_bar
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
| 91 |
+
if pil_image.mode == 'RGBA':
|
| 92 |
+
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
| 93 |
+
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
| 94 |
+
return white_background
|
| 95 |
+
else:
|
| 96 |
+
return pil_image.convert("RGB")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
|
| 100 |
+
if "image" in ele:
|
| 101 |
+
image = ele["image"]
|
| 102 |
+
else:
|
| 103 |
+
image = ele["image_url"]
|
| 104 |
+
image_obj = None
|
| 105 |
+
if isinstance(image, Image.Image):
|
| 106 |
+
image_obj = image
|
| 107 |
+
elif image.startswith("http://") or image.startswith("https://"):
|
| 108 |
+
response = requests.get(image, stream=True)
|
| 109 |
+
image_obj = Image.open(BytesIO(response.content))
|
| 110 |
+
elif image.startswith("file://"):
|
| 111 |
+
image_obj = Image.open(image[7:])
|
| 112 |
+
elif image.startswith("data:image"):
|
| 113 |
+
if "base64," in image:
|
| 114 |
+
_, base64_data = image.split("base64,", 1)
|
| 115 |
+
data = base64.b64decode(base64_data)
|
| 116 |
+
image_obj = Image.open(BytesIO(data))
|
| 117 |
+
else:
|
| 118 |
+
image_obj = Image.open(image)
|
| 119 |
+
if image_obj is None:
|
| 120 |
+
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
| 121 |
+
image = to_rgb(image_obj)
|
| 122 |
+
## resize
|
| 123 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 124 |
+
resized_height, resized_width = smart_resize(
|
| 125 |
+
ele["resized_height"],
|
| 126 |
+
ele["resized_width"],
|
| 127 |
+
factor=size_factor,
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
width, height = image.size
|
| 131 |
+
min_pixels = ele.get("min_pixels", MIN_PIXELS)
|
| 132 |
+
max_pixels = ele.get("max_pixels", MAX_PIXELS)
|
| 133 |
+
resized_height, resized_width = smart_resize(
|
| 134 |
+
height,
|
| 135 |
+
width,
|
| 136 |
+
factor=size_factor,
|
| 137 |
+
min_pixels=min_pixels,
|
| 138 |
+
max_pixels=max_pixels,
|
| 139 |
+
)
|
| 140 |
+
image = image.resize((resized_width, resized_height))
|
| 141 |
+
|
| 142 |
+
return image
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def smart_nframes(
|
| 146 |
+
ele: dict,
|
| 147 |
+
total_frames: int,
|
| 148 |
+
video_fps: int | float,
|
| 149 |
+
) -> int:
|
| 150 |
+
"""calculate the number of frames for video used for model inputs.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
ele (dict): a dict contains the configuration of video.
|
| 154 |
+
support either `fps` or `nframes`:
|
| 155 |
+
- nframes: the number of frames to extract for model inputs.
|
| 156 |
+
- fps: the fps to extract frames for model inputs.
|
| 157 |
+
- min_frames: the minimum number of frames of the video, only used when fps is provided.
|
| 158 |
+
- max_frames: the maximum number of frames of the video, only used when fps is provided.
|
| 159 |
+
total_frames (int): the original total number of frames of the video.
|
| 160 |
+
video_fps (int | float): the original fps of the video.
|
| 161 |
+
|
| 162 |
+
Raises:
|
| 163 |
+
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
int: the number of frames for video used for model inputs.
|
| 167 |
+
"""
|
| 168 |
+
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
|
| 169 |
+
if "nframes" in ele:
|
| 170 |
+
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
| 171 |
+
else:
|
| 172 |
+
fps = ele.get("fps", FPS)
|
| 173 |
+
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
| 174 |
+
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
|
| 175 |
+
nframes = total_frames / video_fps * fps
|
| 176 |
+
if nframes > total_frames:
|
| 177 |
+
logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
|
| 178 |
+
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
|
| 179 |
+
nframes = floor_by_factor(nframes, FRAME_FACTOR)
|
| 180 |
+
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
| 181 |
+
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
|
| 182 |
+
return nframes
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _read_video_torchvision(
|
| 186 |
+
ele: dict,
|
| 187 |
+
) -> (torch.Tensor, float):
|
| 188 |
+
"""read video using torchvision.io.read_video
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
ele (dict): a dict contains the configuration of video.
|
| 192 |
+
support keys:
|
| 193 |
+
- video: the path of video. support "file://", "http://", "https://" and local path.
|
| 194 |
+
- video_start: the start time of video.
|
| 195 |
+
- video_end: the end time of video.
|
| 196 |
+
Returns:
|
| 197 |
+
torch.Tensor: the video tensor with shape (T, C, H, W).
|
| 198 |
+
"""
|
| 199 |
+
video_path = ele["video"]
|
| 200 |
+
if version.parse(torchvision.__version__) < version.parse("0.19.0"):
|
| 201 |
+
if "http://" in video_path or "https://" in video_path:
|
| 202 |
+
warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
|
| 203 |
+
if "file://" in video_path:
|
| 204 |
+
video_path = video_path[7:]
|
| 205 |
+
st = time.time()
|
| 206 |
+
video, audio, info = io.read_video(
|
| 207 |
+
video_path,
|
| 208 |
+
start_pts=ele.get("video_start", 0.0),
|
| 209 |
+
end_pts=ele.get("video_end", None),
|
| 210 |
+
pts_unit="sec",
|
| 211 |
+
output_format="TCHW",
|
| 212 |
+
)
|
| 213 |
+
total_frames, video_fps = video.size(0), info["video_fps"]
|
| 214 |
+
logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
| 215 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
| 216 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long()
|
| 217 |
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
| 218 |
+
video = video[idx]
|
| 219 |
+
return video, sample_fps
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def is_decord_available() -> bool:
|
| 223 |
+
import importlib.util
|
| 224 |
+
|
| 225 |
+
return importlib.util.find_spec("decord") is not None
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _read_video_decord(
|
| 229 |
+
ele: dict,
|
| 230 |
+
) -> (torch.Tensor, float):
|
| 231 |
+
"""read video using decord.VideoReader
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
ele (dict): a dict contains the configuration of video.
|
| 235 |
+
support keys:
|
| 236 |
+
- video: the path of video. support "file://", "http://", "https://" and local path.
|
| 237 |
+
- video_start: the start time of video.
|
| 238 |
+
- video_end: the end time of video.
|
| 239 |
+
Returns:
|
| 240 |
+
torch.Tensor: the video tensor with shape (T, C, H, W).
|
| 241 |
+
"""
|
| 242 |
+
import decord
|
| 243 |
+
video_path = ele["video"]
|
| 244 |
+
st = time.time()
|
| 245 |
+
vr = decord.VideoReader(video_path)
|
| 246 |
+
# TODO: support start_pts and end_pts
|
| 247 |
+
if 'video_start' in ele or 'video_end' in ele:
|
| 248 |
+
raise NotImplementedError("not support start_pts and end_pts in decord for now.")
|
| 249 |
+
total_frames, video_fps = len(vr), vr.get_avg_fps()
|
| 250 |
+
logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
|
| 251 |
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
| 252 |
+
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
|
| 253 |
+
video = vr.get_batch(idx).asnumpy()
|
| 254 |
+
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
|
| 255 |
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
| 256 |
+
return video, sample_fps
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
VIDEO_READER_BACKENDS = {
|
| 260 |
+
"decord": _read_video_decord,
|
| 261 |
+
"torchvision": _read_video_torchvision,
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@lru_cache(maxsize=1)
|
| 268 |
+
def get_video_reader_backend() -> str:
|
| 269 |
+
if FORCE_QWENVL_VIDEO_READER is not None:
|
| 270 |
+
video_reader_backend = FORCE_QWENVL_VIDEO_READER
|
| 271 |
+
elif is_decord_available():
|
| 272 |
+
video_reader_backend = "decord"
|
| 273 |
+
else:
|
| 274 |
+
video_reader_backend = "torchvision"
|
| 275 |
+
print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
|
| 276 |
+
return video_reader_backend
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
|
| 280 |
+
if isinstance(ele["video"], str):
|
| 281 |
+
video_reader_backend = get_video_reader_backend()
|
| 282 |
+
try:
|
| 283 |
+
video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
|
| 286 |
+
video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
|
| 287 |
+
|
| 288 |
+
nframes, _, height, width = video.shape
|
| 289 |
+
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
| 290 |
+
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
| 291 |
+
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
|
| 292 |
+
max_pixels_supposed = ele.get("max_pixels", max_pixels)
|
| 293 |
+
if max_pixels_supposed > max_pixels:
|
| 294 |
+
logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
|
| 295 |
+
max_pixels = min(max_pixels_supposed, max_pixels)
|
| 296 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 297 |
+
resized_height, resized_width = smart_resize(
|
| 298 |
+
ele["resized_height"],
|
| 299 |
+
ele["resized_width"],
|
| 300 |
+
factor=image_factor,
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
resized_height, resized_width = smart_resize(
|
| 304 |
+
height,
|
| 305 |
+
width,
|
| 306 |
+
factor=image_factor,
|
| 307 |
+
min_pixels=min_pixels,
|
| 308 |
+
max_pixels=max_pixels,
|
| 309 |
+
)
|
| 310 |
+
video = transforms.functional.resize(
|
| 311 |
+
video,
|
| 312 |
+
[resized_height, resized_width],
|
| 313 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 314 |
+
antialias=True,
|
| 315 |
+
).float()
|
| 316 |
+
if return_video_sample_fps:
|
| 317 |
+
return video, sample_fps
|
| 318 |
+
return video
|
| 319 |
+
else:
|
| 320 |
+
assert isinstance(ele["video"], (list, tuple))
|
| 321 |
+
process_info = ele.copy()
|
| 322 |
+
process_info.pop("type", None)
|
| 323 |
+
process_info.pop("video", None)
|
| 324 |
+
images = [
|
| 325 |
+
fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
|
| 326 |
+
for video_element in ele["video"]
|
| 327 |
+
]
|
| 328 |
+
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
| 329 |
+
if len(images) < nframes:
|
| 330 |
+
images.extend([images[-1]] * (nframes - len(images)))
|
| 331 |
+
if return_video_sample_fps:
|
| 332 |
+
return images, process_info.pop("fps", 2.0)
|
| 333 |
+
return images
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
| 337 |
+
vision_infos = []
|
| 338 |
+
if isinstance(conversations[0], dict):
|
| 339 |
+
conversations = [conversations]
|
| 340 |
+
for conversation in conversations:
|
| 341 |
+
for message in conversation:
|
| 342 |
+
if isinstance(message["content"], list):
|
| 343 |
+
for ele in message["content"]:
|
| 344 |
+
if (
|
| 345 |
+
"image" in ele
|
| 346 |
+
or "image_url" in ele
|
| 347 |
+
or "video" in ele
|
| 348 |
+
or ele["type"] in ("image", "image_url", "video")
|
| 349 |
+
):
|
| 350 |
+
vision_infos.append(ele)
|
| 351 |
+
return vision_infos
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def process_vision_info(
|
| 355 |
+
conversations: list[dict] | list[list[dict]],
|
| 356 |
+
return_video_kwargs: bool = False,
|
| 357 |
+
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
|
| 358 |
+
|
| 359 |
+
vision_infos = extract_vision_info(conversations)
|
| 360 |
+
## Read images or videos
|
| 361 |
+
image_inputs = []
|
| 362 |
+
video_inputs = []
|
| 363 |
+
video_sample_fps_list = []
|
| 364 |
+
for vision_info in vision_infos:
|
| 365 |
+
if "image" in vision_info or "image_url" in vision_info:
|
| 366 |
+
image_inputs.append(fetch_image(vision_info))
|
| 367 |
+
elif "video" in vision_info:
|
| 368 |
+
video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
|
| 369 |
+
video_sample_fps_list.append(video_sample_fps)
|
| 370 |
+
video_inputs.append(video_input)
|
| 371 |
+
else:
|
| 372 |
+
raise ValueError("image, image_url or video should in content.")
|
| 373 |
+
if len(image_inputs) == 0:
|
| 374 |
+
image_inputs = None
|
| 375 |
+
if len(video_inputs) == 0:
|
| 376 |
+
video_inputs = None
|
| 377 |
+
if return_video_kwargs:
|
| 378 |
+
return image_inputs, video_inputs, {'fps': video_sample_fps_list}
|
| 379 |
+
return image_inputs, video_inputs
|
previous_version/Video-R1-main-previous/src/r1-v/temp_image.png
ADDED
|
Git LFS Details
|
src/r1-v/.gitignore
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
#.idea/
|
| 169 |
+
|
| 170 |
+
# PyPI configuration file
|
| 171 |
+
.pypirc
|
| 172 |
+
|
| 173 |
+
# Temp folders
|
| 174 |
+
data/
|
| 175 |
+
wandb/
|
| 176 |
+
scripts/
|
| 177 |
+
checkpoints/
|
| 178 |
+
.vscode/
|
src/r1-v/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
src/r1-v/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: style quality
|
| 2 |
+
|
| 3 |
+
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
| 4 |
+
export PYTHONPATH = src
|
| 5 |
+
|
| 6 |
+
check_dirs := src
|
| 7 |
+
|
| 8 |
+
style:
|
| 9 |
+
black --line-length 119 --target-version py310 $(check_dirs) setup.py
|
| 10 |
+
isort $(check_dirs) setup.py
|
| 11 |
+
|
| 12 |
+
quality:
|
| 13 |
+
black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
| 14 |
+
isort --check-only $(check_dirs) setup.py
|
| 15 |
+
flake8 --max-line-length 119 $(check_dirs) setup.py
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Evaluation
|
| 19 |
+
|
| 20 |
+
evaluate:
|
src/r1-v/setup.cfg
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[isort]
|
| 2 |
+
default_section = FIRSTPARTY
|
| 3 |
+
ensure_newline_before_comments = True
|
| 4 |
+
force_grid_wrap = 0
|
| 5 |
+
include_trailing_comma = True
|
| 6 |
+
known_first_party = open_r1
|
| 7 |
+
known_third_party =
|
| 8 |
+
transformers
|
| 9 |
+
datasets
|
| 10 |
+
fugashi
|
| 11 |
+
git
|
| 12 |
+
h5py
|
| 13 |
+
matplotlib
|
| 14 |
+
nltk
|
| 15 |
+
numpy
|
| 16 |
+
packaging
|
| 17 |
+
pandas
|
| 18 |
+
psutil
|
| 19 |
+
pytest
|
| 20 |
+
rouge_score
|
| 21 |
+
sacrebleu
|
| 22 |
+
seqeval
|
| 23 |
+
sklearn
|
| 24 |
+
streamlit
|
| 25 |
+
torch
|
| 26 |
+
tqdm
|
| 27 |
+
|
| 28 |
+
line_length = 119
|
| 29 |
+
lines_after_imports = 2
|
| 30 |
+
multi_line_output = 3
|
| 31 |
+
use_parentheses = True
|
| 32 |
+
|
| 33 |
+
[flake8]
|
| 34 |
+
ignore = E203, E501, E741, W503, W605
|
| 35 |
+
max-line-length = 119
|
| 36 |
+
per-file-ignores =
|
| 37 |
+
# imported but unused
|
| 38 |
+
__init__.py: F401
|
| 39 |
+
|
| 40 |
+
[tool:pytest]
|
| 41 |
+
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
|
src/r1-v/setup.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import re
|
| 19 |
+
import shutil
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
from setuptools import find_packages, setup
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
| 26 |
+
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
| 27 |
+
if stale_egg_info.exists():
|
| 28 |
+
print(
|
| 29 |
+
(
|
| 30 |
+
"Warning: {} exists.\n\n"
|
| 31 |
+
"If you recently updated open_r1, this is expected,\n"
|
| 32 |
+
"but it may prevent open_r1 from installing in editable mode.\n\n"
|
| 33 |
+
"This directory is automatically generated by Python's packaging tools.\n"
|
| 34 |
+
"I will remove it now.\n\n"
|
| 35 |
+
"See https://github.com/pypa/pip/issues/5466 for details.\n"
|
| 36 |
+
).format(stale_egg_info)
|
| 37 |
+
)
|
| 38 |
+
shutil.rmtree(stale_egg_info)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
| 42 |
+
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
| 43 |
+
_deps = [
|
| 44 |
+
"accelerate>=1.2.1",
|
| 45 |
+
"bitsandbytes>=0.43.0",
|
| 46 |
+
"black>=24.4.2",
|
| 47 |
+
"datasets>=3.2.0",
|
| 48 |
+
"deepspeed==0.15.4",
|
| 49 |
+
"distilabel[vllm,ray,openai]>=1.5.2",
|
| 50 |
+
"einops>=0.8.0",
|
| 51 |
+
"flake8>=6.0.0",
|
| 52 |
+
"hf_transfer>=0.1.4",
|
| 53 |
+
"huggingface-hub[cli]>=0.19.2,<1.0",
|
| 54 |
+
"isort>=5.12.0",
|
| 55 |
+
"liger_kernel==0.5.2",
|
| 56 |
+
"lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
|
| 57 |
+
"math-verify", # Used for math verification in grpo
|
| 58 |
+
"packaging>=23.0",
|
| 59 |
+
"parameterized>=0.9.0",
|
| 60 |
+
"pytest",
|
| 61 |
+
"safetensors>=0.3.3",
|
| 62 |
+
"sentencepiece>=0.1.99",
|
| 63 |
+
"torch>=2.5.1",
|
| 64 |
+
# "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef",
|
| 65 |
+
"trl==0.16.0",
|
| 66 |
+
"vllm==0.7.2",
|
| 67 |
+
"wandb>=0.19.1",
|
| 68 |
+
"pillow",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
# this is a lookup table with items like:
|
| 72 |
+
#
|
| 73 |
+
# tokenizers: "tokenizers==0.9.4"
|
| 74 |
+
# packaging: "packaging"
|
| 75 |
+
#
|
| 76 |
+
# some of the values are versioned whereas others aren't.
|
| 77 |
+
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def deps_list(*pkgs):
|
| 81 |
+
return [deps[pkg] for pkg in pkgs]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
extras = {}
|
| 85 |
+
extras["tests"] = deps_list("pytest", "parameterized")
|
| 86 |
+
extras["torch"] = deps_list("torch")
|
| 87 |
+
extras["quality"] = deps_list("black", "isort", "flake8")
|
| 88 |
+
extras["eval"] = deps_list("lighteval", "math-verify")
|
| 89 |
+
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
|
| 90 |
+
|
| 91 |
+
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
| 92 |
+
install_requires = [
|
| 93 |
+
deps["accelerate"],
|
| 94 |
+
deps["bitsandbytes"],
|
| 95 |
+
deps["einops"],
|
| 96 |
+
deps["datasets"],
|
| 97 |
+
deps["deepspeed"],
|
| 98 |
+
deps["hf_transfer"],
|
| 99 |
+
deps["huggingface-hub"],
|
| 100 |
+
deps["liger_kernel"],
|
| 101 |
+
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
| 102 |
+
deps["safetensors"],
|
| 103 |
+
deps["sentencepiece"],
|
| 104 |
+
# deps["transformers"],
|
| 105 |
+
deps["trl"],
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
setup(
|
| 109 |
+
name="r1-v",
|
| 110 |
+
version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
| 111 |
+
author="The r1-v team and the Hugging Face team (past and future)",
|
| 112 |
+
description="R1-V",
|
| 113 |
+
license="Apache",
|
| 114 |
+
url="https://github.com/Deep-Agent/R1-V",
|
| 115 |
+
package_dir={"": "src"},
|
| 116 |
+
packages=find_packages("src"),
|
| 117 |
+
zip_safe=False,
|
| 118 |
+
extras_require=extras,
|
| 119 |
+
python_requires=">=3.10.9",
|
| 120 |
+
install_requires=install_requires,
|
| 121 |
+
classifiers=[
|
| 122 |
+
"Development Status :: 3 - Alpha",
|
| 123 |
+
"Intended Audience :: Developers",
|
| 124 |
+
"Intended Audience :: Education",
|
| 125 |
+
"Intended Audience :: Science/Research",
|
| 126 |
+
"License :: OSI Approved :: Apache Software License",
|
| 127 |
+
"Operating System :: OS Independent",
|
| 128 |
+
"Programming Language :: Python :: 3",
|
| 129 |
+
"Programming Language :: Python :: 3.10",
|
| 130 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 131 |
+
],
|
| 132 |
+
)
|
src/r1-v/src/open_r1/__init__.py
ADDED
|
File without changes
|
src/r1-v/src/open_r1/evaluate.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Custom evaluation tasks for LightEval."""
|
| 16 |
+
|
| 17 |
+
from lighteval.metrics.dynamic_metrics import (
|
| 18 |
+
ExprExtractionConfig,
|
| 19 |
+
LatexExtractionConfig,
|
| 20 |
+
multilingual_extractive_match_metric,
|
| 21 |
+
)
|
| 22 |
+
from lighteval.tasks.lighteval_task import LightevalTaskConfig
|
| 23 |
+
from lighteval.tasks.requests import Doc
|
| 24 |
+
from lighteval.utils.language import Language
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
metric = multilingual_extractive_match_metric(
|
| 28 |
+
language=Language.ENGLISH,
|
| 29 |
+
fallback_mode="first_match",
|
| 30 |
+
precision=5,
|
| 31 |
+
gold_extraction_target=(LatexExtractionConfig(),),
|
| 32 |
+
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
|
| 33 |
+
aggregation_function=max,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def prompt_fn(line, task_name: str = None):
|
| 38 |
+
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
|
| 39 |
+
return Doc(
|
| 40 |
+
task_name=task_name,
|
| 41 |
+
query=line["problem"],
|
| 42 |
+
choices=[line["solution"]],
|
| 43 |
+
gold_index=0,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Define tasks
|
| 48 |
+
aime24 = LightevalTaskConfig(
|
| 49 |
+
name="aime24",
|
| 50 |
+
suite=["custom"],
|
| 51 |
+
prompt_function=prompt_fn,
|
| 52 |
+
hf_repo="HuggingFaceH4/aime_2024",
|
| 53 |
+
hf_subset="default",
|
| 54 |
+
hf_avail_splits=["train"],
|
| 55 |
+
evaluation_splits=["train"],
|
| 56 |
+
few_shots_split=None,
|
| 57 |
+
few_shots_select=None,
|
| 58 |
+
generation_size=32768,
|
| 59 |
+
metric=[metric],
|
| 60 |
+
version=1,
|
| 61 |
+
)
|
| 62 |
+
math_500 = LightevalTaskConfig(
|
| 63 |
+
name="math_500",
|
| 64 |
+
suite=["custom"],
|
| 65 |
+
prompt_function=prompt_fn,
|
| 66 |
+
hf_repo="HuggingFaceH4/MATH-500",
|
| 67 |
+
hf_subset="default",
|
| 68 |
+
hf_avail_splits=["test"],
|
| 69 |
+
evaluation_splits=["test"],
|
| 70 |
+
few_shots_split=None,
|
| 71 |
+
few_shots_select=None,
|
| 72 |
+
generation_size=32768,
|
| 73 |
+
metric=[metric],
|
| 74 |
+
version=1,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Add tasks to the table
|
| 78 |
+
TASKS_TABLE = []
|
| 79 |
+
TASKS_TABLE.append(aime24)
|
| 80 |
+
TASKS_TABLE.append(math_500)
|
| 81 |
+
|
| 82 |
+
# MODULE LOGIC
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
print([t["name"] for t in TASKS_TABLE])
|
| 85 |
+
print(len(TASKS_TABLE))
|
src/r1-v/src/open_r1/generate.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from distilabel.llms import OpenAILLM
|
| 18 |
+
from distilabel.pipeline import Pipeline
|
| 19 |
+
from distilabel.steps.tasks import TextGeneration
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_distilabel_pipeline(
|
| 23 |
+
model: str,
|
| 24 |
+
base_url: str = "http://localhost:8000/v1",
|
| 25 |
+
prompt_column: Optional[str] = None,
|
| 26 |
+
temperature: Optional[float] = None,
|
| 27 |
+
top_p: Optional[float] = None,
|
| 28 |
+
max_new_tokens: int = 8192,
|
| 29 |
+
num_generations: int = 1,
|
| 30 |
+
) -> Pipeline:
|
| 31 |
+
generation_kwargs = {"max_new_tokens": max_new_tokens}
|
| 32 |
+
|
| 33 |
+
if temperature is not None:
|
| 34 |
+
generation_kwargs["temperature"] = temperature
|
| 35 |
+
|
| 36 |
+
if top_p is not None:
|
| 37 |
+
generation_kwargs["top_p"] = top_p
|
| 38 |
+
|
| 39 |
+
with Pipeline().ray() as pipeline:
|
| 40 |
+
TextGeneration(
|
| 41 |
+
llm=OpenAILLM(
|
| 42 |
+
base_url=base_url,
|
| 43 |
+
api_key="something",
|
| 44 |
+
model=model,
|
| 45 |
+
# thinking can take some time...
|
| 46 |
+
timeout=10 * 60,
|
| 47 |
+
generation_kwargs=generation_kwargs,
|
| 48 |
+
),
|
| 49 |
+
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
|
| 50 |
+
input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
|
| 51 |
+
num_generations=num_generations,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return pipeline
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
import argparse
|
| 59 |
+
|
| 60 |
+
from datasets import load_dataset
|
| 61 |
+
|
| 62 |
+
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--hf-dataset",
|
| 65 |
+
type=str,
|
| 66 |
+
required=True,
|
| 67 |
+
help="HuggingFace dataset to load",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--hf-dataset-config",
|
| 71 |
+
type=str,
|
| 72 |
+
required=False,
|
| 73 |
+
help="Dataset config to use",
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--hf-dataset-split",
|
| 77 |
+
type=str,
|
| 78 |
+
default="train",
|
| 79 |
+
help="Dataset split to use",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument("--prompt-column", type=str, default="prompt")
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--model",
|
| 84 |
+
type=str,
|
| 85 |
+
required=True,
|
| 86 |
+
help="Model name to use for generation",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--vllm-server-url",
|
| 90 |
+
type=str,
|
| 91 |
+
default="http://localhost:8000/v1",
|
| 92 |
+
help="URL of the vLLM server",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--temperature",
|
| 96 |
+
type=float,
|
| 97 |
+
help="Temperature for generation",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--top-p",
|
| 101 |
+
type=float,
|
| 102 |
+
help="Top-p value for generation",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--max-new-tokens",
|
| 106 |
+
type=int,
|
| 107 |
+
default=8192,
|
| 108 |
+
help="Maximum number of new tokens to generate",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--num-generations",
|
| 112 |
+
type=int,
|
| 113 |
+
default=1,
|
| 114 |
+
help="Number of generations per problem",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--hf-output-dataset",
|
| 118 |
+
type=str,
|
| 119 |
+
required=False,
|
| 120 |
+
help="HuggingFace repo to push results to",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--private",
|
| 124 |
+
action="store_true",
|
| 125 |
+
help="Whether to make the output dataset private when pushing to HF Hub",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
print("\nRunning with arguments:")
|
| 131 |
+
for arg, value in vars(args).items():
|
| 132 |
+
print(f" {arg}: {value}")
|
| 133 |
+
print()
|
| 134 |
+
|
| 135 |
+
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
|
| 136 |
+
dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
|
| 137 |
+
print("Dataset loaded!")
|
| 138 |
+
|
| 139 |
+
pipeline = build_distilabel_pipeline(
|
| 140 |
+
model=args.model,
|
| 141 |
+
base_url=args.vllm_server_url,
|
| 142 |
+
prompt_column=args.prompt_column,
|
| 143 |
+
temperature=args.temperature,
|
| 144 |
+
top_p=args.top_p,
|
| 145 |
+
max_new_tokens=args.max_new_tokens,
|
| 146 |
+
num_generations=args.num_generations,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
print("Running generation pipeline...")
|
| 150 |
+
distiset = pipeline.run(dataset=dataset, use_cache=False)
|
| 151 |
+
print("Generation pipeline finished!")
|
| 152 |
+
|
| 153 |
+
if args.hf_output_dataset:
|
| 154 |
+
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
|
| 155 |
+
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
|
| 156 |
+
print("Dataset pushed!")
|
src/r1-v/src/open_r1/grpo-cot-72BEval.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
from openai import OpenAI
|
| 34 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 35 |
+
import time
|
| 36 |
+
# from utils.math_cot import *
|
| 37 |
+
# from qa_metrics.pedant import PEDANT
|
| 38 |
+
# from qa_metrics.answerBERT import AnswerBertActor
|
| 39 |
+
|
| 40 |
+
# pedant = PEDANT()
|
| 41 |
+
# answerBERT = AnswerBertActor(device='cuda:7')
|
| 42 |
+
client = OpenAI(
|
| 43 |
+
base_url="http://29.81.228.243:8081 /v1", # your vLLM server
|
| 44 |
+
api_key="ANYKEY", # if you set --api-key when launching
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def validate_description(description, question):
|
| 48 |
+
input_message = "You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide a step-by-step reasoning within <think> </think> tags, then provide your answer as a single final answer, single letter choice, or a short phrase ENCLOSED with <answer> </answer> tags. \nText description: {Description}\nQuestion: {Question}\nPlease only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions.".format(Description=description, Question=question)
|
| 49 |
+
response = client.chat.completions.create(
|
| 50 |
+
model="Qwen2.5-72B-Instruct", # **must match** the returned id
|
| 51 |
+
messages=[
|
| 52 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 53 |
+
{"role": "user", "content": input_message}
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# print('*'*10)
|
| 58 |
+
# print('Input Prompt: ', input_message)
|
| 59 |
+
# print('-'*10)
|
| 60 |
+
# print('Output Message: ', response.choices[0].message.content)
|
| 61 |
+
# print('-'*10)
|
| 62 |
+
# time.sleep(40)
|
| 63 |
+
|
| 64 |
+
return response.choices[0].message.content
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 69 |
+
"""
|
| 70 |
+
Script arguments for the GRPO training script.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
reward_funcs (`list[str]`):
|
| 74 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
reward_funcs: list[str] = field(
|
| 78 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 79 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# reward_funcs: list[str] = field(
|
| 83 |
+
# default_factory=lambda: ["accuracy"],
|
| 84 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 85 |
+
# )
|
| 86 |
+
max_pixels: Optional[int] = field(
|
| 87 |
+
default=12845056,
|
| 88 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 89 |
+
)
|
| 90 |
+
min_pixels: Optional[int] = field(
|
| 91 |
+
default=3136,
|
| 92 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 93 |
+
)
|
| 94 |
+
temporal: Optional[bool] = field(
|
| 95 |
+
default=True,
|
| 96 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 97 |
+
)
|
| 98 |
+
len_control: Optional[bool] = field(
|
| 99 |
+
default=True,
|
| 100 |
+
metadata={"help": "whether using length reward"},
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 105 |
+
def extract_answer(text: str) -> str:
|
| 106 |
+
"""
|
| 107 |
+
1) Try the full <answer> … </answer> block.
|
| 108 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 109 |
+
3) Otherwise return the original text.
|
| 110 |
+
"""
|
| 111 |
+
# ① normal case <answer> … </answer>
|
| 112 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 113 |
+
if m:
|
| 114 |
+
return m.group(1).strip()
|
| 115 |
+
|
| 116 |
+
# ② fallback <answer> … <end-of-string>
|
| 117 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 118 |
+
if m:
|
| 119 |
+
return m.group(1).strip()
|
| 120 |
+
|
| 121 |
+
# ③ nothing found
|
| 122 |
+
return text.strip()
|
| 123 |
+
|
| 124 |
+
def extract_description(predict: str) -> Optional[str]:
|
| 125 |
+
"""
|
| 126 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 127 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 128 |
+
or None if no <answer> tag is found.
|
| 129 |
+
"""
|
| 130 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 131 |
+
if not match:
|
| 132 |
+
return predict
|
| 133 |
+
return match.group(1).strip()
|
| 134 |
+
|
| 135 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 136 |
+
answer = predict
|
| 137 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 138 |
+
|
| 139 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 140 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 141 |
+
# format_score = format_reward(predict)
|
| 142 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 143 |
+
|
| 144 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 145 |
+
return accuracy_score
|
| 146 |
+
|
| 147 |
+
def normalize_number(num_str):
|
| 148 |
+
try:
|
| 149 |
+
num_str = num_str.replace(',', '')
|
| 150 |
+
return float(num_str)
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def wer(reference, hypothesis):
|
| 156 |
+
ref_words = reference.split()
|
| 157 |
+
hyp_words = hypothesis.split()
|
| 158 |
+
m = len(ref_words)
|
| 159 |
+
n = len(hyp_words)
|
| 160 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 161 |
+
for i in range(m+1):
|
| 162 |
+
d[i][0] = i
|
| 163 |
+
for j in range(n+1):
|
| 164 |
+
d[0][j] = j
|
| 165 |
+
for i in range(1, m+1):
|
| 166 |
+
for j in range(1, n+1):
|
| 167 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 168 |
+
d[i][j] = d[i-1][j-1]
|
| 169 |
+
else:
|
| 170 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 171 |
+
return d[m][n] / max(1, m)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 175 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 176 |
+
scores = scorer.score(reference, hypothesis)
|
| 177 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 178 |
+
return average_fmeasure
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
question_type = kwargs['problem_type'][0]
|
| 182 |
+
questions = kwargs['problem']
|
| 183 |
+
# questions = kwargs['prompt']
|
| 184 |
+
|
| 185 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 186 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 187 |
+
rewards = []
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
extracted_content_descriptions = [extract_description(str(ele)) for ele in contents]
|
| 191 |
+
description_answer_outputs = []
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
with ThreadPoolExecutor(max_workers=8) as executor:
|
| 195 |
+
futures = [
|
| 196 |
+
executor.submit(validate_description, desc, q)
|
| 197 |
+
for desc, q in zip(extracted_content_descriptions, questions)
|
| 198 |
+
]
|
| 199 |
+
for future in as_completed(futures):
|
| 200 |
+
try:
|
| 201 |
+
description_answer_outputs.append(future.result())
|
| 202 |
+
except Exception as e:
|
| 203 |
+
# handle/log e
|
| 204 |
+
# description_answer_outputs.append(None)
|
| 205 |
+
print('Description output error: ', e)
|
| 206 |
+
description_answer_outputs.append(0)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
contents = [str(ele) for ele in contents]
|
| 210 |
+
description_answer_outputs = [str(ele) for ele in description_answer_outputs]
|
| 211 |
+
|
| 212 |
+
gt_answers = [extract_answer(str(sol)) for sol in solution]
|
| 213 |
+
extracted_description_outputs = [extract_answer(str(description_answer_outputs[index_description])) for index_description in range(len(description_answer_outputs))]
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# print('GT answers: ', gt_answers)
|
| 217 |
+
# print('Description answers: ', description_answer_outputs[0])
|
| 218 |
+
# print('-'*10)
|
| 219 |
+
# import time
|
| 220 |
+
# time.sleep(10)
|
| 221 |
+
|
| 222 |
+
description_rewards = [compute_math_score_single(extracted_description_outputs[count_idx], gt_answers[count_idx]) for count_idx in range(len(description_answer_outputs))]
|
| 223 |
+
|
| 224 |
+
# print('()'*10)
|
| 225 |
+
# print("Question: ", questions[0])
|
| 226 |
+
# print(gt_answers)
|
| 227 |
+
# print('Description outputs', description_answer_outputs[0])
|
| 228 |
+
# print(description_rewards)
|
| 229 |
+
# print('-'*10)
|
| 230 |
+
# time.sleep(30)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# for content, sol, description_reward in zip(contents, solution, description_rewards):
|
| 234 |
+
for content, gt_ans, description_reward in zip(contents, gt_answers, description_rewards):
|
| 235 |
+
try:
|
| 236 |
+
output_ans = extract_answer(str(content))
|
| 237 |
+
# gt_ans = extract_answer(sol)
|
| 238 |
+
|
| 239 |
+
if question_type == "OCR":
|
| 240 |
+
# description_extraction = extract_answer(str(second_content))
|
| 241 |
+
# description_error_rate = wer(gt_ans, description_extraction)
|
| 242 |
+
description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 243 |
+
# error_rate = wer(gt_ans, output_ans)
|
| 244 |
+
answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
|
| 245 |
+
# reward = (1 - error_rate) + (1- description_error_rate)
|
| 246 |
+
# reward = max(0.0, min(2.0, reward))
|
| 247 |
+
# print('Extracted description: ', description_extraction)
|
| 248 |
+
# print('Generated answer: ', output_ans)
|
| 249 |
+
# print('Sol: ', gt_ans)
|
| 250 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 251 |
+
# print('-' * 10)
|
| 252 |
+
reward = description_pendat_reward + answer_pedant_reward
|
| 253 |
+
# elif question_type == "free-form":
|
| 254 |
+
# score = compute_rouge_score(gt_ans, output_ans)
|
| 255 |
+
# reward = max(0.0, min(1.0, score))
|
| 256 |
+
elif question_type == "regression":
|
| 257 |
+
gt_number = normalize_number(gt_ans)
|
| 258 |
+
out_number = normalize_number(output_ans)
|
| 259 |
+
if gt_number is None or out_number is None:
|
| 260 |
+
reward = 0.0
|
| 261 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 262 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 263 |
+
reward = 1 - rel_diff
|
| 264 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
|
| 265 |
+
# description_reward = compute_math_score_single(description_extraction, gt_ans)
|
| 266 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 267 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 268 |
+
# print('-' * 10)
|
| 269 |
+
reward = description_reward + answer_reward
|
| 270 |
+
# reward = answer_reward
|
| 271 |
+
else:
|
| 272 |
+
print('Falling back to none rewards')
|
| 273 |
+
reward = 0.0
|
| 274 |
+
except Exception as e:
|
| 275 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 276 |
+
reward = 0.0
|
| 277 |
+
|
| 278 |
+
rewards.append(reward)
|
| 279 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 280 |
+
log_path = os.getenv("LOG_PATH")
|
| 281 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 282 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 283 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 284 |
+
f.write(f"Content: {content}\n")
|
| 285 |
+
f.write(f"Solution: {gt_ans}\n")
|
| 286 |
+
|
| 287 |
+
# print("rewards: ", rewards)
|
| 288 |
+
return rewards
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def simple_format_reward(completions, **kwargs):
|
| 292 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 293 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 294 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 295 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 296 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 297 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
reward_funcs_registry = {
|
| 301 |
+
"accuracy": accuracy_reward,
|
| 302 |
+
"format": simple_format_reward,
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
# SYSTEM_PROMPT = (
|
| 306 |
+
# "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 307 |
+
# "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 308 |
+
# "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 309 |
+
# "<think> reasoning process here </think><answer> answer here </answer>"
|
| 310 |
+
# )
|
| 311 |
+
|
| 312 |
+
SYSTEM_PROMPT = (
|
| 313 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 314 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 315 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 316 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 317 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def main(script_args, training_args, model_args):
|
| 322 |
+
# Get reward functions
|
| 323 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 324 |
+
|
| 325 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 326 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 327 |
+
else:
|
| 328 |
+
# Load the dataset
|
| 329 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Format into conversation
|
| 333 |
+
def make_conversation(example):
|
| 334 |
+
return {
|
| 335 |
+
"prompt": [
|
| 336 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 337 |
+
{"role": "user", "content": example["problem"]},
|
| 338 |
+
],
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# QUESTION_TEMPLATE = (
|
| 343 |
+
# "{Question}\n"
|
| 344 |
+
# "Please think about this question as if you were a human pondering deeply. "
|
| 345 |
+
# "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 346 |
+
# "It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 347 |
+
# "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 348 |
+
# )
|
| 349 |
+
|
| 350 |
+
QUESTION_TEMPLATE = (
|
| 351 |
+
"{Question}\n"
|
| 352 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 353 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 354 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 355 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 356 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 357 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
TYPE_TEMPLATE = {
|
| 362 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 363 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 364 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 365 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 366 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 367 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
def make_conversation_image(example):
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
"prompt": [
|
| 374 |
+
{
|
| 375 |
+
"role": "user",
|
| 376 |
+
"content": [
|
| 377 |
+
{"type": "image"},
|
| 378 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 379 |
+
],
|
| 380 |
+
},
|
| 381 |
+
],
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def make_conversation_video(example):
|
| 386 |
+
return {
|
| 387 |
+
"prompt": [
|
| 388 |
+
{
|
| 389 |
+
"role": "user",
|
| 390 |
+
"content": [
|
| 391 |
+
{"type": "video"},
|
| 392 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 393 |
+
],
|
| 394 |
+
},
|
| 395 |
+
],
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
def make_conversation_image_and_video(example):
|
| 399 |
+
if example["problem_type"] == 'multiple choice':
|
| 400 |
+
question = example['problem'] + "Options:\n"
|
| 401 |
+
for op in example["options"]:
|
| 402 |
+
question += op + "\n"
|
| 403 |
+
else:
|
| 404 |
+
question = example['problem']
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
# msg ={
|
| 408 |
+
# "prompt":
|
| 409 |
+
# [{
|
| 410 |
+
# "role": "user",
|
| 411 |
+
# "content": [
|
| 412 |
+
# {
|
| 413 |
+
# "type": example['data_type'],
|
| 414 |
+
# # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 415 |
+
# },
|
| 416 |
+
# {
|
| 417 |
+
# "type": "text",
|
| 418 |
+
# "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 419 |
+
# }
|
| 420 |
+
# ]
|
| 421 |
+
# }]
|
| 422 |
+
# }
|
| 423 |
+
|
| 424 |
+
msg ={
|
| 425 |
+
"prompt":
|
| 426 |
+
[{
|
| 427 |
+
"role": "user",
|
| 428 |
+
"content": [
|
| 429 |
+
{
|
| 430 |
+
"type": example['data_type'],
|
| 431 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 432 |
+
},
|
| 433 |
+
{
|
| 434 |
+
"type": "text",
|
| 435 |
+
"text": QUESTION_TEMPLATE.format(Question=question)
|
| 436 |
+
}
|
| 437 |
+
]
|
| 438 |
+
}]
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
# return msg
|
| 442 |
+
return {
|
| 443 |
+
"prompt": msg["prompt"],
|
| 444 |
+
"problem": question,
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# print('Example problem')
|
| 452 |
+
# print(dataset['train']['problem'][10])
|
| 453 |
+
# time.sleep(30)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 457 |
+
print("using: ", trainer_cls)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
# Initialize the GRPO trainer
|
| 461 |
+
trainer = trainer_cls(
|
| 462 |
+
model=model_args.model_name_or_path,
|
| 463 |
+
reward_funcs=reward_funcs,
|
| 464 |
+
args=training_args,
|
| 465 |
+
script_args=script_args,
|
| 466 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 467 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 468 |
+
peft_config=get_peft_config(model_args),
|
| 469 |
+
attn_implementation=model_args.attn_implementation,
|
| 470 |
+
max_pixels=script_args.max_pixels,
|
| 471 |
+
min_pixels=script_args.min_pixels,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
if training_args.resume_from_checkpoint is not None:
|
| 475 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 476 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 477 |
+
else:
|
| 478 |
+
trainer.train()
|
| 479 |
+
|
| 480 |
+
# Save and push to hub
|
| 481 |
+
trainer.save_model(training_args.output_dir)
|
| 482 |
+
if training_args.push_to_hub:
|
| 483 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
if __name__ == "__main__":
|
| 487 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 488 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 489 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-LLMEval.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
# from utils.gpt_eval import infer
|
| 34 |
+
# from utils.math_cot import *
|
| 35 |
+
# from qa_metrics.pedant import PEDANT
|
| 36 |
+
# from qa_metrics.answerBERT import AnswerBertActor
|
| 37 |
+
|
| 38 |
+
# pedant = PEDANT()
|
| 39 |
+
# answerBERT = AnswerBertActor(device='cuda:7')
|
| 40 |
+
|
| 41 |
+
alpha = 1.0
|
| 42 |
+
|
| 43 |
+
TYPE_TEMPLATE = {
|
| 44 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 45 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 46 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 47 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 48 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 49 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
'''
|
| 53 |
+
gpt infer
|
| 54 |
+
'''
|
| 55 |
+
import os
|
| 56 |
+
from openai import AzureOpenAI
|
| 57 |
+
import time
|
| 58 |
+
|
| 59 |
+
import base64
|
| 60 |
+
from mimetypes import guess_type
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def azure_gpt4(messages, model):
|
| 64 |
+
outputs = []
|
| 65 |
+
for message in messages:
|
| 66 |
+
input_prompt = [
|
| 67 |
+
{ "role": "system", "content": "You are a helpful assistant." },
|
| 68 |
+
{ "role": "user", "content": [
|
| 69 |
+
{
|
| 70 |
+
"type": "text",
|
| 71 |
+
"text": message["instruction"]
|
| 72 |
+
},
|
| 73 |
+
# {
|
| 74 |
+
# "type": "image_url",
|
| 75 |
+
# "image_url": {
|
| 76 |
+
# "url": message["image"]
|
| 77 |
+
# }
|
| 78 |
+
# }
|
| 79 |
+
]}
|
| 80 |
+
]
|
| 81 |
+
## try N times if API exceed limit ...
|
| 82 |
+
for i in range(10):
|
| 83 |
+
try:
|
| 84 |
+
output = client.chat.completions.create(
|
| 85 |
+
model=model, messages=input_prompt, max_tokens=2000
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
output_text = output.choices[0].message.content
|
| 89 |
+
break ## exit if successful
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f'Index {i} got error message: {e}')
|
| 93 |
+
output_text = ''
|
| 94 |
+
time.sleep(3)
|
| 95 |
+
|
| 96 |
+
outputs.append(output_text)
|
| 97 |
+
|
| 98 |
+
return outputs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
client = AzureOpenAI(
|
| 102 |
+
api_key = "83f30a2a22324395b854bd343db38d85",
|
| 103 |
+
api_version = "2024-08-01-preview",
|
| 104 |
+
azure_endpoint = "https://francecentral.api.cognitive.microsoft.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
model = "gpt-4o"
|
| 108 |
+
prompt_template = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. Provide your answer as a single final answer or a short phrase enclosed with <answer></answer>. \nText description: {text}\nQuestion: {question}'''
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def infer(prompt):
|
| 112 |
+
# prompt_question = prompt_question.replace('<image>', '')
|
| 113 |
+
# prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
|
| 114 |
+
|
| 115 |
+
messages = [
|
| 116 |
+
{"instruction": prompt},
|
| 117 |
+
]
|
| 118 |
+
prompt_success = False
|
| 119 |
+
prompt_time = 0
|
| 120 |
+
outputs = ['<answer> None </answer>']
|
| 121 |
+
while prompt_success == False and prompt_time <= 2:
|
| 122 |
+
try:
|
| 123 |
+
outputs = azure_gpt4(messages, model)
|
| 124 |
+
prompt_success = True
|
| 125 |
+
except:
|
| 126 |
+
prompt_time += 1
|
| 127 |
+
time.sleep(5)
|
| 128 |
+
|
| 129 |
+
return outputs[0]
|
| 130 |
+
|
| 131 |
+
'''
|
| 132 |
+
end of gpt infer
|
| 133 |
+
'''
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 137 |
+
|
| 138 |
+
def _call_infer(desc):
|
| 139 |
+
return infer(desc)
|
| 140 |
+
|
| 141 |
+
@dataclass
|
| 142 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 143 |
+
"""
|
| 144 |
+
Script arguments for the GRPO training script.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
reward_funcs (`list[str]`):
|
| 148 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
reward_funcs: list[str] = field(
|
| 152 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 153 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# reward_funcs: list[str] = field(
|
| 157 |
+
# default_factory=lambda: ["accuracy"],
|
| 158 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 159 |
+
# )
|
| 160 |
+
max_pixels: Optional[int] = field(
|
| 161 |
+
default=12845056,
|
| 162 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 163 |
+
)
|
| 164 |
+
min_pixels: Optional[int] = field(
|
| 165 |
+
default=3136,
|
| 166 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 167 |
+
)
|
| 168 |
+
temporal: Optional[bool] = field(
|
| 169 |
+
default=True,
|
| 170 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 171 |
+
)
|
| 172 |
+
len_control: Optional[bool] = field(
|
| 173 |
+
default=True,
|
| 174 |
+
metadata={"help": "whether using length reward"},
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 180 |
+
def extract_answer(text: str) -> str:
|
| 181 |
+
"""
|
| 182 |
+
1) Try the full <answer> … </answer> block.
|
| 183 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 184 |
+
3) Otherwise return the original text.
|
| 185 |
+
"""
|
| 186 |
+
# ① normal case <answer> … </answer>
|
| 187 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 188 |
+
if m:
|
| 189 |
+
return m.group(1).strip()
|
| 190 |
+
|
| 191 |
+
# ② fallback <answer> … <end-of-string>
|
| 192 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 193 |
+
if m:
|
| 194 |
+
return m.group(1).strip()
|
| 195 |
+
|
| 196 |
+
# ③ nothing found
|
| 197 |
+
return text.strip()
|
| 198 |
+
|
| 199 |
+
def extract_description(predict: str) -> Optional[str]:
|
| 200 |
+
"""
|
| 201 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 202 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 203 |
+
or None if no <answer> tag is found.
|
| 204 |
+
"""
|
| 205 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 206 |
+
if not match:
|
| 207 |
+
return predict
|
| 208 |
+
return match.group(1).strip()
|
| 209 |
+
|
| 210 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 211 |
+
answer = predict
|
| 212 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 213 |
+
|
| 214 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 215 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 216 |
+
# format_score = format_reward(predict)
|
| 217 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 218 |
+
|
| 219 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 220 |
+
return accuracy_score
|
| 221 |
+
|
| 222 |
+
def normalize_number(num_str):
|
| 223 |
+
try:
|
| 224 |
+
num_str = num_str.replace(',', '')
|
| 225 |
+
return float(num_str)
|
| 226 |
+
except Exception as e:
|
| 227 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
def wer(reference, hypothesis):
|
| 231 |
+
ref_words = reference.split()
|
| 232 |
+
hyp_words = hypothesis.split()
|
| 233 |
+
m = len(ref_words)
|
| 234 |
+
n = len(hyp_words)
|
| 235 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 236 |
+
for i in range(m+1):
|
| 237 |
+
d[i][0] = i
|
| 238 |
+
for j in range(n+1):
|
| 239 |
+
d[0][j] = j
|
| 240 |
+
for i in range(1, m+1):
|
| 241 |
+
for j in range(1, n+1):
|
| 242 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 243 |
+
d[i][j] = d[i-1][j-1]
|
| 244 |
+
else:
|
| 245 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 246 |
+
return d[m][n] / max(1, m)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 250 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 251 |
+
scores = scorer.score(reference, hypothesis)
|
| 252 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 253 |
+
return average_fmeasure
|
| 254 |
+
|
| 255 |
+
# print('Computing rewards now...')
|
| 256 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 257 |
+
# second_completions = kwargs.get("second_completions")
|
| 258 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 259 |
+
# print('second prompts', second_prompts)
|
| 260 |
+
# print('-'*10)
|
| 261 |
+
# print('second completions', second_completions)
|
| 262 |
+
# print('-'*10)
|
| 263 |
+
|
| 264 |
+
# import time
|
| 265 |
+
# time.sleep(30)
|
| 266 |
+
question_type = kwargs['problem_type'][0]
|
| 267 |
+
questions = kwargs['problem']
|
| 268 |
+
|
| 269 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 270 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 271 |
+
rewards = []
|
| 272 |
+
|
| 273 |
+
extracted_content_descriptions = [extract_description(ele) for ele in contents]
|
| 274 |
+
|
| 275 |
+
description_query_inputs = []
|
| 276 |
+
|
| 277 |
+
for index in range(len(extracted_content_descriptions)):
|
| 278 |
+
prompt_question = questions[index]
|
| 279 |
+
des_text = extracted_content_descriptions[index]
|
| 280 |
+
prompt_question = prompt_question.replace('<image>', '')
|
| 281 |
+
prompt_input = prompt_template.replace('{text}', des_text).replace('{question}', prompt_question) + TYPE_TEMPLATE[question_type]
|
| 282 |
+
description_query_inputs.append(prompt_input)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
description_score_outputs = []
|
| 286 |
+
with ThreadPoolExecutor(max_workers=8) as executor:
|
| 287 |
+
# kick off all the futures
|
| 288 |
+
# futures = [
|
| 289 |
+
# executor.submit(_call_infer, desc, ques)
|
| 290 |
+
# for desc, ques in zip(extracted_content_descriptions, questions)
|
| 291 |
+
# ]
|
| 292 |
+
futures = [
|
| 293 |
+
executor.submit(_call_infer, desc)
|
| 294 |
+
for desc in description_query_inputs
|
| 295 |
+
]
|
| 296 |
+
# collect as they finish (optional—keeps order of completion)
|
| 297 |
+
for fut in as_completed(futures):
|
| 298 |
+
description_score_outputs.append(extract_answer(fut.result()))
|
| 299 |
+
|
| 300 |
+
gt_answers = [extract_answer(sol) for sol in solution]
|
| 301 |
+
description_rewards = [compute_math_score_single(description_score_outputs[count_idx], gt_answers[count_idx]) for count_idx in range(len(description_score_outputs))]
|
| 302 |
+
|
| 303 |
+
# print(gt_answers)
|
| 304 |
+
# print(description_score_outputs)
|
| 305 |
+
# print(description_rewards)
|
| 306 |
+
# print('-'*10)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
for content, gt_ans, description_reward in zip(contents, gt_answers, description_rewards):
|
| 310 |
+
# for content, sol, question in zip(contents, solution, questions):
|
| 311 |
+
# for content, sol, second_content in zip(contents, solution, second_completions):
|
| 312 |
+
try:
|
| 313 |
+
output_ans = extract_answer(content)
|
| 314 |
+
# gt_ans = extract_answer(sol)
|
| 315 |
+
# description_extraction = extract_answer(second_content)
|
| 316 |
+
# if question_type == "multiple choice":
|
| 317 |
+
# reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 318 |
+
# elif question_type == "numerical":
|
| 319 |
+
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 320 |
+
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 321 |
+
# if gt_has_decimal != out_has_decimal:
|
| 322 |
+
# reward = 0.0
|
| 323 |
+
# else:
|
| 324 |
+
# gt_number = normalize_number(gt_ans)
|
| 325 |
+
# out_number = normalize_number(output_ans)
|
| 326 |
+
# if gt_number is None or out_number is None:
|
| 327 |
+
# reward = 0.0
|
| 328 |
+
# else:
|
| 329 |
+
# reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 330 |
+
if question_type == "OCR":
|
| 331 |
+
# description_extraction = extract_answer(second_content)
|
| 332 |
+
# description_error_rate = wer(gt_ans, description_extraction)
|
| 333 |
+
# description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 334 |
+
# error_rate = wer(gt_ans, output_ans)
|
| 335 |
+
answer_pedant_reward = pedant.get_score(gt_ans, output_ans, questions[0])
|
| 336 |
+
# reward = (1 - error_rate) + (1- description_error_rate)
|
| 337 |
+
# reward = max(0.0, min(2.0, reward))
|
| 338 |
+
# print('Extracted description: ', description_extraction)
|
| 339 |
+
# print('Generated answer: ', output_ans)
|
| 340 |
+
# print('Sol: ', gt_ans)
|
| 341 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 342 |
+
# print('-' * 10)
|
| 343 |
+
# reward = description_pendat_reward + answer_pedant_reward
|
| 344 |
+
reward = answer_pedant_reward
|
| 345 |
+
# elif question_type == "free-form":
|
| 346 |
+
# score = compute_rouge_score(gt_ans, output_ans)
|
| 347 |
+
# reward = max(0.0, min(1.0, score))
|
| 348 |
+
elif question_type == "regression":
|
| 349 |
+
gt_number = normalize_number(gt_ans)
|
| 350 |
+
out_number = normalize_number(output_ans)
|
| 351 |
+
if gt_number is None or out_number is None:
|
| 352 |
+
reward = 0.0
|
| 353 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 354 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 355 |
+
reward = 1 - rel_diff
|
| 356 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
|
| 357 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# print(f"Extracted description: {description_extraction} | Generated answer: {output_ans} | Sol: {gt_ans}")
|
| 361 |
+
# print(f'Description reward: {description_reward} | answer reward: {answer_reward} | final reward: {reward}')
|
| 362 |
+
# print('-' * 10)
|
| 363 |
+
|
| 364 |
+
if description_reward == 0 and answer_reward == 1:
|
| 365 |
+
reward = alpha
|
| 366 |
+
else:
|
| 367 |
+
reward = description_reward + answer_reward
|
| 368 |
+
# reward = answer_reward
|
| 369 |
+
else:
|
| 370 |
+
print('Falling back to none rewards')
|
| 371 |
+
reward = 0.0
|
| 372 |
+
except Exception as e:
|
| 373 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 374 |
+
reward = 0.0
|
| 375 |
+
|
| 376 |
+
rewards.append(reward)
|
| 377 |
+
|
| 378 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 379 |
+
log_path = os.getenv("LOG_PATH")
|
| 380 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 381 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 382 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 383 |
+
f.write(f"Content: {content}\n")
|
| 384 |
+
f.write(f"Solution: {gt_ans}\n")
|
| 385 |
+
|
| 386 |
+
return rewards
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def simple_format_reward(completions, **kwargs):
|
| 390 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 391 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 392 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 393 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 394 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 395 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
reward_funcs_registry = {
|
| 399 |
+
"accuracy": accuracy_reward,
|
| 400 |
+
"format": simple_format_reward,
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
# SYSTEM_PROMPT = (
|
| 404 |
+
# "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 405 |
+
# "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 406 |
+
# "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 407 |
+
# "<think> reasoning process here </think><answer> answer here </answer>"
|
| 408 |
+
# )
|
| 409 |
+
|
| 410 |
+
SYSTEM_PROMPT = (
|
| 411 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 412 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 413 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 414 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 415 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def main(script_args, training_args, model_args):
|
| 420 |
+
# Get reward functions
|
| 421 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 422 |
+
|
| 423 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 424 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 425 |
+
else:
|
| 426 |
+
# Load the dataset
|
| 427 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# Format into conversation
|
| 431 |
+
def make_conversation(example):
|
| 432 |
+
return {
|
| 433 |
+
"prompt": [
|
| 434 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 435 |
+
{"role": "user", "content": example["problem"]},
|
| 436 |
+
],
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# QUESTION_TEMPLATE = (
|
| 441 |
+
# "{Question}\n"
|
| 442 |
+
# "Please think about this question as if you were a human pondering deeply. "
|
| 443 |
+
# "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 444 |
+
# "It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 445 |
+
# "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 446 |
+
# )
|
| 447 |
+
|
| 448 |
+
QUESTION_TEMPLATE = (
|
| 449 |
+
"{Question}\n"
|
| 450 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 451 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 452 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 453 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 454 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 455 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def make_conversation_image(example):
|
| 461 |
+
|
| 462 |
+
return {
|
| 463 |
+
"prompt": [
|
| 464 |
+
{
|
| 465 |
+
"role": "user",
|
| 466 |
+
"content": [
|
| 467 |
+
{"type": "image"},
|
| 468 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 469 |
+
],
|
| 470 |
+
},
|
| 471 |
+
],
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def make_conversation_video(example):
|
| 476 |
+
return {
|
| 477 |
+
"prompt": [
|
| 478 |
+
{
|
| 479 |
+
"role": "user",
|
| 480 |
+
"content": [
|
| 481 |
+
{"type": "video"},
|
| 482 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 483 |
+
],
|
| 484 |
+
},
|
| 485 |
+
],
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
def make_conversation_image_and_video(example):
|
| 489 |
+
if example["problem_type"] == 'multiple choice':
|
| 490 |
+
question = example['problem'] + "Options:\n"
|
| 491 |
+
for op in example["options"]:
|
| 492 |
+
question += op + "\n"
|
| 493 |
+
else:
|
| 494 |
+
question = example['problem']
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
msg ={
|
| 498 |
+
"prompt":
|
| 499 |
+
[{
|
| 500 |
+
"role": "user",
|
| 501 |
+
"content": [
|
| 502 |
+
{
|
| 503 |
+
"type": example['data_type'],
|
| 504 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"type": "text",
|
| 508 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 509 |
+
}
|
| 510 |
+
]
|
| 511 |
+
}]
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
return msg
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 521 |
+
print("using: ", trainer_cls)
|
| 522 |
+
|
| 523 |
+
# Initialize the GRPO trainer
|
| 524 |
+
trainer = trainer_cls(
|
| 525 |
+
model=model_args.model_name_or_path,
|
| 526 |
+
reward_funcs=reward_funcs,
|
| 527 |
+
args=training_args,
|
| 528 |
+
script_args=script_args,
|
| 529 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 530 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 531 |
+
peft_config=get_peft_config(model_args),
|
| 532 |
+
attn_implementation=model_args.attn_implementation,
|
| 533 |
+
max_pixels=script_args.max_pixels,
|
| 534 |
+
min_pixels=script_args.min_pixels,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
if training_args.resume_from_checkpoint is not None:
|
| 538 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 539 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 540 |
+
else:
|
| 541 |
+
trainer.train()
|
| 542 |
+
|
| 543 |
+
# Save and push to hub
|
| 544 |
+
trainer.save_model(training_args.output_dir)
|
| 545 |
+
if training_args.push_to_hub:
|
| 546 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if __name__ == "__main__":
|
| 550 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 551 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 552 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-answerBERT-eval.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
# from utils.math_cot import *
|
| 34 |
+
# from qa_metrics.pedant import PEDANT
|
| 35 |
+
from qa_metrics.answerBERT import AnswerBertActor
|
| 36 |
+
|
| 37 |
+
# pedant = PEDANT()
|
| 38 |
+
answerBERT = AnswerBertActor(device='cuda:0')
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 42 |
+
"""
|
| 43 |
+
Script arguments for the GRPO training script.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
reward_funcs (`list[str]`):
|
| 47 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
reward_funcs: list[str] = field(
|
| 51 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 52 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# reward_funcs: list[str] = field(
|
| 56 |
+
# default_factory=lambda: ["accuracy"],
|
| 57 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 58 |
+
# )
|
| 59 |
+
max_pixels: Optional[int] = field(
|
| 60 |
+
default=12845056,
|
| 61 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 62 |
+
)
|
| 63 |
+
min_pixels: Optional[int] = field(
|
| 64 |
+
default=3136,
|
| 65 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 66 |
+
)
|
| 67 |
+
temporal: Optional[bool] = field(
|
| 68 |
+
default=True,
|
| 69 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 70 |
+
)
|
| 71 |
+
len_control: Optional[bool] = field(
|
| 72 |
+
default=True,
|
| 73 |
+
metadata={"help": "whether using length reward"},
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 79 |
+
def extract_answer(text: str) -> str:
|
| 80 |
+
"""
|
| 81 |
+
1) Try the full <answer> … </answer> block.
|
| 82 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 83 |
+
3) Otherwise return the original text.
|
| 84 |
+
"""
|
| 85 |
+
# ① normal case <answer> … </answer>
|
| 86 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 87 |
+
if m:
|
| 88 |
+
return m.group(1).strip()
|
| 89 |
+
|
| 90 |
+
# ② fallback <answer> … <end-of-string>
|
| 91 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 92 |
+
if m:
|
| 93 |
+
return m.group(1).strip()
|
| 94 |
+
|
| 95 |
+
# ③ nothing found
|
| 96 |
+
return text.strip()
|
| 97 |
+
|
| 98 |
+
def extract_description(predict: str) -> Optional[str]:
|
| 99 |
+
"""
|
| 100 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 101 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 102 |
+
or None if no <answer> tag is found.
|
| 103 |
+
"""
|
| 104 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 105 |
+
if not match:
|
| 106 |
+
return predict
|
| 107 |
+
return match.group(1).strip()
|
| 108 |
+
|
| 109 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 110 |
+
answer = predict
|
| 111 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 112 |
+
|
| 113 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 114 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 115 |
+
# format_score = format_reward(predict)
|
| 116 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 117 |
+
|
| 118 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 119 |
+
return accuracy_score
|
| 120 |
+
|
| 121 |
+
def normalize_number(num_str):
|
| 122 |
+
try:
|
| 123 |
+
num_str = num_str.replace(',', '')
|
| 124 |
+
return float(num_str)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
def wer(reference, hypothesis):
|
| 130 |
+
ref_words = reference.split()
|
| 131 |
+
hyp_words = hypothesis.split()
|
| 132 |
+
m = len(ref_words)
|
| 133 |
+
n = len(hyp_words)
|
| 134 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 135 |
+
for i in range(m+1):
|
| 136 |
+
d[i][0] = i
|
| 137 |
+
for j in range(n+1):
|
| 138 |
+
d[0][j] = j
|
| 139 |
+
for i in range(1, m+1):
|
| 140 |
+
for j in range(1, n+1):
|
| 141 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 142 |
+
d[i][j] = d[i-1][j-1]
|
| 143 |
+
else:
|
| 144 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 145 |
+
return d[m][n] / max(1, m)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 149 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 150 |
+
scores = scorer.score(reference, hypothesis)
|
| 151 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 152 |
+
return average_fmeasure
|
| 153 |
+
|
| 154 |
+
# print('Computing rewards now...')
|
| 155 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 156 |
+
# second_completions = kwargs.get("second_completions")
|
| 157 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 158 |
+
# print('second prompts', second_prompts)
|
| 159 |
+
# print('-'*10)
|
| 160 |
+
# print('second completions', second_completions)
|
| 161 |
+
# print('-'*10)
|
| 162 |
+
|
| 163 |
+
# import time
|
| 164 |
+
# time.sleep(30)
|
| 165 |
+
question_type = kwargs['problem_type'][0]
|
| 166 |
+
questions = kwargs['problem']
|
| 167 |
+
|
| 168 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 169 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 170 |
+
rewards = []
|
| 171 |
+
|
| 172 |
+
extracted_content_descriptions = [extract_description(ele) for ele in contents]
|
| 173 |
+
# extracted_content_answers = [extract_answer(ele) for ele in contents]
|
| 174 |
+
# model = kwargs.get("model") # may be None if called elsewhere
|
| 175 |
+
# tokenizer = kwargs.get("tokenizer")
|
| 176 |
+
# # (optional) example use: let the model score the generated answer
|
| 177 |
+
# if model is not None and tokenizer is not None:
|
| 178 |
+
# model.eval()
|
| 179 |
+
description_inputs = [questions[index_count] + ' [SEP] ' + extracted_content_descriptions[index_count] for index_count in range(len(extracted_content_descriptions))]
|
| 180 |
+
description_rewards = answerBERT.batch_predict(description_inputs, batch_size = 32)
|
| 181 |
+
|
| 182 |
+
for content, sol, description_reward in zip(contents, solution, description_rewards):
|
| 183 |
+
# for content, sol, question in zip(contents, solution, questions):
|
| 184 |
+
# for content, sol, second_content in zip(contents, solution, second_completions):
|
| 185 |
+
try:
|
| 186 |
+
output_ans = extract_answer(content)
|
| 187 |
+
gt_ans = extract_answer(sol)
|
| 188 |
+
# description_extraction = extract_answer(second_content)
|
| 189 |
+
# if question_type == "multiple choice":
|
| 190 |
+
# reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 191 |
+
# elif question_type == "numerical":
|
| 192 |
+
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 193 |
+
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 194 |
+
# if gt_has_decimal != out_has_decimal:
|
| 195 |
+
# reward = 0.0
|
| 196 |
+
# else:
|
| 197 |
+
# gt_number = normalize_number(gt_ans)
|
| 198 |
+
# out_number = normalize_number(output_ans)
|
| 199 |
+
# if gt_number is None or out_number is None:
|
| 200 |
+
# reward = 0.0
|
| 201 |
+
# else:
|
| 202 |
+
# reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 203 |
+
if question_type == "OCR":
|
| 204 |
+
# description_extraction = extract_answer(second_content)
|
| 205 |
+
# description_error_rate = wer(gt_ans, description_extraction)
|
| 206 |
+
description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 207 |
+
# error_rate = wer(gt_ans, output_ans)
|
| 208 |
+
answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
|
| 209 |
+
# reward = (1 - error_rate) + (1- description_error_rate)
|
| 210 |
+
# reward = max(0.0, min(2.0, reward))
|
| 211 |
+
# print('Extracted description: ', description_extraction)
|
| 212 |
+
# print('Generated answer: ', output_ans)
|
| 213 |
+
# print('Sol: ', gt_ans)
|
| 214 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 215 |
+
# print('-' * 10)
|
| 216 |
+
reward = description_pendat_reward + answer_pedant_reward
|
| 217 |
+
# elif question_type == "free-form":
|
| 218 |
+
# score = compute_rouge_score(gt_ans, output_ans)
|
| 219 |
+
# reward = max(0.0, min(1.0, score))
|
| 220 |
+
# elif question_type == "regression":
|
| 221 |
+
# gt_number = normalize_number(gt_ans)
|
| 222 |
+
# out_number = normalize_number(output_ans)
|
| 223 |
+
# if gt_number is None or out_number is None:
|
| 224 |
+
# reward = 0.0
|
| 225 |
+
# rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 226 |
+
# rel_diff = min(1.0, max(0.0, rel_diff))
|
| 227 |
+
# reward = 1 - rel_diff
|
| 228 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical" or question_type == "regression":
|
| 229 |
+
# print('Extracted description: ', description_extraction)
|
| 230 |
+
# print('Generated answer: ', output_ans)
|
| 231 |
+
# print('Sol: ', gt_ans)
|
| 232 |
+
|
| 233 |
+
# description_reward = compute_math_score_single(description_extraction, gt_ans)
|
| 234 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 235 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 236 |
+
# print('-' * 10)
|
| 237 |
+
reward = description_reward + answer_reward
|
| 238 |
+
else:
|
| 239 |
+
print('Falling back to none rewards')
|
| 240 |
+
reward = 0.0
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 243 |
+
reward = 0.0
|
| 244 |
+
|
| 245 |
+
rewards.append(reward)
|
| 246 |
+
|
| 247 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 248 |
+
log_path = os.getenv("LOG_PATH")
|
| 249 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 250 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 251 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 252 |
+
f.write(f"Content: {content}\n")
|
| 253 |
+
f.write(f"Solution: {sol}\n")
|
| 254 |
+
|
| 255 |
+
return rewards
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def simple_format_reward(completions, **kwargs):
|
| 259 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 260 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 261 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 262 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 263 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 264 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
reward_funcs_registry = {
|
| 268 |
+
"accuracy": accuracy_reward,
|
| 269 |
+
"format": simple_format_reward,
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# SYSTEM_PROMPT = (
|
| 273 |
+
# "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 274 |
+
# "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 275 |
+
# "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 276 |
+
# "<think> reasoning process here </think><answer> answer here </answer>"
|
| 277 |
+
# )
|
| 278 |
+
|
| 279 |
+
SYSTEM_PROMPT = (
|
| 280 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 281 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 282 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 283 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 284 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def main(script_args, training_args, model_args):
|
| 289 |
+
# Get reward functions
|
| 290 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 291 |
+
|
| 292 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 293 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 294 |
+
else:
|
| 295 |
+
# Load the dataset
|
| 296 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Format into conversation
|
| 300 |
+
def make_conversation(example):
|
| 301 |
+
return {
|
| 302 |
+
"prompt": [
|
| 303 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 304 |
+
{"role": "user", "content": example["problem"]},
|
| 305 |
+
],
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# QUESTION_TEMPLATE = (
|
| 310 |
+
# "{Question}\n"
|
| 311 |
+
# "Please think about this question as if you were a human pondering deeply. "
|
| 312 |
+
# "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 313 |
+
# "It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 314 |
+
# "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 315 |
+
# )
|
| 316 |
+
|
| 317 |
+
QUESTION_TEMPLATE = (
|
| 318 |
+
"{Question}\n"
|
| 319 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 320 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 321 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 322 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 323 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 324 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
TYPE_TEMPLATE = {
|
| 329 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 330 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 331 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 332 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 333 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 334 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
def make_conversation_image(example):
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
"prompt": [
|
| 341 |
+
{
|
| 342 |
+
"role": "user",
|
| 343 |
+
"content": [
|
| 344 |
+
{"type": "image"},
|
| 345 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 346 |
+
],
|
| 347 |
+
},
|
| 348 |
+
],
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def make_conversation_video(example):
|
| 353 |
+
return {
|
| 354 |
+
"prompt": [
|
| 355 |
+
{
|
| 356 |
+
"role": "user",
|
| 357 |
+
"content": [
|
| 358 |
+
{"type": "video"},
|
| 359 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 360 |
+
],
|
| 361 |
+
},
|
| 362 |
+
],
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def make_conversation_image_and_video(example):
|
| 366 |
+
if example["problem_type"] == 'multiple choice':
|
| 367 |
+
question = example['problem'] + "Options:\n"
|
| 368 |
+
for op in example["options"]:
|
| 369 |
+
question += op + "\n"
|
| 370 |
+
else:
|
| 371 |
+
question = example['problem']
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
msg ={
|
| 375 |
+
"prompt":
|
| 376 |
+
[{
|
| 377 |
+
"role": "user",
|
| 378 |
+
"content": [
|
| 379 |
+
{
|
| 380 |
+
"type": example['data_type'],
|
| 381 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 382 |
+
},
|
| 383 |
+
{
|
| 384 |
+
"type": "text",
|
| 385 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 386 |
+
}
|
| 387 |
+
]
|
| 388 |
+
}]
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
return msg
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 398 |
+
print("using: ", trainer_cls)
|
| 399 |
+
|
| 400 |
+
# Initialize the GRPO trainer
|
| 401 |
+
trainer = trainer_cls(
|
| 402 |
+
model=model_args.model_name_or_path,
|
| 403 |
+
reward_funcs=reward_funcs,
|
| 404 |
+
args=training_args,
|
| 405 |
+
script_args=script_args,
|
| 406 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 407 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 408 |
+
peft_config=get_peft_config(model_args),
|
| 409 |
+
attn_implementation=model_args.attn_implementation,
|
| 410 |
+
max_pixels=script_args.max_pixels,
|
| 411 |
+
min_pixels=script_args.min_pixels,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
if training_args.resume_from_checkpoint is not None:
|
| 415 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 416 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 417 |
+
else:
|
| 418 |
+
trainer.train()
|
| 419 |
+
|
| 420 |
+
# Save and push to hub
|
| 421 |
+
trainer.save_model(training_args.output_dir)
|
| 422 |
+
if training_args.push_to_hub:
|
| 423 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
if __name__ == "__main__":
|
| 427 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 428 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 429 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-noDesEval.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
# from utils.math_cot import *
|
| 34 |
+
# from qa_metrics.pedant import PEDANT
|
| 35 |
+
# from qa_metrics.answerBERT import AnswerBertActor
|
| 36 |
+
|
| 37 |
+
# pedant = PEDANT()
|
| 38 |
+
# answerBERT = AnswerBertActor(device='cuda:7')
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 42 |
+
"""
|
| 43 |
+
Script arguments for the GRPO training script.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
reward_funcs (`list[str]`):
|
| 47 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
reward_funcs: list[str] = field(
|
| 51 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 52 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# reward_funcs: list[str] = field(
|
| 56 |
+
# default_factory=lambda: ["accuracy"],
|
| 57 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 58 |
+
# )
|
| 59 |
+
max_pixels: Optional[int] = field(
|
| 60 |
+
default=12845056,
|
| 61 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 62 |
+
)
|
| 63 |
+
min_pixels: Optional[int] = field(
|
| 64 |
+
default=3136,
|
| 65 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 66 |
+
)
|
| 67 |
+
temporal: Optional[bool] = field(
|
| 68 |
+
default=True,
|
| 69 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 70 |
+
)
|
| 71 |
+
len_control: Optional[bool] = field(
|
| 72 |
+
default=True,
|
| 73 |
+
metadata={"help": "whether using length reward"},
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 78 |
+
def extract_answer(text: str) -> str:
|
| 79 |
+
"""
|
| 80 |
+
1) Try the full <answer> … </answer> block.
|
| 81 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 82 |
+
3) Otherwise return the original text.
|
| 83 |
+
"""
|
| 84 |
+
# ① normal case <answer> … </answer>
|
| 85 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 86 |
+
if m:
|
| 87 |
+
return m.group(1).strip()
|
| 88 |
+
|
| 89 |
+
# ② fallback <answer> … <end-of-string>
|
| 90 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 91 |
+
if m:
|
| 92 |
+
return m.group(1).strip()
|
| 93 |
+
|
| 94 |
+
# ③ nothing found
|
| 95 |
+
return text.strip()
|
| 96 |
+
|
| 97 |
+
def extract_description(predict: str) -> Optional[str]:
|
| 98 |
+
"""
|
| 99 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 100 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 101 |
+
or None if no <answer> tag is found.
|
| 102 |
+
"""
|
| 103 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 104 |
+
if not match:
|
| 105 |
+
return predict
|
| 106 |
+
return match.group(1).strip()
|
| 107 |
+
|
| 108 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 109 |
+
answer = predict
|
| 110 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 111 |
+
|
| 112 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 113 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 114 |
+
# format_score = format_reward(predict)
|
| 115 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 116 |
+
|
| 117 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 118 |
+
return accuracy_score
|
| 119 |
+
|
| 120 |
+
def normalize_number(num_str):
|
| 121 |
+
try:
|
| 122 |
+
num_str = num_str.replace(',', '')
|
| 123 |
+
return float(num_str)
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
def wer(reference, hypothesis):
|
| 129 |
+
ref_words = reference.split()
|
| 130 |
+
hyp_words = hypothesis.split()
|
| 131 |
+
m = len(ref_words)
|
| 132 |
+
n = len(hyp_words)
|
| 133 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 134 |
+
for i in range(m+1):
|
| 135 |
+
d[i][0] = i
|
| 136 |
+
for j in range(n+1):
|
| 137 |
+
d[0][j] = j
|
| 138 |
+
for i in range(1, m+1):
|
| 139 |
+
for j in range(1, n+1):
|
| 140 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 141 |
+
d[i][j] = d[i-1][j-1]
|
| 142 |
+
else:
|
| 143 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 144 |
+
return d[m][n] / max(1, m)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 148 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 149 |
+
scores = scorer.score(reference, hypothesis)
|
| 150 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 151 |
+
return average_fmeasure
|
| 152 |
+
|
| 153 |
+
# print('Computing rewards now...')
|
| 154 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 155 |
+
# second_completions = kwargs.get("second_completions")
|
| 156 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 157 |
+
# print('second prompts', second_prompts)
|
| 158 |
+
# print('-'*10)
|
| 159 |
+
# print('second completions', second_completions)
|
| 160 |
+
# print('-'*10)
|
| 161 |
+
|
| 162 |
+
# import time
|
| 163 |
+
# time.sleep(30)
|
| 164 |
+
question_type = kwargs['problem_type'][0]
|
| 165 |
+
questions = kwargs['problem']
|
| 166 |
+
|
| 167 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 168 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 169 |
+
rewards = []
|
| 170 |
+
|
| 171 |
+
# extracted_content_descriptions = [extract_description(ele) for ele in contents]
|
| 172 |
+
# extracted_content_answers = [extract_answer(ele) for ele in contents]
|
| 173 |
+
# model = kwargs.get("model") # may be None if called elsewhere
|
| 174 |
+
# tokenizer = kwargs.get("tokenizer")
|
| 175 |
+
# # (optional) example use: let the model score the generated answer
|
| 176 |
+
# if model is not None and tokenizer is not None:
|
| 177 |
+
# model.eval()
|
| 178 |
+
# description_inputs = [questions[index_count] + ' [SEP] ' + extracted_content_descriptions[index_count] for index_count in range(len(extracted_content_descriptions))]
|
| 179 |
+
# description_rewards = answerBERT.batch_predict(description_inputs, batch_size = 64)
|
| 180 |
+
|
| 181 |
+
for content, sol in zip(contents, solution):
|
| 182 |
+
# for content, sol, question in zip(contents, solution, questions):
|
| 183 |
+
# for content, sol, second_content in zip(contents, solution, second_completions):
|
| 184 |
+
try:
|
| 185 |
+
output_ans = extract_answer(content)
|
| 186 |
+
gt_ans = extract_answer(sol)
|
| 187 |
+
# description_extraction = extract_answer(second_content)
|
| 188 |
+
# if question_type == "multiple choice":
|
| 189 |
+
# reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 190 |
+
# elif question_type == "numerical":
|
| 191 |
+
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 192 |
+
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 193 |
+
# if gt_has_decimal != out_has_decimal:
|
| 194 |
+
# reward = 0.0
|
| 195 |
+
# else:
|
| 196 |
+
# gt_number = normalize_number(gt_ans)
|
| 197 |
+
# out_number = normalize_number(output_ans)
|
| 198 |
+
# if gt_number is None or out_number is None:
|
| 199 |
+
# reward = 0.0
|
| 200 |
+
# else:
|
| 201 |
+
# reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 202 |
+
if question_type == "OCR":
|
| 203 |
+
# description_extraction = extract_answer(second_content)
|
| 204 |
+
# description_error_rate = wer(gt_ans, description_extraction)
|
| 205 |
+
description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 206 |
+
# error_rate = wer(gt_ans, output_ans)
|
| 207 |
+
answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
|
| 208 |
+
# reward = (1 - error_rate) + (1- description_error_rate)
|
| 209 |
+
# reward = max(0.0, min(2.0, reward))
|
| 210 |
+
# print('Extracted description: ', description_extraction)
|
| 211 |
+
# print('Generated answer: ', output_ans)
|
| 212 |
+
# print('Sol: ', gt_ans)
|
| 213 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 214 |
+
# print('-' * 10)
|
| 215 |
+
reward = description_pendat_reward + answer_pedant_reward
|
| 216 |
+
# elif question_type == "free-form":
|
| 217 |
+
# score = compute_rouge_score(gt_ans, output_ans)
|
| 218 |
+
# reward = max(0.0, min(1.0, score))
|
| 219 |
+
elif question_type == "regression":
|
| 220 |
+
gt_number = normalize_number(gt_ans)
|
| 221 |
+
out_number = normalize_number(output_ans)
|
| 222 |
+
if gt_number is None or out_number is None:
|
| 223 |
+
reward = 0.0
|
| 224 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 225 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 226 |
+
reward = 1 - rel_diff
|
| 227 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
|
| 228 |
+
# print('Extracted description: ', description_extraction)
|
| 229 |
+
# print('Generated answer: ', output_ans)
|
| 230 |
+
# print('Sol: ', gt_ans)
|
| 231 |
+
|
| 232 |
+
# description_reward = compute_math_score_single(description_extraction, gt_ans)
|
| 233 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 234 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 235 |
+
# print('-' * 10)
|
| 236 |
+
# reward = description_reward + answer_reward
|
| 237 |
+
reward = answer_reward
|
| 238 |
+
else:
|
| 239 |
+
print('Falling back to none rewards')
|
| 240 |
+
reward = 0.0
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 243 |
+
reward = 0.0
|
| 244 |
+
|
| 245 |
+
rewards.append(reward)
|
| 246 |
+
|
| 247 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 248 |
+
log_path = os.getenv("LOG_PATH")
|
| 249 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 250 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 251 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 252 |
+
f.write(f"Content: {content}\n")
|
| 253 |
+
f.write(f"Solution: {sol}\n")
|
| 254 |
+
|
| 255 |
+
return rewards
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def simple_format_reward(completions, **kwargs):
|
| 259 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 260 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 261 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 262 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 263 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 264 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
reward_funcs_registry = {
|
| 268 |
+
"accuracy": accuracy_reward,
|
| 269 |
+
"format": simple_format_reward,
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# SYSTEM_PROMPT = (
|
| 273 |
+
# "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 274 |
+
# "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 275 |
+
# "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 276 |
+
# "<think> reasoning process here </think><answer> answer here </answer>"
|
| 277 |
+
# )
|
| 278 |
+
|
| 279 |
+
SYSTEM_PROMPT = (
|
| 280 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 281 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 282 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 283 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 284 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def main(script_args, training_args, model_args):
|
| 289 |
+
# Get reward functions
|
| 290 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 291 |
+
|
| 292 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 293 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 294 |
+
else:
|
| 295 |
+
# Load the dataset
|
| 296 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Format into conversation
|
| 300 |
+
def make_conversation(example):
|
| 301 |
+
return {
|
| 302 |
+
"prompt": [
|
| 303 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 304 |
+
{"role": "user", "content": example["problem"]},
|
| 305 |
+
],
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
# QUESTION_TEMPLATE = (
|
| 310 |
+
# "{Question}\n"
|
| 311 |
+
# "Please think about this question as if you were a human pondering deeply. "
|
| 312 |
+
# "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 313 |
+
# "It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 314 |
+
# "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 315 |
+
# )
|
| 316 |
+
|
| 317 |
+
QUESTION_TEMPLATE = (
|
| 318 |
+
"{Question}\n"
|
| 319 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 320 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 321 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 322 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 323 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 324 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
TYPE_TEMPLATE = {
|
| 329 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 330 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 331 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 332 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 333 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 334 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
def make_conversation_image(example):
|
| 338 |
+
|
| 339 |
+
return {
|
| 340 |
+
"prompt": [
|
| 341 |
+
{
|
| 342 |
+
"role": "user",
|
| 343 |
+
"content": [
|
| 344 |
+
{"type": "image"},
|
| 345 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 346 |
+
],
|
| 347 |
+
},
|
| 348 |
+
],
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def make_conversation_video(example):
|
| 353 |
+
return {
|
| 354 |
+
"prompt": [
|
| 355 |
+
{
|
| 356 |
+
"role": "user",
|
| 357 |
+
"content": [
|
| 358 |
+
{"type": "video"},
|
| 359 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 360 |
+
],
|
| 361 |
+
},
|
| 362 |
+
],
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def make_conversation_image_and_video(example):
|
| 366 |
+
if example["problem_type"] == 'multiple choice':
|
| 367 |
+
question = example['problem'] + "Options:\n"
|
| 368 |
+
for op in example["options"]:
|
| 369 |
+
question += op + "\n"
|
| 370 |
+
else:
|
| 371 |
+
question = example['problem']
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# msg ={
|
| 375 |
+
# "prompt":
|
| 376 |
+
# [{
|
| 377 |
+
# "role": "user",
|
| 378 |
+
# "content": [
|
| 379 |
+
# {
|
| 380 |
+
# "type": example['data_type'],
|
| 381 |
+
# # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 382 |
+
# },
|
| 383 |
+
# {
|
| 384 |
+
# "type": "text",
|
| 385 |
+
# "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 386 |
+
# }
|
| 387 |
+
# ]
|
| 388 |
+
# }]
|
| 389 |
+
# }
|
| 390 |
+
|
| 391 |
+
msg ={
|
| 392 |
+
"prompt":
|
| 393 |
+
[{
|
| 394 |
+
"role": "user",
|
| 395 |
+
"content": [
|
| 396 |
+
{
|
| 397 |
+
"type": example['data_type'],
|
| 398 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"type": "text",
|
| 402 |
+
"text": QUESTION_TEMPLATE.format(Question=question)
|
| 403 |
+
}
|
| 404 |
+
]
|
| 405 |
+
}]
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
return msg
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 415 |
+
print("using: ", trainer_cls)
|
| 416 |
+
|
| 417 |
+
# Initialize the GRPO trainer
|
| 418 |
+
trainer = trainer_cls(
|
| 419 |
+
model=model_args.model_name_or_path,
|
| 420 |
+
reward_funcs=reward_funcs,
|
| 421 |
+
args=training_args,
|
| 422 |
+
script_args=script_args,
|
| 423 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 424 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 425 |
+
peft_config=get_peft_config(model_args),
|
| 426 |
+
attn_implementation=model_args.attn_implementation,
|
| 427 |
+
max_pixels=script_args.max_pixels,
|
| 428 |
+
min_pixels=script_args.min_pixels,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
if training_args.resume_from_checkpoint is not None:
|
| 432 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 433 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 434 |
+
else:
|
| 435 |
+
trainer.train()
|
| 436 |
+
|
| 437 |
+
# Save and push to hub
|
| 438 |
+
trainer.save_model(training_args.output_dir)
|
| 439 |
+
if training_args.push_to_hub:
|
| 440 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
if __name__ == "__main__":
|
| 444 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 445 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 446 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-noInfo.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from datasets import load_dataset, load_from_disk
|
| 22 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 23 |
+
|
| 24 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
|
| 25 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 26 |
+
|
| 27 |
+
from datasets import Dataset, DatasetDict
|
| 28 |
+
|
| 29 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 30 |
+
from rouge_score import rouge_scorer
|
| 31 |
+
from utils.math_cot_noInfo import *
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 36 |
+
"""
|
| 37 |
+
Script arguments for the GRPO training script.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
reward_funcs (`list[str]`):
|
| 41 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
reward_funcs: list[str] = field(
|
| 45 |
+
default_factory=lambda: ["accuracy"],
|
| 46 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 47 |
+
)
|
| 48 |
+
max_pixels: Optional[int] = field(
|
| 49 |
+
default=12845056,
|
| 50 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 51 |
+
)
|
| 52 |
+
min_pixels: Optional[int] = field(
|
| 53 |
+
default=3136,
|
| 54 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 55 |
+
)
|
| 56 |
+
temporal: Optional[bool] = field(
|
| 57 |
+
default=True,
|
| 58 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 59 |
+
)
|
| 60 |
+
len_control: Optional[bool] = field(
|
| 61 |
+
default=True,
|
| 62 |
+
metadata={"help": "whether using length reward"},
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 68 |
+
|
| 69 |
+
def extract_answer(text):
|
| 70 |
+
pattern = r'<answer>\s*(.*?)\s*</answer>'
|
| 71 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 72 |
+
if match:
|
| 73 |
+
return match.group(1).strip()
|
| 74 |
+
return ""
|
| 75 |
+
|
| 76 |
+
def normalize_number(num_str):
|
| 77 |
+
try:
|
| 78 |
+
num_str = num_str.replace(',', '')
|
| 79 |
+
return float(num_str)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
def wer(reference, hypothesis):
|
| 85 |
+
ref_words = reference.split()
|
| 86 |
+
hyp_words = hypothesis.split()
|
| 87 |
+
m = len(ref_words)
|
| 88 |
+
n = len(hyp_words)
|
| 89 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 90 |
+
for i in range(m+1):
|
| 91 |
+
d[i][0] = i
|
| 92 |
+
for j in range(n+1):
|
| 93 |
+
d[0][j] = j
|
| 94 |
+
for i in range(1, m+1):
|
| 95 |
+
for j in range(1, n+1):
|
| 96 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 97 |
+
d[i][j] = d[i-1][j-1]
|
| 98 |
+
else:
|
| 99 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 100 |
+
return d[m][n] / max(1, m)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 104 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 105 |
+
scores = scorer.score(reference, hypothesis)
|
| 106 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 107 |
+
return average_fmeasure
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
question_type = kwargs['problem_type'][0]
|
| 111 |
+
|
| 112 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 113 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 114 |
+
rewards = []
|
| 115 |
+
|
| 116 |
+
for content, sol in zip(contents, solution):
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
output_ans = extract_answer(content)
|
| 120 |
+
gt_ans = extract_answer(sol)
|
| 121 |
+
if question_type == "multiple choice":
|
| 122 |
+
reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 123 |
+
elif question_type == "numerical":
|
| 124 |
+
gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 125 |
+
out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 126 |
+
if gt_has_decimal != out_has_decimal:
|
| 127 |
+
reward = 0.0
|
| 128 |
+
else:
|
| 129 |
+
gt_number = normalize_number(gt_ans)
|
| 130 |
+
out_number = normalize_number(output_ans)
|
| 131 |
+
if gt_number is None or out_number is None:
|
| 132 |
+
reward = 0.0
|
| 133 |
+
else:
|
| 134 |
+
reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 135 |
+
elif question_type == "OCR":
|
| 136 |
+
error_rate = wer(gt_ans, output_ans)
|
| 137 |
+
reward = 1 - error_rate
|
| 138 |
+
reward = max(0.0, min(1.0, reward))
|
| 139 |
+
elif question_type == "free-form":
|
| 140 |
+
score = compute_rouge_score(gt_ans, output_ans)
|
| 141 |
+
reward = max(0.0, min(1.0, score))
|
| 142 |
+
elif question_type == "regression":
|
| 143 |
+
gt_number = normalize_number(gt_ans)
|
| 144 |
+
out_number = normalize_number(output_ans)
|
| 145 |
+
if gt_number is None or out_number is None:
|
| 146 |
+
reward = 0.0
|
| 147 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 148 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 149 |
+
reward = 1 - rel_diff
|
| 150 |
+
elif question_type == 'math':
|
| 151 |
+
reward = compute_math_score_single(content, gt_ans)
|
| 152 |
+
else:
|
| 153 |
+
print('Falling back to none rewards')
|
| 154 |
+
reward = 0.0
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 157 |
+
reward = 0.0
|
| 158 |
+
|
| 159 |
+
rewards.append(reward)
|
| 160 |
+
|
| 161 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 162 |
+
log_path = os.getenv("LOG_PATH")
|
| 163 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 164 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 165 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 166 |
+
f.write(f"Content: {content}\n")
|
| 167 |
+
f.write(f"Solution: {sol}\n")
|
| 168 |
+
|
| 169 |
+
return rewards
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def format_reward(completions, **kwargs):
|
| 173 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 174 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 175 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 176 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 177 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
reward_funcs_registry = {
|
| 181 |
+
"accuracy": accuracy_reward,
|
| 182 |
+
# "format": format_reward,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
SYSTEM_PROMPT = (
|
| 186 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 187 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 188 |
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 189 |
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# SYSTEM_PROMPT = (
|
| 193 |
+
# "A conversation between User and Assistant. The user provides a question about an image, "
|
| 194 |
+
# "and the Assistant is tasked with generating an exhaustive and detailed description of the image. "
|
| 195 |
+
# "The assistant should extract and describe all possible information from the image—including objects, numbers, text, and their relationships—"
|
| 196 |
+
# "and enclose this description within <info> </info> tags. "
|
| 197 |
+
# "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 198 |
+
# "and provide this step-by-step reasoning within <think> </think> tags. "
|
| 199 |
+
# "Finally, the assistant provides a single word or phrase answer within <answer> </answer> tags. "
|
| 200 |
+
# "The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
|
| 201 |
+
# )
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def main(script_args, training_args, model_args):
|
| 205 |
+
# Get reward functions
|
| 206 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 207 |
+
|
| 208 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 209 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 210 |
+
else:
|
| 211 |
+
# Load the dataset
|
| 212 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Format into conversation
|
| 216 |
+
def make_conversation(example):
|
| 217 |
+
return {
|
| 218 |
+
"prompt": [
|
| 219 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 220 |
+
{"role": "user", "content": example["problem"]},
|
| 221 |
+
],
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
QUESTION_TEMPLATE = (
|
| 226 |
+
"{Question}\n"
|
| 227 |
+
"Please think about this question as if you were a human pondering deeply. "
|
| 228 |
+
"Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 229 |
+
"It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 230 |
+
"Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# QUESTION_TEMPLATE = (
|
| 234 |
+
# "{Question}\n"
|
| 235 |
+
# "You are tasked with analyzing an image to generate an exhaustive and detailed description. "
|
| 236 |
+
# "Your goal is to extract and describe all possible information from the image, including but not limited to objects, numbers, text, and the relationships between these elements. "
|
| 237 |
+
# "The description should be as fine and detailed as possible, capturing every nuance, and should be enclosed within <info> </info> tags.\n"
|
| 238 |
+
# "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 239 |
+
# "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 240 |
+
# "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 241 |
+
# "The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
|
| 242 |
+
# )
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
TYPE_TEMPLATE = {
|
| 246 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 247 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 248 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 249 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 250 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 251 |
+
"math": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
def make_conversation_image(example):
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"prompt": [
|
| 258 |
+
{
|
| 259 |
+
"role": "user",
|
| 260 |
+
"content": [
|
| 261 |
+
{"type": "image"},
|
| 262 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 263 |
+
],
|
| 264 |
+
},
|
| 265 |
+
],
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def make_conversation_video(example):
|
| 270 |
+
return {
|
| 271 |
+
"prompt": [
|
| 272 |
+
{
|
| 273 |
+
"role": "user",
|
| 274 |
+
"content": [
|
| 275 |
+
{"type": "video"},
|
| 276 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 277 |
+
],
|
| 278 |
+
},
|
| 279 |
+
],
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
def make_conversation_image_and_video(example):
|
| 283 |
+
if example["problem_type"] == 'multiple choice':
|
| 284 |
+
question = example['problem'] + "Options:\n"
|
| 285 |
+
for op in example["options"]:
|
| 286 |
+
question += op + "\n"
|
| 287 |
+
else:
|
| 288 |
+
question = example['problem']
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
msg ={
|
| 292 |
+
"prompt":
|
| 293 |
+
[{
|
| 294 |
+
"role": "user",
|
| 295 |
+
"content": [
|
| 296 |
+
{
|
| 297 |
+
"type": example['data_type'],
|
| 298 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"type": "text",
|
| 302 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 303 |
+
}
|
| 304 |
+
]
|
| 305 |
+
}]
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
return msg
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
|
| 315 |
+
print("using: ", trainer_cls)
|
| 316 |
+
|
| 317 |
+
# Initialize the GRPO trainer
|
| 318 |
+
trainer = trainer_cls(
|
| 319 |
+
model=model_args.model_name_or_path,
|
| 320 |
+
reward_funcs=reward_funcs,
|
| 321 |
+
args=training_args,
|
| 322 |
+
script_args=script_args,
|
| 323 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 324 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 325 |
+
peft_config=get_peft_config(model_args),
|
| 326 |
+
attn_implementation=model_args.attn_implementation,
|
| 327 |
+
max_pixels=script_args.max_pixels,
|
| 328 |
+
min_pixels=script_args.min_pixels,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if training_args.resume_from_checkpoint is not None:
|
| 332 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 333 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 334 |
+
else:
|
| 335 |
+
trainer.train()
|
| 336 |
+
|
| 337 |
+
# Save and push to hub
|
| 338 |
+
trainer.save_model(training_args.output_dir)
|
| 339 |
+
if training_args.push_to_hub:
|
| 340 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 345 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 346 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-qwenEval.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
import ray
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
|
| 21 |
+
from datasets import load_dataset, load_from_disk
|
| 22 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 23 |
+
|
| 24 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 25 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 26 |
+
|
| 27 |
+
from datasets import Dataset, DatasetDict
|
| 28 |
+
|
| 29 |
+
from typing import Dict, List, Optional
|
| 30 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 31 |
+
|
| 32 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 33 |
+
from rouge_score import rouge_scorer
|
| 34 |
+
import torch
|
| 35 |
+
# from utils.gpt_eval import infer
|
| 36 |
+
# from utils.math_cot import *
|
| 37 |
+
from qa_metrics.pedant import PEDANT
|
| 38 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 39 |
+
import os, subprocess, sys
|
| 40 |
+
# from qa_metrics.answerBERT import AnswerBertActor
|
| 41 |
+
# from utils.self_eval import *
|
| 42 |
+
from vllm import LLM, SamplingParams
|
| 43 |
+
|
| 44 |
+
pedant = None
|
| 45 |
+
# answerBERT = AnswerBertActor(device='cuda:7')
|
| 46 |
+
|
| 47 |
+
# curr_actor = VllmActor.options(num_gpus=1).remote("Qwen/Qwen2.5-3B-Instruct")
|
| 48 |
+
|
| 49 |
+
from typing import List
|
| 50 |
+
import os
|
| 51 |
+
import ray, os, subprocess, torch.distributed as dist
|
| 52 |
+
|
| 53 |
+
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
|
| 54 |
+
MAX_LEN = 32_768
|
| 55 |
+
RAY_NS = "grpo_qwen_vllm"
|
| 56 |
+
RAY_TMP = "/tmp/ray"
|
| 57 |
+
|
| 58 |
+
# ------------------------------------------------------------
|
| 59 |
+
# 1. Define the Ray actor class *before* we ever create it
|
| 60 |
+
# (Ray just needs to see the decorator; it doesn’t need an
|
| 61 |
+
# active cluster at definition time)
|
| 62 |
+
# ------------------------------------------------------------
|
| 63 |
+
@ray.remote(num_gpus=1,resources={"gpu_7": 1})
|
| 64 |
+
class VllmActor:
|
| 65 |
+
def __init__(self, model_id):
|
| 66 |
+
self.engine = LLM(
|
| 67 |
+
model_id,
|
| 68 |
+
tensor_parallel_size=1,
|
| 69 |
+
gpu_memory_utilization=0.80,
|
| 70 |
+
max_model_len=MAX_LEN,
|
| 71 |
+
trust_remote_code=True,
|
| 72 |
+
dtype="bfloat16",
|
| 73 |
+
)
|
| 74 |
+
self.default = SamplingParams(top_p=0.9, temperature=0.7, max_tokens=128)
|
| 75 |
+
|
| 76 |
+
def generate_batch(self, prompts, sampling=None):
|
| 77 |
+
outs = self.engine.generate(prompts, sampling_params=sampling or self.default)
|
| 78 |
+
return [o.outputs[0].text for o in outs]
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------
|
| 81 |
+
# 2. Torch-DDP initialisation
|
| 82 |
+
# ------------------------------------------------------------
|
| 83 |
+
dist.init_process_group("nccl")
|
| 84 |
+
rank = dist.get_rank()
|
| 85 |
+
|
| 86 |
+
# ------------------------------------------------------------
|
| 87 |
+
# 3. Rank-0 starts the Ray head, others wait
|
| 88 |
+
# ------------------------------------------------------------
|
| 89 |
+
if rank == 1:
|
| 90 |
+
ray.init(
|
| 91 |
+
_temp_dir=RAY_TMP,
|
| 92 |
+
object_store_memory=1 * 1024**3,
|
| 93 |
+
namespace=RAY_NS,
|
| 94 |
+
include_dashboard=False,
|
| 95 |
+
resources={"gpu_7": 1}
|
| 96 |
+
)
|
| 97 |
+
# optional: confirm the head is up
|
| 98 |
+
# from ray._private.internal_api import wait_for_gcs
|
| 99 |
+
# wait_for_gcs()
|
| 100 |
+
dist.barrier() # ---- head definitely running here ----
|
| 101 |
+
|
| 102 |
+
# ------------------------------------------------------------
|
| 103 |
+
# 4. Non-zero ranks attach to the head
|
| 104 |
+
# ------------------------------------------------------------
|
| 105 |
+
if rank != 0:
|
| 106 |
+
ray.init(address="auto", _temp_dir=RAY_TMP, namespace=RAY_NS)
|
| 107 |
+
|
| 108 |
+
dist.barrier() # ---- every rank now in the cluster ----
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------
|
| 111 |
+
# 5. Create / look-up the VllmActor
|
| 112 |
+
# ------------------------------------------------------------
|
| 113 |
+
if rank == 1:
|
| 114 |
+
vllm_actor = (
|
| 115 |
+
VllmActor.options(name="vllm", namespace=RAY_NS, lifetime="detached")
|
| 116 |
+
.remote(MODEL_ID)
|
| 117 |
+
)
|
| 118 |
+
# block until the model finishes loading so other ranks don’t race
|
| 119 |
+
ray.get(vllm_actor.generate_batch.remote(["ping"]))
|
| 120 |
+
dist.barrier() # ---- actor fully alive everywhere ----
|
| 121 |
+
|
| 122 |
+
if rank != 0:
|
| 123 |
+
vllm_actor = ray.get_actor("vllm", namespace=RAY_NS)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
eval_prompt_template = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. Provide your answer as a single final answer or a short phrase enclosed with <answer></answer>. If the question is a multiple choice, the final answer should be a single letter choice. \nText description: {}\nQuestion: {}'''
|
| 127 |
+
|
| 128 |
+
@dataclass
|
| 129 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 130 |
+
"""
|
| 131 |
+
Script arguments for the GRPO training script.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
reward_funcs (`list[str]`):
|
| 135 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
reward_funcs: list[str] = field(
|
| 139 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 140 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# reward_funcs: list[str] = field(
|
| 144 |
+
# default_factory=lambda: ["accuracy"],
|
| 145 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 146 |
+
# )
|
| 147 |
+
max_pixels: Optional[int] = field(
|
| 148 |
+
default=12845056,
|
| 149 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 150 |
+
)
|
| 151 |
+
min_pixels: Optional[int] = field(
|
| 152 |
+
default=3136,
|
| 153 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 154 |
+
)
|
| 155 |
+
temporal: Optional[bool] = field(
|
| 156 |
+
default=True,
|
| 157 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 158 |
+
)
|
| 159 |
+
len_control: Optional[bool] = field(
|
| 160 |
+
default=True,
|
| 161 |
+
metadata={"help": "whether using length reward"},
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 166 |
+
def extract_answer(text: str) -> str:
|
| 167 |
+
"""
|
| 168 |
+
1) Try the full <answer> … </answer> block.
|
| 169 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 170 |
+
3) Otherwise return the original text.
|
| 171 |
+
"""
|
| 172 |
+
# ① normal case <answer> … </answer>
|
| 173 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 174 |
+
if m:
|
| 175 |
+
return m.group(1).strip()
|
| 176 |
+
|
| 177 |
+
# ② fallback <answer> … <end-of-string>
|
| 178 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 179 |
+
if m:
|
| 180 |
+
return m.group(1).strip()
|
| 181 |
+
|
| 182 |
+
# ③ nothing found
|
| 183 |
+
return text.strip()
|
| 184 |
+
|
| 185 |
+
def extract_description(predict: str) -> Optional[str]:
|
| 186 |
+
"""
|
| 187 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 188 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 189 |
+
or None if no <answer> tag is found.
|
| 190 |
+
"""
|
| 191 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 192 |
+
if not match:
|
| 193 |
+
return predict
|
| 194 |
+
return match.group(1).strip()
|
| 195 |
+
|
| 196 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 197 |
+
answer = predict
|
| 198 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 199 |
+
|
| 200 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 201 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 202 |
+
# format_score = format_reward(predict)
|
| 203 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 204 |
+
|
| 205 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 206 |
+
return accuracy_score
|
| 207 |
+
|
| 208 |
+
def normalize_number(num_str):
|
| 209 |
+
try:
|
| 210 |
+
num_str = num_str.replace(',', '')
|
| 211 |
+
return float(num_str)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
def wer(reference, hypothesis):
|
| 217 |
+
ref_words = reference.split()
|
| 218 |
+
hyp_words = hypothesis.split()
|
| 219 |
+
m = len(ref_words)
|
| 220 |
+
n = len(hyp_words)
|
| 221 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 222 |
+
for i in range(m+1):
|
| 223 |
+
d[i][0] = i
|
| 224 |
+
for j in range(n+1):
|
| 225 |
+
d[0][j] = j
|
| 226 |
+
for i in range(1, m+1):
|
| 227 |
+
for j in range(1, n+1):
|
| 228 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 229 |
+
d[i][j] = d[i-1][j-1]
|
| 230 |
+
else:
|
| 231 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 232 |
+
return d[m][n] / max(1, m)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 236 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 237 |
+
scores = scorer.score(reference, hypothesis)
|
| 238 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 239 |
+
return average_fmeasure
|
| 240 |
+
|
| 241 |
+
# print('Computing rewards now...')
|
| 242 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 243 |
+
# second_completions = kwargs.get("second_completions")
|
| 244 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 245 |
+
# print('second prompts', second_prompts)
|
| 246 |
+
# print('-'*10)
|
| 247 |
+
# print('second completions', second_completions)
|
| 248 |
+
# print('-'*10)
|
| 249 |
+
|
| 250 |
+
# import time
|
| 251 |
+
# time.sleep(30)
|
| 252 |
+
question_type = kwargs['problem_type'][0]
|
| 253 |
+
questions = kwargs['problem']
|
| 254 |
+
|
| 255 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 256 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 257 |
+
rewards = []
|
| 258 |
+
|
| 259 |
+
extracted_content_descriptions = [extract_description(ele) for ele in contents]
|
| 260 |
+
description_eval_inputs = [eval_prompt_template.format(extracted_content_descriptions[count_index], questions[count_index]) for count_index in range(len(extracted_content_descriptions))]
|
| 261 |
+
# extracted_content_answers = [extract_answer(ele) for ele in contents]
|
| 262 |
+
# model = kwargs.get("model") # may be None if called elsewhere
|
| 263 |
+
# tokenizer = kwargs.get("tokenizer")
|
| 264 |
+
# # (optional) example use: let the model score the generated answer
|
| 265 |
+
# if model is not None and tokenizer is not None:
|
| 266 |
+
# model.eval()
|
| 267 |
+
# description_inputs = [questions[index_count] + ' [SEP] ' + extracted_content_descriptions[index_count] for index_count in range(len(extracted_content_descriptions))]
|
| 268 |
+
# description_rewards = answerBERT.batch_predict(description_inputs, batch_size = 64)
|
| 269 |
+
# description_rewards = [infer(extracted_content_descriptions[index_count], questions[index_count]) for index_count in range(len(extracted_content_descriptions))]
|
| 270 |
+
# description_outputs = generate_batch(description_eval_inputs)
|
| 271 |
+
print(len(description_eval_inputs))
|
| 272 |
+
print('Computing rewards...')
|
| 273 |
+
print('-'*10)
|
| 274 |
+
# description_outputs = ray.get(vllm_actor.generate.remote(description_eval_inputs))
|
| 275 |
+
description_outputs = ray.get(
|
| 276 |
+
vllm_actor.generate_batch_sequential.remote(description_eval_inputs,
|
| 277 |
+
batch_size=32) # tune to taste
|
| 278 |
+
)
|
| 279 |
+
print('Finish computing generating batch')
|
| 280 |
+
output_answers = [extract_answer(content) for content in contents]
|
| 281 |
+
gt_answers = [extract_answer(sol) for sol in solution]
|
| 282 |
+
description_rewards = [compute_math_score_single(description_outputs[curr_idx], gt_answers[curr_idx]) for curr_idx in range(len(description_outputs))]
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# for content, sol, description_reward in zip(contents, solution, description_rewards):
|
| 288 |
+
# for content, sol, question in zip(contents, solution, questions):
|
| 289 |
+
# for content, sol, second_content in zip(contents, solution, second_completions):
|
| 290 |
+
for output_ans, gt_ans, description_reward in zip(output_answers, gt_answers, description_rewards):
|
| 291 |
+
try:
|
| 292 |
+
# output_ans = extract_answer(content)
|
| 293 |
+
# gt_ans = extract_answer(sol)
|
| 294 |
+
# description_extraction = extract_answer(second_content)
|
| 295 |
+
# if question_type == "multiple choice":
|
| 296 |
+
# reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 297 |
+
# elif question_type == "numerical":
|
| 298 |
+
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 299 |
+
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 300 |
+
# if gt_has_decimal != out_has_decimal:
|
| 301 |
+
# reward = 0.0
|
| 302 |
+
# else:
|
| 303 |
+
# gt_number = normalize_number(gt_ans)
|
| 304 |
+
# out_number = normalize_number(output_ans)
|
| 305 |
+
# if gt_number is None or out_number is None:
|
| 306 |
+
# reward = 0.0
|
| 307 |
+
# else:
|
| 308 |
+
# reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 309 |
+
if question_type == "OCR":
|
| 310 |
+
# description_extraction = extract_answer(second_content)
|
| 311 |
+
# description_error_rate = wer(gt_ans, description_extraction)
|
| 312 |
+
# description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 313 |
+
# error_rate = wer(gt_ans, output_ans)
|
| 314 |
+
answer_pedant_reward = pedant.get_score(gt_ans, output_ans, questions[0])
|
| 315 |
+
# reward = (1 - error_rate) + (1- description_error_rate)
|
| 316 |
+
# reward = max(0.0, min(2.0, reward))
|
| 317 |
+
# print('Extracted description: ', description_extraction)
|
| 318 |
+
print('Generated answer: ', output_ans)
|
| 319 |
+
print('Sol: ', gt_ans)
|
| 320 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 321 |
+
print('-' * 10)
|
| 322 |
+
# reward = description_pendat_reward + answer_pedant_reward
|
| 323 |
+
reward = answer_pedant_reward
|
| 324 |
+
# elif question_type == "free-form":
|
| 325 |
+
# score = compute_rouge_score(gt_ans, output_ans)
|
| 326 |
+
# reward = max(0.0, min(1.0, score))
|
| 327 |
+
elif question_type == "regression":
|
| 328 |
+
gt_number = normalize_number(gt_ans)
|
| 329 |
+
out_number = normalize_number(output_ans)
|
| 330 |
+
if gt_number is None or out_number is None:
|
| 331 |
+
reward = 0.0
|
| 332 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 333 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 334 |
+
reward = 1 - rel_diff
|
| 335 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
|
| 336 |
+
# print('Extracted description: ', description_extraction)
|
| 337 |
+
print('Generated answer: ', output_ans)
|
| 338 |
+
print('Sol: ', gt_ans)
|
| 339 |
+
|
| 340 |
+
# description_reward = compute_math_score_single(description_extraction, gt_ans)
|
| 341 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 342 |
+
print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 343 |
+
print('-' * 10)
|
| 344 |
+
reward = description_reward + answer_reward
|
| 345 |
+
else:
|
| 346 |
+
print('Falling back to none rewards')
|
| 347 |
+
reward = 0.0
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 350 |
+
reward = 0.0
|
| 351 |
+
|
| 352 |
+
rewards.append(reward)
|
| 353 |
+
|
| 354 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 355 |
+
log_path = os.getenv("LOG_PATH")
|
| 356 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 357 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 358 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 359 |
+
f.write(f"Content: {output_ans}\n")
|
| 360 |
+
f.write(f"Solution: {gt_ans}\n")
|
| 361 |
+
|
| 362 |
+
return rewards
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def simple_format_reward(completions, **kwargs):
|
| 366 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 367 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 368 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 369 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 370 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 371 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
reward_funcs_registry = {
|
| 375 |
+
"accuracy": accuracy_reward,
|
| 376 |
+
"format": simple_format_reward,
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
SYSTEM_PROMPT = (
|
| 381 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 382 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 383 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 384 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 385 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def main(script_args, training_args, model_args):
|
| 390 |
+
# Get reward functions
|
| 391 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 392 |
+
|
| 393 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 394 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 395 |
+
else:
|
| 396 |
+
# Load the dataset
|
| 397 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# Format into conversation
|
| 401 |
+
def make_conversation(example):
|
| 402 |
+
return {
|
| 403 |
+
"prompt": [
|
| 404 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 405 |
+
{"role": "user", "content": example["problem"]},
|
| 406 |
+
],
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
QUESTION_TEMPLATE = (
|
| 412 |
+
"{Question}\n"
|
| 413 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 414 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 415 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 416 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 417 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 418 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
TYPE_TEMPLATE = {
|
| 423 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 424 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 425 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 426 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 427 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 428 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
def make_conversation_image(example):
|
| 432 |
+
|
| 433 |
+
return {
|
| 434 |
+
"prompt": [
|
| 435 |
+
{
|
| 436 |
+
"role": "user",
|
| 437 |
+
"content": [
|
| 438 |
+
{"type": "image"},
|
| 439 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 440 |
+
],
|
| 441 |
+
},
|
| 442 |
+
],
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def make_conversation_video(example):
|
| 447 |
+
return {
|
| 448 |
+
"prompt": [
|
| 449 |
+
{
|
| 450 |
+
"role": "user",
|
| 451 |
+
"content": [
|
| 452 |
+
{"type": "video"},
|
| 453 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 454 |
+
],
|
| 455 |
+
},
|
| 456 |
+
],
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
def make_conversation_image_and_video(example):
|
| 460 |
+
if example["problem_type"] == 'multiple choice':
|
| 461 |
+
question = example['problem'] + "Options:\n"
|
| 462 |
+
for op in example["options"]:
|
| 463 |
+
question += op + "\n"
|
| 464 |
+
else:
|
| 465 |
+
question = example['problem']
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
msg ={
|
| 469 |
+
"prompt":
|
| 470 |
+
[{
|
| 471 |
+
"role": "user",
|
| 472 |
+
"content": [
|
| 473 |
+
{
|
| 474 |
+
"type": example['data_type'],
|
| 475 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 476 |
+
},
|
| 477 |
+
{
|
| 478 |
+
"type": "text",
|
| 479 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 480 |
+
}
|
| 481 |
+
]
|
| 482 |
+
}]
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
return msg
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 492 |
+
print("using: ", trainer_cls)
|
| 493 |
+
|
| 494 |
+
# Initialize the GRPO trainer
|
| 495 |
+
trainer = trainer_cls(
|
| 496 |
+
model=model_args.model_name_or_path,
|
| 497 |
+
reward_funcs=reward_funcs,
|
| 498 |
+
args=training_args,
|
| 499 |
+
script_args=script_args,
|
| 500 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 501 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 502 |
+
peft_config=get_peft_config(model_args),
|
| 503 |
+
attn_implementation=model_args.attn_implementation,
|
| 504 |
+
max_pixels=script_args.max_pixels,
|
| 505 |
+
min_pixels=script_args.min_pixels,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
if training_args.resume_from_checkpoint is not None:
|
| 509 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 510 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 511 |
+
else:
|
| 512 |
+
trainer.train()
|
| 513 |
+
|
| 514 |
+
# Save and push to hub
|
| 515 |
+
trainer.save_model(training_args.output_dir)
|
| 516 |
+
if training_args.push_to_hub:
|
| 517 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
if __name__ == "__main__":
|
| 521 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 522 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 523 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-selfEval.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
# from utils.math_cot import *
|
| 34 |
+
# from qa_metrics.pedant import PEDANT
|
| 35 |
+
|
| 36 |
+
# pedant = PEDANT()
|
| 37 |
+
|
| 38 |
+
'''
|
| 39 |
+
Alpha constant: When the description is wrong, but the final answer is right, the model is doing reward hacking,
|
| 40 |
+
so we give it a partial reward
|
| 41 |
+
'''
|
| 42 |
+
alpha = 1.0
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 46 |
+
"""
|
| 47 |
+
Script arguments for the GRPO training script.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
reward_funcs (`list[str]`):
|
| 51 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
reward_funcs: list[str] = field(
|
| 55 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 56 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# reward_funcs: list[str] = field(
|
| 60 |
+
# default_factory=lambda: ["accuracy"],
|
| 61 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 62 |
+
# )
|
| 63 |
+
max_pixels: Optional[int] = field(
|
| 64 |
+
default=12845056,
|
| 65 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 66 |
+
)
|
| 67 |
+
min_pixels: Optional[int] = field(
|
| 68 |
+
default=3136,
|
| 69 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 70 |
+
)
|
| 71 |
+
temporal: Optional[bool] = field(
|
| 72 |
+
default=True,
|
| 73 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 74 |
+
)
|
| 75 |
+
len_control: Optional[bool] = field(
|
| 76 |
+
default=True,
|
| 77 |
+
metadata={"help": "whether using length reward"},
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 83 |
+
def extract_answer(text: str) -> str:
|
| 84 |
+
"""
|
| 85 |
+
1) Try the full <answer> … </answer> block.
|
| 86 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 87 |
+
3) Otherwise return the original text.
|
| 88 |
+
"""
|
| 89 |
+
# ① normal case <answer> … </answer>
|
| 90 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 91 |
+
if m:
|
| 92 |
+
return m.group(1).strip()
|
| 93 |
+
|
| 94 |
+
# ② fallback <answer> … <end-of-string>
|
| 95 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 96 |
+
if m:
|
| 97 |
+
return m.group(1).strip()
|
| 98 |
+
|
| 99 |
+
# ③ nothing found
|
| 100 |
+
return text.strip()
|
| 101 |
+
|
| 102 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 103 |
+
answer = predict
|
| 104 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 105 |
+
|
| 106 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 107 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 108 |
+
# format_score = format_reward(predict)
|
| 109 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 110 |
+
|
| 111 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 112 |
+
return accuracy_score
|
| 113 |
+
|
| 114 |
+
def normalize_number(num_str):
|
| 115 |
+
try:
|
| 116 |
+
num_str = num_str.replace(',', '')
|
| 117 |
+
return float(num_str)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def wer(reference, hypothesis):
|
| 123 |
+
ref_words = reference.split()
|
| 124 |
+
hyp_words = hypothesis.split()
|
| 125 |
+
m = len(ref_words)
|
| 126 |
+
n = len(hyp_words)
|
| 127 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 128 |
+
for i in range(m+1):
|
| 129 |
+
d[i][0] = i
|
| 130 |
+
for j in range(n+1):
|
| 131 |
+
d[0][j] = j
|
| 132 |
+
for i in range(1, m+1):
|
| 133 |
+
for j in range(1, n+1):
|
| 134 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 135 |
+
d[i][j] = d[i-1][j-1]
|
| 136 |
+
else:
|
| 137 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 138 |
+
return d[m][n] / max(1, m)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 142 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 143 |
+
scores = scorer.score(reference, hypothesis)
|
| 144 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 145 |
+
return average_fmeasure
|
| 146 |
+
|
| 147 |
+
# print('Computing rewards now...')
|
| 148 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 149 |
+
second_completions = kwargs.get("second_completions")
|
| 150 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 151 |
+
# print('second prompts', second_prompts)
|
| 152 |
+
# print('-'*10)
|
| 153 |
+
# print('second completions', second_completions)
|
| 154 |
+
# print('-'*10)
|
| 155 |
+
|
| 156 |
+
# import time
|
| 157 |
+
# time.sleep(30)
|
| 158 |
+
question_type = kwargs['problem_type'][0]
|
| 159 |
+
question = kwargs['problem'][0]
|
| 160 |
+
|
| 161 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 162 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 163 |
+
rewards = []
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# model = kwargs.get("model") # may be None if called elsewhere
|
| 167 |
+
# tokenizer = kwargs.get("tokenizer")
|
| 168 |
+
|
| 169 |
+
# # (optional) example use: let the model score the generated answer
|
| 170 |
+
# if model is not None and tokenizer is not None:
|
| 171 |
+
# model.eval()
|
| 172 |
+
|
| 173 |
+
# for content, sol in zip(contents, solution):
|
| 174 |
+
for content, sol, second_content in zip(contents, solution, second_completions):
|
| 175 |
+
try:
|
| 176 |
+
output_ans = extract_answer(content)
|
| 177 |
+
gt_ans = extract_answer(sol)
|
| 178 |
+
description_extraction = extract_answer(second_content)
|
| 179 |
+
# if question_type == "multiple choice":
|
| 180 |
+
# reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 181 |
+
# elif question_type == "numerical":
|
| 182 |
+
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 183 |
+
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 184 |
+
# if gt_has_decimal != out_has_decimal:
|
| 185 |
+
# reward = 0.0
|
| 186 |
+
# else:
|
| 187 |
+
# gt_number = normalize_number(gt_ans)
|
| 188 |
+
# out_number = normalize_number(output_ans)
|
| 189 |
+
# if gt_number is None or out_number is None:
|
| 190 |
+
# reward = 0.0
|
| 191 |
+
# else:
|
| 192 |
+
# reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 193 |
+
# if question_type == "OCR":
|
| 194 |
+
# # description_extraction = extract_answer(second_content)
|
| 195 |
+
# # description_error_rate = wer(gt_ans, description_extraction)
|
| 196 |
+
# description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 197 |
+
# # error_rate = wer(gt_ans, output_ans)
|
| 198 |
+
# answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
|
| 199 |
+
# # reward = (1 - error_rate) + (1- description_error_rate)
|
| 200 |
+
# # reward = max(0.0, min(2.0, reward))
|
| 201 |
+
# print('Extracted description: ', description_extraction)
|
| 202 |
+
# print('Generated answer: ', output_ans)
|
| 203 |
+
# print('Sol: ', gt_ans)
|
| 204 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 205 |
+
# print('-' * 10)
|
| 206 |
+
# reward = description_pendat_reward + answer_pedant_reward
|
| 207 |
+
if question_type == "free-form":
|
| 208 |
+
score = compute_rouge_score(gt_ans, output_ans)
|
| 209 |
+
description_score = compute_rouge_score(gt_ans, description_extraction)
|
| 210 |
+
reward = max(0.0, min(1.0, score)) + max(0.0, min(1.0, description_score))
|
| 211 |
+
elif question_type == "regression":
|
| 212 |
+
gt_number = normalize_number(gt_ans)
|
| 213 |
+
out_number = normalize_number(output_ans)
|
| 214 |
+
description_number = normalize_number(description_extraction)
|
| 215 |
+
if gt_number is None or out_number is None:
|
| 216 |
+
reward = 0.0
|
| 217 |
+
|
| 218 |
+
if description_number is None:
|
| 219 |
+
description_reward = 0.0
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 223 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 224 |
+
|
| 225 |
+
description_diff = (abs(description_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 226 |
+
description_diff = min(1.0, max(0.0, description_diff))
|
| 227 |
+
|
| 228 |
+
reward = (1 - rel_diff) + (1 - description_diff)
|
| 229 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == 'multiple choice' or question_type == 'numerical':
|
| 230 |
+
description_reward = compute_math_score_single(description_extraction, gt_ans)
|
| 231 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 232 |
+
|
| 233 |
+
if description_reward == 0 and answer_reward == 1:
|
| 234 |
+
# Avoid multiplication to save computation
|
| 235 |
+
reward = alpha
|
| 236 |
+
else:
|
| 237 |
+
reward = description_reward + answer_reward
|
| 238 |
+
|
| 239 |
+
# print(f"Extracted description: {description_extraction} | Generated answer: {output_ans} | Sol: {gt_ans}")
|
| 240 |
+
# print(f'Description reward: {description_reward} | answer reward: {answer_reward} | final reward: {reward}')
|
| 241 |
+
# print('-' * 10)
|
| 242 |
+
else:
|
| 243 |
+
print('Falling back to none rewards')
|
| 244 |
+
reward = 0.0
|
| 245 |
+
except Exception as e:
|
| 246 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 247 |
+
reward = 0.0
|
| 248 |
+
|
| 249 |
+
rewards.append(reward)
|
| 250 |
+
|
| 251 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 252 |
+
log_path = os.getenv("LOG_PATH")
|
| 253 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 254 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 255 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 256 |
+
f.write(f"Content: {content}\n")
|
| 257 |
+
f.write(f"Solution: {sol}\n")
|
| 258 |
+
|
| 259 |
+
return rewards
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def simple_format_reward(completions, **kwargs):
|
| 263 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 264 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 265 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 266 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 267 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 268 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
reward_funcs_registry = {
|
| 272 |
+
"accuracy": accuracy_reward,
|
| 273 |
+
"format": simple_format_reward,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
SYSTEM_PROMPT = (
|
| 278 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 279 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 280 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 281 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 282 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def main(script_args, training_args, model_args):
|
| 287 |
+
print('Start program..')
|
| 288 |
+
# Get reward functions
|
| 289 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
print('Loading dataset')
|
| 293 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 294 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 295 |
+
else:
|
| 296 |
+
# Load the dataset
|
| 297 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Format into conversation
|
| 301 |
+
def make_conversation(example):
|
| 302 |
+
return {
|
| 303 |
+
"prompt": [
|
| 304 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 305 |
+
{"role": "user", "content": example["problem"]},
|
| 306 |
+
],
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
QUESTION_TEMPLATE = (
|
| 311 |
+
"{Question}\n"
|
| 312 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 313 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 314 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 315 |
+
"Provide your detailed, step-by-step reasoning based on the image and image description, and enclose this part within <think> </think> tags.\n"
|
| 316 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 317 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please keep your final answer short and precise."
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
TYPE_TEMPLATE = {
|
| 322 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 323 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 324 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 325 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 326 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 327 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
ABS_Verify_Prompt = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide a step-by-step reasoning within <think> </think> tags, then provide your answer as a single final answer, single letter choice, or a short phrase ENCLOSED with <answer> </answer> tags. \nText description: {{Description}}\nQuestion: {Question}\nPlease only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions.'''
|
| 331 |
+
|
| 332 |
+
def make_conversation_image(example):
|
| 333 |
+
|
| 334 |
+
return {
|
| 335 |
+
"prompt": [
|
| 336 |
+
{
|
| 337 |
+
"role": "user",
|
| 338 |
+
"content": [
|
| 339 |
+
{"type": "image"},
|
| 340 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 341 |
+
],
|
| 342 |
+
},
|
| 343 |
+
],
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def make_conversation_video(example):
|
| 348 |
+
return {
|
| 349 |
+
"prompt": [
|
| 350 |
+
{
|
| 351 |
+
"role": "user",
|
| 352 |
+
"content": [
|
| 353 |
+
{"type": "video"},
|
| 354 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 355 |
+
],
|
| 356 |
+
},
|
| 357 |
+
],
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
def make_conversation_image_and_video(example):
|
| 361 |
+
if example["problem_type"] == 'multiple choice':
|
| 362 |
+
question = example['problem'] + "Options:\n"
|
| 363 |
+
for op in example["options"]:
|
| 364 |
+
question += op + "\n"
|
| 365 |
+
else:
|
| 366 |
+
question = example['problem']
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
msg ={
|
| 370 |
+
"prompt":
|
| 371 |
+
[{
|
| 372 |
+
"role": "user",
|
| 373 |
+
"content": [
|
| 374 |
+
{
|
| 375 |
+
"type": example['data_type'],
|
| 376 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"type": "text",
|
| 380 |
+
# "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 381 |
+
"text": QUESTION_TEMPLATE.format(Question=question)
|
| 382 |
+
}
|
| 383 |
+
]
|
| 384 |
+
}]
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
return msg
|
| 388 |
+
|
| 389 |
+
def make_verify_conversation(example):
|
| 390 |
+
# ➊ build the question text
|
| 391 |
+
question = example["problem"]
|
| 392 |
+
if example["problem_type"] == "multiple choice":
|
| 393 |
+
question += "Options:\n" + "\n".join(example["options"])
|
| 394 |
+
|
| 395 |
+
# ➋ verification template + suffix (no if/else)
|
| 396 |
+
verify_text = (
|
| 397 |
+
ABS_Verify_Prompt.format(Question=question.replace("<image>", ""))
|
| 398 |
+
# + TYPE_TEMPLATE[example["problem_type"]] # ← one-liner, no branching
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# ➌ conversation dict
|
| 402 |
+
conv_dict = {
|
| 403 |
+
"prompt": [
|
| 404 |
+
{
|
| 405 |
+
"role": "user",
|
| 406 |
+
"content": [{"type": "text", "text": verify_text}],
|
| 407 |
+
}
|
| 408 |
+
]
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
# templated = maybe_apply_chat_template(conv_dict, processing_class)["prompt"]
|
| 412 |
+
# return {"verify_prompt": templated}
|
| 413 |
+
return {"verify_prompt": conv_dict}
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
print('Start mapping dataset')
|
| 419 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 420 |
+
dataset = dataset.map(
|
| 421 |
+
make_verify_conversation,
|
| 422 |
+
desc="add description verify prompt",
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
|
| 426 |
+
print("using: ", trainer_cls)
|
| 427 |
+
|
| 428 |
+
# Initialize the GRPO trainer
|
| 429 |
+
trainer = trainer_cls(
|
| 430 |
+
model=model_args.model_name_or_path,
|
| 431 |
+
reward_funcs=reward_funcs,
|
| 432 |
+
args=training_args,
|
| 433 |
+
script_args=script_args,
|
| 434 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 435 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 436 |
+
peft_config=get_peft_config(model_args),
|
| 437 |
+
attn_implementation=model_args.attn_implementation,
|
| 438 |
+
max_pixels=script_args.max_pixels,
|
| 439 |
+
min_pixels=script_args.min_pixels,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if training_args.resume_from_checkpoint is not None:
|
| 443 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 444 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 445 |
+
else:
|
| 446 |
+
trainer.train()
|
| 447 |
+
|
| 448 |
+
# Save and push to hub
|
| 449 |
+
trainer.save_model(training_args.output_dir)
|
| 450 |
+
if training_args.push_to_hub:
|
| 451 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if __name__ == "__main__":
|
| 455 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 456 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 457 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot-selfEvalConst.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerSelfConst
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
# from utils.math_cot import *
|
| 34 |
+
# from qa_metrics.pedant import PEDANT
|
| 35 |
+
|
| 36 |
+
# pedant = PEDANT()
|
| 37 |
+
|
| 38 |
+
'''
|
| 39 |
+
Alpha constant: When the description is wrong, but the final answer is right, the model is doing reward hacking,
|
| 40 |
+
so we give it a partial reward
|
| 41 |
+
'''
|
| 42 |
+
alpha = 0.85
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 46 |
+
"""
|
| 47 |
+
Script arguments for the GRPO training script.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
reward_funcs (`list[str]`):
|
| 51 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
reward_funcs: list[str] = field(
|
| 55 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 56 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# reward_funcs: list[str] = field(
|
| 60 |
+
# default_factory=lambda: ["accuracy"],
|
| 61 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 62 |
+
# )
|
| 63 |
+
max_pixels: Optional[int] = field(
|
| 64 |
+
default=12845056,
|
| 65 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 66 |
+
)
|
| 67 |
+
min_pixels: Optional[int] = field(
|
| 68 |
+
default=3136,
|
| 69 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 70 |
+
)
|
| 71 |
+
temporal: Optional[bool] = field(
|
| 72 |
+
default=True,
|
| 73 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 74 |
+
)
|
| 75 |
+
len_control: Optional[bool] = field(
|
| 76 |
+
default=True,
|
| 77 |
+
metadata={"help": "whether using length reward"},
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 83 |
+
def extract_answer(text: str) -> str:
|
| 84 |
+
"""
|
| 85 |
+
1) Try the full <answer> … </answer> block.
|
| 86 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 87 |
+
3) Otherwise return the original text.
|
| 88 |
+
"""
|
| 89 |
+
# ① normal case <answer> … </answer>
|
| 90 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 91 |
+
if m:
|
| 92 |
+
return m.group(1).strip()
|
| 93 |
+
|
| 94 |
+
# ② fallback <answer> … <end-of-string>
|
| 95 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 96 |
+
if m:
|
| 97 |
+
return m.group(1).strip()
|
| 98 |
+
|
| 99 |
+
# ③ nothing found
|
| 100 |
+
return text.strip()
|
| 101 |
+
|
| 102 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 103 |
+
answer = predict
|
| 104 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 105 |
+
|
| 106 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 107 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 108 |
+
# format_score = format_reward(predict)
|
| 109 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 110 |
+
|
| 111 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 112 |
+
return accuracy_score
|
| 113 |
+
|
| 114 |
+
def normalize_number(num_str):
|
| 115 |
+
try:
|
| 116 |
+
num_str = num_str.replace(',', '')
|
| 117 |
+
return float(num_str)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def wer(reference, hypothesis):
|
| 123 |
+
ref_words = reference.split()
|
| 124 |
+
hyp_words = hypothesis.split()
|
| 125 |
+
m = len(ref_words)
|
| 126 |
+
n = len(hyp_words)
|
| 127 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 128 |
+
for i in range(m+1):
|
| 129 |
+
d[i][0] = i
|
| 130 |
+
for j in range(n+1):
|
| 131 |
+
d[0][j] = j
|
| 132 |
+
for i in range(1, m+1):
|
| 133 |
+
for j in range(1, n+1):
|
| 134 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 135 |
+
d[i][j] = d[i-1][j-1]
|
| 136 |
+
else:
|
| 137 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 138 |
+
return d[m][n] / max(1, m)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 142 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 143 |
+
scores = scorer.score(reference, hypothesis)
|
| 144 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 145 |
+
return average_fmeasure
|
| 146 |
+
|
| 147 |
+
# print('Computing rewards now...')
|
| 148 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 149 |
+
second_completions = kwargs.get("second_completions")
|
| 150 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 151 |
+
# print('second prompts', second_prompts)
|
| 152 |
+
# print('-'*10)
|
| 153 |
+
# print('second completions', second_completions)
|
| 154 |
+
# print('-'*10)
|
| 155 |
+
|
| 156 |
+
# import time
|
| 157 |
+
# time.sleep(30)
|
| 158 |
+
question_type = kwargs['problem_type'][0]
|
| 159 |
+
question = kwargs['problem'][0]
|
| 160 |
+
|
| 161 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 162 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 163 |
+
rewards = []
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# model = kwargs.get("model") # may be None if called elsewhere
|
| 167 |
+
# tokenizer = kwargs.get("tokenizer")
|
| 168 |
+
|
| 169 |
+
# # (optional) example use: let the model score the generated answer
|
| 170 |
+
# if model is not None and tokenizer is not None:
|
| 171 |
+
# model.eval()
|
| 172 |
+
|
| 173 |
+
# for content, sol in zip(contents, solution):
|
| 174 |
+
for content, sol, second_content in zip(contents, solution, second_completions):
|
| 175 |
+
try:
|
| 176 |
+
output_ans = extract_answer(content)
|
| 177 |
+
gt_ans = extract_answer(sol)
|
| 178 |
+
description_extraction = extract_answer(second_content)
|
| 179 |
+
# if question_type == "multiple choice":
|
| 180 |
+
# reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 181 |
+
# elif question_type == "numerical":
|
| 182 |
+
# gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 183 |
+
# out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 184 |
+
# if gt_has_decimal != out_has_decimal:
|
| 185 |
+
# reward = 0.0
|
| 186 |
+
# else:
|
| 187 |
+
# gt_number = normalize_number(gt_ans)
|
| 188 |
+
# out_number = normalize_number(output_ans)
|
| 189 |
+
# if gt_number is None or out_number is None:
|
| 190 |
+
# reward = 0.0
|
| 191 |
+
# else:
|
| 192 |
+
# reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 193 |
+
# if question_type == "OCR":
|
| 194 |
+
# # description_extraction = extract_answer(second_content)
|
| 195 |
+
# # description_error_rate = wer(gt_ans, description_extraction)
|
| 196 |
+
# description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
|
| 197 |
+
# # error_rate = wer(gt_ans, output_ans)
|
| 198 |
+
# answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
|
| 199 |
+
# # reward = (1 - error_rate) + (1- description_error_rate)
|
| 200 |
+
# # reward = max(0.0, min(2.0, reward))
|
| 201 |
+
# print('Extracted description: ', description_extraction)
|
| 202 |
+
# print('Generated answer: ', output_ans)
|
| 203 |
+
# print('Sol: ', gt_ans)
|
| 204 |
+
# print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
|
| 205 |
+
# print('-' * 10)
|
| 206 |
+
# reward = description_pendat_reward + answer_pedant_reward
|
| 207 |
+
if question_type == "free-form":
|
| 208 |
+
score = compute_rouge_score(gt_ans, output_ans)
|
| 209 |
+
description_score = compute_rouge_score(gt_ans, description_extraction)
|
| 210 |
+
reward = max(0.0, min(1.0, score)) + max(0.0, min(1.0, description_score))
|
| 211 |
+
elif question_type == "regression":
|
| 212 |
+
gt_number = normalize_number(gt_ans)
|
| 213 |
+
out_number = normalize_number(output_ans)
|
| 214 |
+
description_number = normalize_number(description_extraction)
|
| 215 |
+
if gt_number is None or out_number is None:
|
| 216 |
+
reward = 0.0
|
| 217 |
+
|
| 218 |
+
if description_number is None:
|
| 219 |
+
description_reward = 0.0
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 223 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 224 |
+
|
| 225 |
+
description_diff = (abs(description_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 226 |
+
description_diff = min(1.0, max(0.0, description_diff))
|
| 227 |
+
|
| 228 |
+
reward = (1 - rel_diff) + (1 - description_diff)
|
| 229 |
+
elif question_type == 'math' or question_type == 'unify' or question_type == 'multiple choice' or question_type == 'numerical':
|
| 230 |
+
description_reward = compute_math_score_single(description_extraction, gt_ans)
|
| 231 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 232 |
+
|
| 233 |
+
if description_reward == 0 and answer_reward == 1:
|
| 234 |
+
# Avoid multiplication to save computation
|
| 235 |
+
reward = alpha
|
| 236 |
+
else:
|
| 237 |
+
reward = description_reward + answer_reward
|
| 238 |
+
|
| 239 |
+
# print(f"Extracted description: {description_extraction} | Generated answer: {output_ans} | Sol: {gt_ans}")
|
| 240 |
+
# print(f'Description reward: {description_reward} | answer reward: {answer_reward} | final reward: {reward}')
|
| 241 |
+
# print('-' * 10)
|
| 242 |
+
else:
|
| 243 |
+
print('Falling back to none rewards')
|
| 244 |
+
reward = 0.0
|
| 245 |
+
except Exception as e:
|
| 246 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 247 |
+
reward = 0.0
|
| 248 |
+
|
| 249 |
+
rewards.append(reward)
|
| 250 |
+
|
| 251 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 252 |
+
log_path = os.getenv("LOG_PATH")
|
| 253 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 254 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 255 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 256 |
+
f.write(f"Content: {content}\n")
|
| 257 |
+
f.write(f"Solution: {sol}\n")
|
| 258 |
+
|
| 259 |
+
return rewards
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def simple_format_reward(completions, **kwargs):
|
| 263 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 264 |
+
# pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 265 |
+
pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
|
| 266 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 267 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 268 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
reward_funcs_registry = {
|
| 272 |
+
"accuracy": accuracy_reward,
|
| 273 |
+
"format": simple_format_reward,
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
SYSTEM_PROMPT = (
|
| 278 |
+
"A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
|
| 279 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 280 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 281 |
+
"Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
|
| 282 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def main(script_args, training_args, model_args):
|
| 287 |
+
print('Start program..')
|
| 288 |
+
# Get reward functions
|
| 289 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
print('Loading dataset')
|
| 293 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 294 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 295 |
+
else:
|
| 296 |
+
# Load the dataset
|
| 297 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Format into conversation
|
| 301 |
+
def make_conversation(example):
|
| 302 |
+
return {
|
| 303 |
+
"prompt": [
|
| 304 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 305 |
+
{"role": "user", "content": example["problem"]},
|
| 306 |
+
],
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
QUESTION_TEMPLATE = (
|
| 311 |
+
"{Question}\n"
|
| 312 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
|
| 313 |
+
"Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
|
| 314 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 315 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 316 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 317 |
+
"The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
TYPE_TEMPLATE = {
|
| 322 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 323 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 324 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 325 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 326 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 327 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
ABS_Verify_Prompt = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide a step-by-step reasoning within <think> </think> tags, then provide your answer as a single final answer, single letter choice, or a short phrase ENCLOSED with <answer> </answer> tags. \nText description: {{Description}}\nQuestion: {Question}'''
|
| 331 |
+
|
| 332 |
+
def make_conversation_image(example):
|
| 333 |
+
|
| 334 |
+
return {
|
| 335 |
+
"prompt": [
|
| 336 |
+
{
|
| 337 |
+
"role": "user",
|
| 338 |
+
"content": [
|
| 339 |
+
{"type": "image"},
|
| 340 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 341 |
+
],
|
| 342 |
+
},
|
| 343 |
+
],
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def make_conversation_video(example):
|
| 348 |
+
return {
|
| 349 |
+
"prompt": [
|
| 350 |
+
{
|
| 351 |
+
"role": "user",
|
| 352 |
+
"content": [
|
| 353 |
+
{"type": "video"},
|
| 354 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 355 |
+
],
|
| 356 |
+
},
|
| 357 |
+
],
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
def make_conversation_image_and_video(example):
|
| 361 |
+
if example["problem_type"] == 'multiple choice':
|
| 362 |
+
question = example['problem'] + "Options:\n"
|
| 363 |
+
for op in example["options"]:
|
| 364 |
+
question += op + "\n"
|
| 365 |
+
else:
|
| 366 |
+
question = example['problem']
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
msg ={
|
| 370 |
+
"prompt":
|
| 371 |
+
[{
|
| 372 |
+
"role": "user",
|
| 373 |
+
"content": [
|
| 374 |
+
{
|
| 375 |
+
"type": example['data_type'],
|
| 376 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"type": "text",
|
| 380 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 381 |
+
}
|
| 382 |
+
]
|
| 383 |
+
}]
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
return msg
|
| 387 |
+
|
| 388 |
+
def make_verify_conversation(example):
|
| 389 |
+
# ➊ build the question text
|
| 390 |
+
question = example["problem"]
|
| 391 |
+
if example["problem_type"] == "multiple choice":
|
| 392 |
+
question += "Options:\n" + "\n".join(example["options"])
|
| 393 |
+
|
| 394 |
+
# ➋ verification template + suffix (no if/else)
|
| 395 |
+
verify_text = (
|
| 396 |
+
ABS_Verify_Prompt.format(Question=question.replace("<image>", ""))
|
| 397 |
+
+ TYPE_TEMPLATE[example["problem_type"]] # ← one-liner, no branching
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# ➌ conversation dict
|
| 401 |
+
conv_dict = {
|
| 402 |
+
"prompt": [
|
| 403 |
+
{
|
| 404 |
+
"role": "user",
|
| 405 |
+
"content": [{"type": "text", "text": verify_text}],
|
| 406 |
+
}
|
| 407 |
+
]
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
# templated = maybe_apply_chat_template(conv_dict, processing_class)["prompt"]
|
| 411 |
+
# return {"verify_prompt": templated}
|
| 412 |
+
return {"verify_prompt": conv_dict}
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
print('Start mapping dataset')
|
| 418 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 419 |
+
dataset = dataset.map(
|
| 420 |
+
make_verify_conversation,
|
| 421 |
+
desc="add description verify prompt",
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerSelfConst
|
| 425 |
+
print("using: ", trainer_cls)
|
| 426 |
+
|
| 427 |
+
# Initialize the GRPO trainer
|
| 428 |
+
trainer = trainer_cls(
|
| 429 |
+
model=model_args.model_name_or_path,
|
| 430 |
+
reward_funcs=reward_funcs,
|
| 431 |
+
args=training_args,
|
| 432 |
+
script_args=script_args,
|
| 433 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 434 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 435 |
+
peft_config=get_peft_config(model_args),
|
| 436 |
+
attn_implementation=model_args.attn_implementation,
|
| 437 |
+
max_pixels=script_args.max_pixels,
|
| 438 |
+
min_pixels=script_args.min_pixels,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
if training_args.resume_from_checkpoint is not None:
|
| 442 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 443 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 444 |
+
else:
|
| 445 |
+
trainer.train()
|
| 446 |
+
|
| 447 |
+
# Save and push to hub
|
| 448 |
+
trainer.save_model(training_args.output_dir)
|
| 449 |
+
if training_args.push_to_hub:
|
| 450 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 455 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 456 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-cot.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from datasets import load_dataset, load_from_disk
|
| 22 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 23 |
+
|
| 24 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
|
| 25 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 26 |
+
|
| 27 |
+
from datasets import Dataset, DatasetDict
|
| 28 |
+
|
| 29 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 30 |
+
from rouge_score import rouge_scorer
|
| 31 |
+
from utils.math_cot import *
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 36 |
+
"""
|
| 37 |
+
Script arguments for the GRPO training script.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
reward_funcs (`list[str]`):
|
| 41 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# reward_funcs: list[str] = field(
|
| 45 |
+
# default_factory=lambda: ["accuracy", "format"],
|
| 46 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 47 |
+
# )
|
| 48 |
+
|
| 49 |
+
reward_funcs: list[str] = field(
|
| 50 |
+
default_factory=lambda: ["accuracy"],
|
| 51 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 52 |
+
)
|
| 53 |
+
max_pixels: Optional[int] = field(
|
| 54 |
+
default=12845056,
|
| 55 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 56 |
+
)
|
| 57 |
+
min_pixels: Optional[int] = field(
|
| 58 |
+
default=3136,
|
| 59 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 60 |
+
)
|
| 61 |
+
temporal: Optional[bool] = field(
|
| 62 |
+
default=True,
|
| 63 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 64 |
+
)
|
| 65 |
+
len_control: Optional[bool] = field(
|
| 66 |
+
default=True,
|
| 67 |
+
metadata={"help": "whether using length reward"},
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 73 |
+
|
| 74 |
+
def extract_answer(text):
|
| 75 |
+
pattern = r'<answer>\s*(.*?)\s*</answer>'
|
| 76 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 77 |
+
if match:
|
| 78 |
+
return match.group(1).strip()
|
| 79 |
+
return ""
|
| 80 |
+
|
| 81 |
+
def normalize_number(num_str):
|
| 82 |
+
try:
|
| 83 |
+
num_str = num_str.replace(',', '')
|
| 84 |
+
return float(num_str)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
def wer(reference, hypothesis):
|
| 90 |
+
ref_words = reference.split()
|
| 91 |
+
hyp_words = hypothesis.split()
|
| 92 |
+
m = len(ref_words)
|
| 93 |
+
n = len(hyp_words)
|
| 94 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 95 |
+
for i in range(m+1):
|
| 96 |
+
d[i][0] = i
|
| 97 |
+
for j in range(n+1):
|
| 98 |
+
d[0][j] = j
|
| 99 |
+
for i in range(1, m+1):
|
| 100 |
+
for j in range(1, n+1):
|
| 101 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 102 |
+
d[i][j] = d[i-1][j-1]
|
| 103 |
+
else:
|
| 104 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 105 |
+
return d[m][n] / max(1, m)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 109 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 110 |
+
scores = scorer.score(reference, hypothesis)
|
| 111 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 112 |
+
return average_fmeasure
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
question_type = kwargs['problem_type'][0]
|
| 116 |
+
|
| 117 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 118 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 119 |
+
rewards = []
|
| 120 |
+
|
| 121 |
+
for content, sol in zip(contents, solution):
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
output_ans = extract_answer(content)
|
| 125 |
+
gt_ans = extract_answer(sol)
|
| 126 |
+
if question_type == "multiple choice":
|
| 127 |
+
reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 128 |
+
elif question_type == "numerical":
|
| 129 |
+
gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 130 |
+
out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 131 |
+
if gt_has_decimal != out_has_decimal:
|
| 132 |
+
reward = 0.0
|
| 133 |
+
else:
|
| 134 |
+
gt_number = normalize_number(gt_ans)
|
| 135 |
+
out_number = normalize_number(output_ans)
|
| 136 |
+
if gt_number is None or out_number is None:
|
| 137 |
+
reward = 0.0
|
| 138 |
+
else:
|
| 139 |
+
reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 140 |
+
elif question_type == "OCR":
|
| 141 |
+
error_rate = wer(gt_ans, output_ans)
|
| 142 |
+
reward = 1 - error_rate
|
| 143 |
+
reward = max(0.0, min(1.0, reward))
|
| 144 |
+
elif question_type == "free-form":
|
| 145 |
+
score = compute_rouge_score(gt_ans, output_ans)
|
| 146 |
+
reward = max(0.0, min(1.0, score))
|
| 147 |
+
elif question_type == "regression":
|
| 148 |
+
gt_number = normalize_number(gt_ans)
|
| 149 |
+
out_number = normalize_number(output_ans)
|
| 150 |
+
if gt_number is None or out_number is None:
|
| 151 |
+
reward = 0.0
|
| 152 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 153 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 154 |
+
reward = 1 - rel_diff
|
| 155 |
+
elif question_type == 'math':
|
| 156 |
+
reward = compute_math_score_single(content, gt_ans)
|
| 157 |
+
else:
|
| 158 |
+
print('Falling back to none rewards')
|
| 159 |
+
reward = 0.0
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 162 |
+
reward = 0.0
|
| 163 |
+
|
| 164 |
+
rewards.append(reward)
|
| 165 |
+
|
| 166 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 167 |
+
log_path = os.getenv("LOG_PATH")
|
| 168 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 169 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 170 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 171 |
+
f.write(f"Content: {content}\n")
|
| 172 |
+
f.write(f"Solution: {sol}\n")
|
| 173 |
+
|
| 174 |
+
return rewards
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def format_reward(completions, **kwargs):
|
| 178 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 179 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 180 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 181 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 182 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
reward_funcs_registry = {
|
| 186 |
+
"accuracy": accuracy_reward,
|
| 187 |
+
# "format": 0,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
# SYSTEM_PROMPT = (
|
| 191 |
+
# "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 192 |
+
# "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 193 |
+
# "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 194 |
+
# "<think> reasoning process here </think><answer> answer here </answer>"
|
| 195 |
+
# )
|
| 196 |
+
|
| 197 |
+
SYSTEM_PROMPT = (
|
| 198 |
+
"A conversation between User and Assistant. The user provides a question about an image, "
|
| 199 |
+
"and the Assistant is tasked with generating an exhaustive and detailed description of the image. "
|
| 200 |
+
"The assistant should extract and describe all possible information from the image—including objects, numbers, text, and their relationships—"
|
| 201 |
+
"and enclose this description within <info> </info> tags. "
|
| 202 |
+
"Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
|
| 203 |
+
"and provide this step-by-step reasoning within <think> </think> tags. "
|
| 204 |
+
"Finally, the assistant provides a single word or phrase answer within <answer> </answer> tags. "
|
| 205 |
+
"The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def main(script_args, training_args, model_args):
|
| 210 |
+
# Get reward functions
|
| 211 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 212 |
+
|
| 213 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 214 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 215 |
+
else:
|
| 216 |
+
# Load the dataset
|
| 217 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Format into conversation
|
| 221 |
+
def make_conversation(example):
|
| 222 |
+
return {
|
| 223 |
+
"prompt": [
|
| 224 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 225 |
+
{"role": "user", "content": example["problem"]},
|
| 226 |
+
],
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# QUESTION_TEMPLATE = (
|
| 231 |
+
# "{Question}\n"
|
| 232 |
+
# "Please think about this question as if you were a human pondering deeply. "
|
| 233 |
+
# "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 234 |
+
# "It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 235 |
+
# "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 236 |
+
# )
|
| 237 |
+
|
| 238 |
+
QUESTION_TEMPLATE = (
|
| 239 |
+
"{Question}\n"
|
| 240 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description. "
|
| 241 |
+
"Your goal is to extract and describe all possible information from the image, including but not limited to objects, numbers, text, and the relationships between these elements. "
|
| 242 |
+
"The description should be as fine and detailed as possible, capturing every nuance, and should be enclosed within <info> </info> tags.\n"
|
| 243 |
+
"Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
|
| 244 |
+
"Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
|
| 245 |
+
"Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
|
| 246 |
+
"The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
TYPE_TEMPLATE = {
|
| 251 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 252 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 253 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 254 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 255 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 256 |
+
"math": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
def make_conversation_image(example):
|
| 260 |
+
|
| 261 |
+
return {
|
| 262 |
+
"prompt": [
|
| 263 |
+
{
|
| 264 |
+
"role": "user",
|
| 265 |
+
"content": [
|
| 266 |
+
{"type": "image"},
|
| 267 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 268 |
+
],
|
| 269 |
+
},
|
| 270 |
+
],
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def make_conversation_video(example):
|
| 275 |
+
return {
|
| 276 |
+
"prompt": [
|
| 277 |
+
{
|
| 278 |
+
"role": "user",
|
| 279 |
+
"content": [
|
| 280 |
+
{"type": "video"},
|
| 281 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 282 |
+
],
|
| 283 |
+
},
|
| 284 |
+
],
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
def make_conversation_image_and_video(example):
|
| 288 |
+
if example["problem_type"] == 'multiple choice':
|
| 289 |
+
question = example['problem'] + "Options:\n"
|
| 290 |
+
for op in example["options"]:
|
| 291 |
+
question += op + "\n"
|
| 292 |
+
else:
|
| 293 |
+
question = example['problem']
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
msg ={
|
| 297 |
+
"prompt":
|
| 298 |
+
[{
|
| 299 |
+
"role": "user",
|
| 300 |
+
"content": [
|
| 301 |
+
{
|
| 302 |
+
"type": example['data_type'],
|
| 303 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"type": "text",
|
| 307 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 308 |
+
}
|
| 309 |
+
]
|
| 310 |
+
}]
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
return msg
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
|
| 320 |
+
print("using: ", trainer_cls)
|
| 321 |
+
|
| 322 |
+
# Initialize the GRPO trainer
|
| 323 |
+
trainer = trainer_cls(
|
| 324 |
+
model=model_args.model_name_or_path,
|
| 325 |
+
reward_funcs=reward_funcs,
|
| 326 |
+
args=training_args,
|
| 327 |
+
script_args=script_args,
|
| 328 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 329 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 330 |
+
peft_config=get_peft_config(model_args),
|
| 331 |
+
attn_implementation=model_args.attn_implementation,
|
| 332 |
+
max_pixels=script_args.max_pixels,
|
| 333 |
+
min_pixels=script_args.min_pixels,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if training_args.resume_from_checkpoint is not None:
|
| 337 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 338 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 339 |
+
else:
|
| 340 |
+
trainer.train()
|
| 341 |
+
|
| 342 |
+
# Save and push to hub
|
| 343 |
+
trainer.save_model(training_args.output_dir)
|
| 344 |
+
if training_args.push_to_hub:
|
| 345 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 350 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 351 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo-description-LLMEval.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
from datasets import load_dataset, load_from_disk
|
| 21 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 22 |
+
from openai import OpenAI
|
| 23 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 24 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 25 |
+
|
| 26 |
+
from datasets import Dataset, DatasetDict
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Optional
|
| 29 |
+
from mathruler.grader import extract_boxed_content, grade_answer
|
| 30 |
+
|
| 31 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 32 |
+
from rouge_score import rouge_scorer
|
| 33 |
+
# from utils.gpt_eval import infer
|
| 34 |
+
# from utils.math_cot import *
|
| 35 |
+
# from qa_metrics.pedant import PEDANT
|
| 36 |
+
# from qa_metrics.answerBERT import AnswerBertActor
|
| 37 |
+
|
| 38 |
+
# pedant = PEDANT()
|
| 39 |
+
# answerBERT = AnswerBertActor(device='cuda:7')
|
| 40 |
+
|
| 41 |
+
alpha = 1.0
|
| 42 |
+
|
| 43 |
+
TYPE_TEMPLATE = {
|
| 44 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) in \\boxed{}.",
|
| 45 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) in \\boxed{}.",
|
| 46 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer in \\boxed{}.",
|
| 47 |
+
"free-form": " Please provide your text answer in \\boxed{}.",
|
| 48 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) in \\boxed{}.",
|
| 49 |
+
"math": " Please provide the final exact answer (single option letter for multiple choice) in \\boxed{}.",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
'''
|
| 53 |
+
gpt infer
|
| 54 |
+
'''
|
| 55 |
+
import os
|
| 56 |
+
from openai import AzureOpenAI
|
| 57 |
+
import time
|
| 58 |
+
|
| 59 |
+
import base64
|
| 60 |
+
from mimetypes import guess_type
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def azure_gpt4(messages, model):
|
| 64 |
+
outputs = []
|
| 65 |
+
for message in messages:
|
| 66 |
+
input_prompt = [
|
| 67 |
+
{ "role": "system", "content": "You are a helpful assistant." },
|
| 68 |
+
{ "role": "user", "content": [
|
| 69 |
+
{
|
| 70 |
+
"type": "text",
|
| 71 |
+
"text": message["instruction"]
|
| 72 |
+
},
|
| 73 |
+
# {
|
| 74 |
+
# "type": "image_url",
|
| 75 |
+
# "image_url": {
|
| 76 |
+
# "url": message["image"]
|
| 77 |
+
# }
|
| 78 |
+
# }
|
| 79 |
+
]}
|
| 80 |
+
]
|
| 81 |
+
## try N times if API exceed limit ...
|
| 82 |
+
for i in range(10):
|
| 83 |
+
try:
|
| 84 |
+
output = client.chat.completions.create(
|
| 85 |
+
model=model, messages=input_prompt, max_tokens=2000
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
output_text = output.choices[0].message.content
|
| 89 |
+
break ## exit if successful
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f'Index {i} got error message: {e}')
|
| 93 |
+
output_text = ''
|
| 94 |
+
time.sleep(3)
|
| 95 |
+
|
| 96 |
+
outputs.append(output_text)
|
| 97 |
+
|
| 98 |
+
return outputs
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
client = AzureOpenAI(
|
| 102 |
+
api_key = "83f30a2a22324395b854bd343db38d85",
|
| 103 |
+
api_version = "2024-08-01-preview",
|
| 104 |
+
azure_endpoint = "https://francecentral.api.cognitive.microsoft.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
model = "gpt-4o"
|
| 108 |
+
prompt_template = '''Text description: {text}\nQuestion: {question}\nYou are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide an internal step-by-step reasoning within <think> </think> tags, then provide a single word or phrase answer in \\boxed{}.'''
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# client = OpenAI(
|
| 112 |
+
# base_url="http://29.81.244.54:8080/v1", # your vLLM server
|
| 113 |
+
# api_key="ANYKEY", # if you set --api-key when launching
|
| 114 |
+
# )
|
| 115 |
+
|
| 116 |
+
client = OpenAI(
|
| 117 |
+
base_url="http://29.81.224.188:8080/v1", # your vLLM server
|
| 118 |
+
api_key="ANYKEY", # if you set --api-key when launching
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def chat_batch(
|
| 122 |
+
client,
|
| 123 |
+
all_message_batches: List[List[Dict[str, str]]],
|
| 124 |
+
*,
|
| 125 |
+
# model: str = "Qwen2.5-32B-Instruct",
|
| 126 |
+
model: str = "Qwen2.5-32B-finetune",
|
| 127 |
+
max_workers: int = 8,
|
| 128 |
+
retries: int = 2,
|
| 129 |
+
backoff: float = 0.5,
|
| 130 |
+
timeout: Optional[float] = None,
|
| 131 |
+
) -> List[str]:
|
| 132 |
+
"""
|
| 133 |
+
Send many chat requests in parallel and return replies as a list of strings,
|
| 134 |
+
preserving the order of `all_message_batches`.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def _chat_once_with_retry(messages: List[Dict[str, str]]) -> str:
|
| 138 |
+
last_err: Optional[BaseException] = None
|
| 139 |
+
for attempt in range(retries + 1):
|
| 140 |
+
try:
|
| 141 |
+
resp = client.chat.completions.create(
|
| 142 |
+
model=model,
|
| 143 |
+
messages=messages,
|
| 144 |
+
timeout=timeout,
|
| 145 |
+
)
|
| 146 |
+
# Different SDKs expose content slightly differently; handle common cases.
|
| 147 |
+
choice = resp.choices[0]
|
| 148 |
+
if hasattr(choice, "message") and getattr(choice.message, "content", None) is not None:
|
| 149 |
+
return choice.message.content
|
| 150 |
+
if hasattr(choice, "text") and choice.text is not None:
|
| 151 |
+
return choice.text
|
| 152 |
+
# Fallback to stringifying the choice if structure is unexpected.
|
| 153 |
+
return str(choice)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
last_err = e
|
| 156 |
+
if attempt < retries:
|
| 157 |
+
sleep(backoff * (2 ** attempt))
|
| 158 |
+
return f"Error: {last_err!r}"
|
| 159 |
+
|
| 160 |
+
results: List[Optional[str]] = [None] * len(all_message_batches)
|
| 161 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 162 |
+
future_to_idx = {
|
| 163 |
+
executor.submit(_chat_once_with_retry, batch): i
|
| 164 |
+
for i, batch in enumerate(all_message_batches)
|
| 165 |
+
}
|
| 166 |
+
for fut in as_completed(future_to_idx):
|
| 167 |
+
i = future_to_idx[fut]
|
| 168 |
+
results[i] = fut.result()
|
| 169 |
+
|
| 170 |
+
# mypy-friendly cast: no Nones remain at this point
|
| 171 |
+
return [r if r is not None else "Error: Unknown failure" for r in results]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def infer(prompt):
|
| 175 |
+
# prompt_question = prompt_question.replace('<image>', '')
|
| 176 |
+
# prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
|
| 177 |
+
|
| 178 |
+
messages = [
|
| 179 |
+
{"instruction": prompt},
|
| 180 |
+
]
|
| 181 |
+
prompt_success = False
|
| 182 |
+
prompt_time = 0
|
| 183 |
+
outputs = ['\\boxed{None}']
|
| 184 |
+
while prompt_success == False and prompt_time <= 2:
|
| 185 |
+
try:
|
| 186 |
+
outputs = azure_gpt4(messages, model)
|
| 187 |
+
prompt_success = True
|
| 188 |
+
except:
|
| 189 |
+
prompt_time += 1
|
| 190 |
+
time.sleep(5)
|
| 191 |
+
|
| 192 |
+
return outputs[0]
|
| 193 |
+
|
| 194 |
+
'''
|
| 195 |
+
end of gpt infer
|
| 196 |
+
'''
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 200 |
+
|
| 201 |
+
def _call_infer(desc):
|
| 202 |
+
return infer(desc)
|
| 203 |
+
|
| 204 |
+
@dataclass
|
| 205 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 206 |
+
"""
|
| 207 |
+
Script arguments for the GRPO training script.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
reward_funcs (`list[str]`):
|
| 211 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
reward_funcs: list[str] = field(
|
| 215 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 216 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# reward_funcs: list[str] = field(
|
| 220 |
+
# default_factory=lambda: ["accuracy"],
|
| 221 |
+
# metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
|
| 222 |
+
# )
|
| 223 |
+
max_pixels: Optional[int] = field(
|
| 224 |
+
default=12845056,
|
| 225 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 226 |
+
)
|
| 227 |
+
min_pixels: Optional[int] = field(
|
| 228 |
+
default=3136,
|
| 229 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 230 |
+
)
|
| 231 |
+
temporal: Optional[bool] = field(
|
| 232 |
+
default=True,
|
| 233 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 234 |
+
)
|
| 235 |
+
len_control: Optional[bool] = field(
|
| 236 |
+
default=True,
|
| 237 |
+
metadata={"help": "whether using length reward"},
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 243 |
+
def extract_answer(text: str) -> str:
|
| 244 |
+
"""
|
| 245 |
+
1) Try the full <answer> … </answer> block.
|
| 246 |
+
2) If that is missing, grab whatever follows the opening <answer> tag.
|
| 247 |
+
3) Otherwise return the original text.
|
| 248 |
+
"""
|
| 249 |
+
# ① normal case <answer> … </answer>
|
| 250 |
+
m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
|
| 251 |
+
if m:
|
| 252 |
+
return m.group(1).strip()
|
| 253 |
+
|
| 254 |
+
# ② fallback <answer> … <end-of-string>
|
| 255 |
+
m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
|
| 256 |
+
if m:
|
| 257 |
+
return m.group(1).strip()
|
| 258 |
+
|
| 259 |
+
# ③ nothing found
|
| 260 |
+
return text.strip()
|
| 261 |
+
|
| 262 |
+
def extract_description(predict: str) -> Optional[str]:
|
| 263 |
+
"""
|
| 264 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 265 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 266 |
+
or None if no <answer> tag is found.
|
| 267 |
+
"""
|
| 268 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 269 |
+
if not match:
|
| 270 |
+
return predict
|
| 271 |
+
return match.group(1).strip()
|
| 272 |
+
|
| 273 |
+
def single_accuracy_reward(predict: str, ground_truth: str) -> float:
|
| 274 |
+
answer = predict
|
| 275 |
+
return 1.0 if grade_answer(answer, ground_truth) else 0.0
|
| 276 |
+
|
| 277 |
+
def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
|
| 278 |
+
predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
|
| 279 |
+
accuracy_score = single_accuracy_reward(predict, ground_truth)
|
| 280 |
+
# return (1 - format_weight) * accuracy_score + format_weight * format_score
|
| 281 |
+
return accuracy_score
|
| 282 |
+
|
| 283 |
+
def normalize_number(num_str):
|
| 284 |
+
try:
|
| 285 |
+
num_str = num_str.replace(',', '')
|
| 286 |
+
return float(num_str)
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
def wer(reference, hypothesis):
|
| 292 |
+
ref_words = reference.split()
|
| 293 |
+
hyp_words = hypothesis.split()
|
| 294 |
+
m = len(ref_words)
|
| 295 |
+
n = len(hyp_words)
|
| 296 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 297 |
+
for i in range(m+1):
|
| 298 |
+
d[i][0] = i
|
| 299 |
+
for j in range(n+1):
|
| 300 |
+
d[0][j] = j
|
| 301 |
+
for i in range(1, m+1):
|
| 302 |
+
for j in range(1, n+1):
|
| 303 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 304 |
+
d[i][j] = d[i-1][j-1]
|
| 305 |
+
else:
|
| 306 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 307 |
+
return d[m][n] / max(1, m)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 311 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 312 |
+
scores = scorer.score(reference, hypothesis)
|
| 313 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 314 |
+
return average_fmeasure
|
| 315 |
+
|
| 316 |
+
# print('Computing rewards now...')
|
| 317 |
+
# second_prompts = kwargs.get("second_prompts") # ← list[str] or None
|
| 318 |
+
# second_completions = kwargs.get("second_completions")
|
| 319 |
+
# second_contents = [comp[0]["content"] for comp in second_completions]
|
| 320 |
+
# print('second prompts', second_prompts)
|
| 321 |
+
# print('-'*10)
|
| 322 |
+
# print('second completions', second_completions)
|
| 323 |
+
# print('-'*10)
|
| 324 |
+
|
| 325 |
+
# import time
|
| 326 |
+
# time.sleep(30)
|
| 327 |
+
question_type = kwargs['problem_type'][0]
|
| 328 |
+
questions = kwargs['problem']
|
| 329 |
+
|
| 330 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 331 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 332 |
+
rewards = []
|
| 333 |
+
|
| 334 |
+
extracted_content_descriptions = [extract_description(ele) for ele in contents]
|
| 335 |
+
|
| 336 |
+
description_query_inputs = []
|
| 337 |
+
batch_messages = []
|
| 338 |
+
vllm_batch_messages = []
|
| 339 |
+
|
| 340 |
+
for index in range(len(extracted_content_descriptions)):
|
| 341 |
+
prompt_question = questions[index]
|
| 342 |
+
des_text = extracted_content_descriptions[index]
|
| 343 |
+
prompt_question = prompt_question.replace('<image>', '')
|
| 344 |
+
prompt_input = prompt_template.replace('{text}', des_text).replace('{question}', prompt_question) + TYPE_TEMPLATE[question_type]
|
| 345 |
+
description_query_inputs.append(prompt_input)
|
| 346 |
+
curr_msg = [
|
| 347 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
| 348 |
+
{"role": "user", "content": prompt_input}
|
| 349 |
+
]
|
| 350 |
+
vllm_batch_messages.append(curr_msg)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
batched_vllm_outputs = chat_batch(client, vllm_batch_messages)
|
| 354 |
+
|
| 355 |
+
description_score_outputs = [extract_boxed_content(idx_input) for idx_input in batched_vllm_outputs]
|
| 356 |
+
# with ThreadPoolExecutor(max_workers=8) as executor:
|
| 357 |
+
# futures = [
|
| 358 |
+
# executor.submit(_call_infer, desc)
|
| 359 |
+
# for desc in description_query_inputs
|
| 360 |
+
# ]
|
| 361 |
+
# # collect as they finish (optional—keeps order of completion)
|
| 362 |
+
# for fut in as_completed(futures):
|
| 363 |
+
# # description_score_outputs.append(extract_answer(fut.result()))
|
| 364 |
+
# # extract_boxed_content
|
| 365 |
+
# description_score_outputs.append(extract_boxed_content(fut.result()))
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
gt_answers = [extract_answer(sol) for sol in solution]
|
| 369 |
+
description_rewards = [compute_math_score_single(description_score_outputs[count_idx], gt_answers[count_idx]) for count_idx in range(len(description_score_outputs))]
|
| 370 |
+
|
| 371 |
+
print(gt_answers)
|
| 372 |
+
print(description_score_outputs)
|
| 373 |
+
print(description_rewards)
|
| 374 |
+
print('-'*10)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
for content, gt_ans, description_reward in zip(contents, gt_answers, description_rewards):
|
| 378 |
+
# for content, sol, question in zip(contents, solution, questions):
|
| 379 |
+
# for content, sol, second_content in zip(contents, solution, second_completions):
|
| 380 |
+
try:
|
| 381 |
+
# output_ans = extract_answer(content)
|
| 382 |
+
output_ans = extract_boxed_content(content)
|
| 383 |
+
|
| 384 |
+
if question_type != 'None':
|
| 385 |
+
answer_reward = compute_math_score_single(output_ans, gt_ans)
|
| 386 |
+
if description_reward == 0 and answer_reward == 1:
|
| 387 |
+
reward = alpha
|
| 388 |
+
else:
|
| 389 |
+
reward = description_reward + answer_reward
|
| 390 |
+
# reward = answer_reward
|
| 391 |
+
else:
|
| 392 |
+
print('Falling back to none rewards')
|
| 393 |
+
reward = 0.0
|
| 394 |
+
except Exception as e:
|
| 395 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 396 |
+
reward = 0.0
|
| 397 |
+
|
| 398 |
+
rewards.append(reward)
|
| 399 |
+
|
| 400 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 401 |
+
log_path = os.getenv("LOG_PATH")
|
| 402 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 403 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 404 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 405 |
+
f.write(f"Content: {content}\n")
|
| 406 |
+
f.write(f"Solution: {gt_ans}\n")
|
| 407 |
+
|
| 408 |
+
return rewards
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def simple_format_reward(completions, **kwargs):
|
| 412 |
+
"""Reward function that checks the same format as `format_reward`:
|
| 413 |
+
<description>...</description><think>...</think>\boxed{...}
|
| 414 |
+
"""
|
| 415 |
+
pattern = re.compile(
|
| 416 |
+
r"^\s*<description>.*?</description>\s*"
|
| 417 |
+
r"<think>.*?</think>\s*"
|
| 418 |
+
r"\\boxed\{.*?\}\s*$",
|
| 419 |
+
re.DOTALL,
|
| 420 |
+
)
|
| 421 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 422 |
+
return [0.1 if pattern.fullmatch(content or "") else 0.0
|
| 423 |
+
for content in completion_contents]
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
reward_funcs_registry = {
|
| 427 |
+
"accuracy": accuracy_reward,
|
| 428 |
+
"format": simple_format_reward,
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
# SYSTEM_PROMPT = (
|
| 432 |
+
# "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 433 |
+
# "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 434 |
+
# "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 435 |
+
# "<think> reasoning process here </think><answer> answer here </answer>"
|
| 436 |
+
# )
|
| 437 |
+
|
| 438 |
+
SYSTEM_PROMPT = (
|
| 439 |
+
"You are tasked with analyzing an image/video to generate a detailed description to help you answer the question. First analyze the image/video and produce a self-contained description—detailed enough that can lead to the correct answer. Wrap the entire description in <description> </description> tags.\n Next, engage in an internal dialogue and include self-reflection or verification in your reasoning process. Provide your detailed, step-by-step reasoning based on the image/video description information and image/video, and enclose this part within <think> </think> tags.\n Finally, provide a single word or phrase answer to the question in \boxed{}.\nThe output format should be: <description> image/video description here </description> <think> reasoning process here </think> \boxed{FINAL ANSWER here}."
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def main(script_args, training_args, model_args):
|
| 444 |
+
# Get reward functions
|
| 445 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 446 |
+
|
| 447 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 448 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 449 |
+
else:
|
| 450 |
+
# Load the dataset
|
| 451 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Format into conversation
|
| 455 |
+
def make_conversation(example):
|
| 456 |
+
return {
|
| 457 |
+
"prompt": [
|
| 458 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 459 |
+
{"role": "user", "content": example["problem"]},
|
| 460 |
+
],
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# QUESTION_TEMPLATE = (
|
| 465 |
+
# "{Question}\n"
|
| 466 |
+
# "Please think about this question as if you were a human pondering deeply. "
|
| 467 |
+
# "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 468 |
+
# "It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 469 |
+
# "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 470 |
+
# )
|
| 471 |
+
|
| 472 |
+
QUESTION_TEMPLATE = (
|
| 473 |
+
"{Question}\n"
|
| 474 |
+
"You are tasked with analyzing an image/video to generate a detailed description to help you answer the question. "
|
| 475 |
+
"First analyze the image/video and produce a self-contained description—detailed enough that can lead to the correct answer. "
|
| 476 |
+
"Wrap the entire description in <description> </description> tags.\n"
|
| 477 |
+
"Next, engage in an internal dialogue and include self-reflection or verification in your reasoning process. "
|
| 478 |
+
"Provide your detailed, step-by-step reasoning based on the image/video description information and image/video, and enclose this part within <think> </think> tags.\n"
|
| 479 |
+
"Finally, provide a single word or phrase answer to the question in \\boxed{{}}.\n"
|
| 480 |
+
"The output format should be: <description> image/video description here </description> "
|
| 481 |
+
"<think> reasoning process here </think> \\boxed{{FINAL ANSWER here}}."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def make_conversation_image(example):
|
| 487 |
+
|
| 488 |
+
return {
|
| 489 |
+
"prompt": [
|
| 490 |
+
{
|
| 491 |
+
"role": "user",
|
| 492 |
+
"content": [
|
| 493 |
+
{"type": "image"},
|
| 494 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 495 |
+
],
|
| 496 |
+
},
|
| 497 |
+
],
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def make_conversation_video(example):
|
| 502 |
+
return {
|
| 503 |
+
"prompt": [
|
| 504 |
+
{
|
| 505 |
+
"role": "user",
|
| 506 |
+
"content": [
|
| 507 |
+
{"type": "video"},
|
| 508 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 509 |
+
],
|
| 510 |
+
},
|
| 511 |
+
],
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
def make_conversation_image_and_video(example):
|
| 515 |
+
if example["problem_type"] == 'multiple choice':
|
| 516 |
+
question = example['problem'] + "Options:\n"
|
| 517 |
+
for op in example["options"]:
|
| 518 |
+
question += op + "\n"
|
| 519 |
+
else:
|
| 520 |
+
question = example['problem']
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
msg ={
|
| 524 |
+
"prompt":
|
| 525 |
+
[{
|
| 526 |
+
"role": "user",
|
| 527 |
+
"content": [
|
| 528 |
+
{
|
| 529 |
+
"type": example['data_type'],
|
| 530 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 531 |
+
},
|
| 532 |
+
{
|
| 533 |
+
"type": "text",
|
| 534 |
+
# "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 535 |
+
"text": QUESTION_TEMPLATE.format(Question=question)
|
| 536 |
+
}
|
| 537 |
+
]
|
| 538 |
+
}]
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
return msg
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 548 |
+
print("using: ", trainer_cls)
|
| 549 |
+
|
| 550 |
+
# Initialize the GRPO trainer
|
| 551 |
+
trainer = trainer_cls(
|
| 552 |
+
model=model_args.model_name_or_path,
|
| 553 |
+
reward_funcs=reward_funcs,
|
| 554 |
+
args=training_args,
|
| 555 |
+
script_args=script_args,
|
| 556 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 557 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 558 |
+
peft_config=get_peft_config(model_args),
|
| 559 |
+
attn_implementation=model_args.attn_implementation,
|
| 560 |
+
max_pixels=script_args.max_pixels,
|
| 561 |
+
min_pixels=script_args.min_pixels,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if training_args.resume_from_checkpoint is not None:
|
| 565 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 566 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 567 |
+
else:
|
| 568 |
+
trainer.train()
|
| 569 |
+
|
| 570 |
+
# Save and push to hub
|
| 571 |
+
trainer.save_model(training_args.output_dir)
|
| 572 |
+
if training_args.push_to_hub:
|
| 573 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
if __name__ == "__main__":
|
| 577 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 578 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 579 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import re
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from datasets import load_dataset, load_from_disk
|
| 22 |
+
from transformers import Qwen2VLForConditionalGeneration
|
| 23 |
+
|
| 24 |
+
from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
|
| 25 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 26 |
+
|
| 27 |
+
from datasets import Dataset, DatasetDict
|
| 28 |
+
|
| 29 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 30 |
+
from rouge_score import rouge_scorer
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 35 |
+
"""
|
| 36 |
+
Script arguments for the GRPO training script.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
reward_funcs (`list[str]`):
|
| 40 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
reward_funcs: list[str] = field(
|
| 44 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 45 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 46 |
+
)
|
| 47 |
+
max_pixels: Optional[int] = field(
|
| 48 |
+
default=12845056,
|
| 49 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 50 |
+
)
|
| 51 |
+
min_pixels: Optional[int] = field(
|
| 52 |
+
default=3136,
|
| 53 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 54 |
+
)
|
| 55 |
+
temporal: Optional[bool] = field(
|
| 56 |
+
default=True,
|
| 57 |
+
metadata={"help": "whether using temporal GRPO"},
|
| 58 |
+
)
|
| 59 |
+
len_control: Optional[bool] = field(
|
| 60 |
+
default=True,
|
| 61 |
+
metadata={"help": "whether using length reward"},
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 67 |
+
|
| 68 |
+
def extract_answer(text):
|
| 69 |
+
pattern = r'<answer>\s*(.*?)\s*</answer>'
|
| 70 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 71 |
+
if match:
|
| 72 |
+
return match.group(1).strip()
|
| 73 |
+
return ""
|
| 74 |
+
|
| 75 |
+
def normalize_number(num_str):
|
| 76 |
+
try:
|
| 77 |
+
num_str = num_str.replace(',', '')
|
| 78 |
+
return float(num_str)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Error converting '{num_str}' to float: {e}")
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
def wer(reference, hypothesis):
|
| 84 |
+
ref_words = reference.split()
|
| 85 |
+
hyp_words = hypothesis.split()
|
| 86 |
+
m = len(ref_words)
|
| 87 |
+
n = len(hyp_words)
|
| 88 |
+
d = [[0]*(n+1) for _ in range(m+1)]
|
| 89 |
+
for i in range(m+1):
|
| 90 |
+
d[i][0] = i
|
| 91 |
+
for j in range(n+1):
|
| 92 |
+
d[0][j] = j
|
| 93 |
+
for i in range(1, m+1):
|
| 94 |
+
for j in range(1, n+1):
|
| 95 |
+
if ref_words[i-1] == hyp_words[j-1]:
|
| 96 |
+
d[i][j] = d[i-1][j-1]
|
| 97 |
+
else:
|
| 98 |
+
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
|
| 99 |
+
return d[m][n] / max(1, m)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 103 |
+
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 104 |
+
scores = scorer.score(reference, hypothesis)
|
| 105 |
+
average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 106 |
+
return average_fmeasure
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
question_type = kwargs['problem_type'][0]
|
| 110 |
+
|
| 111 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 112 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 113 |
+
rewards = []
|
| 114 |
+
|
| 115 |
+
for content, sol in zip(contents, solution):
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
output_ans = extract_answer(content)
|
| 119 |
+
gt_ans = extract_answer(sol)
|
| 120 |
+
if question_type == "multiple choice":
|
| 121 |
+
reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
|
| 122 |
+
elif question_type == "numerical":
|
| 123 |
+
gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
|
| 124 |
+
out_has_decimal = ("." in output_ans) or ("," in output_ans)
|
| 125 |
+
if gt_has_decimal != out_has_decimal:
|
| 126 |
+
reward = 0.0
|
| 127 |
+
else:
|
| 128 |
+
gt_number = normalize_number(gt_ans)
|
| 129 |
+
out_number = normalize_number(output_ans)
|
| 130 |
+
if gt_number is None or out_number is None:
|
| 131 |
+
reward = 0.0
|
| 132 |
+
else:
|
| 133 |
+
reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
|
| 134 |
+
elif question_type == "OCR":
|
| 135 |
+
error_rate = wer(gt_ans, output_ans)
|
| 136 |
+
reward = 1 - error_rate
|
| 137 |
+
reward = max(0.0, min(1.0, reward))
|
| 138 |
+
elif question_type == "free-form":
|
| 139 |
+
score = compute_rouge_score(gt_ans, output_ans)
|
| 140 |
+
reward = max(0.0, min(1.0, score))
|
| 141 |
+
elif question_type == "regression":
|
| 142 |
+
gt_number = normalize_number(gt_ans)
|
| 143 |
+
out_number = normalize_number(output_ans)
|
| 144 |
+
if gt_number is None or out_number is None:
|
| 145 |
+
reward = 0.0
|
| 146 |
+
rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
|
| 147 |
+
rel_diff = min(1.0, max(0.0, rel_diff))
|
| 148 |
+
reward = 1 - rel_diff
|
| 149 |
+
else:
|
| 150 |
+
reward = 0.0
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 153 |
+
reward = 0.0
|
| 154 |
+
|
| 155 |
+
rewards.append(reward)
|
| 156 |
+
|
| 157 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 158 |
+
log_path = os.getenv("LOG_PATH")
|
| 159 |
+
# local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 160 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 161 |
+
f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 162 |
+
f.write(f"Content: {content}\n")
|
| 163 |
+
f.write(f"Solution: {sol}\n")
|
| 164 |
+
|
| 165 |
+
return rewards
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def format_reward(completions, **kwargs):
|
| 169 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 170 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
| 171 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 172 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
|
| 173 |
+
return [0.1 if match else 0.0 for match in matches]
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
reward_funcs_registry = {
|
| 177 |
+
"accuracy": accuracy_reward,
|
| 178 |
+
"format": format_reward,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
SYSTEM_PROMPT = (
|
| 182 |
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
| 183 |
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
| 184 |
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
| 185 |
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main(script_args, training_args, model_args):
|
| 190 |
+
# Get reward functions
|
| 191 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 192 |
+
|
| 193 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 194 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 195 |
+
else:
|
| 196 |
+
# Load the dataset
|
| 197 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# Format into conversation
|
| 201 |
+
def make_conversation(example):
|
| 202 |
+
return {
|
| 203 |
+
"prompt": [
|
| 204 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 205 |
+
{"role": "user", "content": example["problem"]},
|
| 206 |
+
],
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
QUESTION_TEMPLATE = (
|
| 211 |
+
"{Question}\n"
|
| 212 |
+
"Please think about this question as if you were a human pondering deeply. "
|
| 213 |
+
"Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 214 |
+
"It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 215 |
+
"Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
TYPE_TEMPLATE = {
|
| 219 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 220 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 221 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 222 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 223 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
def make_conversation_image(example):
|
| 227 |
+
|
| 228 |
+
return {
|
| 229 |
+
"prompt": [
|
| 230 |
+
{
|
| 231 |
+
"role": "user",
|
| 232 |
+
"content": [
|
| 233 |
+
{"type": "image"},
|
| 234 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 235 |
+
],
|
| 236 |
+
},
|
| 237 |
+
],
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def make_conversation_video(example):
|
| 242 |
+
return {
|
| 243 |
+
"prompt": [
|
| 244 |
+
{
|
| 245 |
+
"role": "user",
|
| 246 |
+
"content": [
|
| 247 |
+
{"type": "video"},
|
| 248 |
+
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
|
| 249 |
+
],
|
| 250 |
+
},
|
| 251 |
+
],
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
def make_conversation_image_and_video(example):
|
| 255 |
+
if example["problem_type"] == 'multiple choice':
|
| 256 |
+
question = example['problem'] + "Options:\n"
|
| 257 |
+
for op in example["options"]:
|
| 258 |
+
question += op + "\n"
|
| 259 |
+
else:
|
| 260 |
+
question = example['problem']
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
msg ={
|
| 264 |
+
"prompt":
|
| 265 |
+
[{
|
| 266 |
+
"role": "user",
|
| 267 |
+
"content": [
|
| 268 |
+
{
|
| 269 |
+
"type": example['data_type'],
|
| 270 |
+
# example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"type": "text",
|
| 274 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 275 |
+
}
|
| 276 |
+
]
|
| 277 |
+
}]
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
return msg
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
dataset = dataset.map(make_conversation_image_and_video)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
|
| 287 |
+
print("using: ", trainer_cls)
|
| 288 |
+
|
| 289 |
+
# Initialize the GRPO trainer
|
| 290 |
+
trainer = trainer_cls(
|
| 291 |
+
model=model_args.model_name_or_path,
|
| 292 |
+
reward_funcs=reward_funcs,
|
| 293 |
+
args=training_args,
|
| 294 |
+
script_args=script_args,
|
| 295 |
+
train_dataset=dataset[script_args.dataset_train_split],
|
| 296 |
+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
| 297 |
+
peft_config=get_peft_config(model_args),
|
| 298 |
+
attn_implementation=model_args.attn_implementation,
|
| 299 |
+
max_pixels=script_args.max_pixels,
|
| 300 |
+
min_pixels=script_args.min_pixels,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if training_args.resume_from_checkpoint is not None:
|
| 304 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 305 |
+
trainer.train(resume_from_checkpoint=checkpoint)
|
| 306 |
+
else:
|
| 307 |
+
trainer.train()
|
| 308 |
+
|
| 309 |
+
# Save and push to hub
|
| 310 |
+
trainer.save_model(training_args.output_dir)
|
| 311 |
+
if training_args.push_to_hub:
|
| 312 |
+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 317 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 318 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/grpo_vllm_caption.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import re
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
import json
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
import base64
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
from rouge_score import rouge_scorer
|
| 26 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 27 |
+
import Levenshtein
|
| 28 |
+
import wandb
|
| 29 |
+
|
| 30 |
+
from dataclasses import dataclass, field
|
| 31 |
+
from typing import Optional
|
| 32 |
+
from math_verify import parse, verify
|
| 33 |
+
|
| 34 |
+
from trainer.grpo_trainer_vllm_caption import Qwen2VLGRPOTrainerCap
|
| 35 |
+
|
| 36 |
+
os.environ["WANDB_MODE"] = "offline"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
wandb.init(project="SelfEval-R1", name="SelfEval-R1")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 44 |
+
"""
|
| 45 |
+
Script arguments for the GRPO training script.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
reward_funcs (`list[str]`):
|
| 49 |
+
List of reward functions. Possible values: 'accuracy', 'format'.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
reward_funcs: list[str] = field(
|
| 53 |
+
default_factory=lambda: ["accuracy", "format"],
|
| 54 |
+
metadata={
|
| 55 |
+
"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 56 |
+
)
|
| 57 |
+
max_pixels: Optional[int] = field(
|
| 58 |
+
default=12845056,
|
| 59 |
+
metadata={"help": "Maximum number of pixels for the image"},
|
| 60 |
+
)
|
| 61 |
+
min_pixels: Optional[int] = field(
|
| 62 |
+
default=3136,
|
| 63 |
+
metadata={"help": "Minimum number of pixels for the image"},
|
| 64 |
+
)
|
| 65 |
+
caption_reward: Optional[bool] = field(
|
| 66 |
+
default=True,
|
| 67 |
+
metadata={"help": "Whether to use caption reward or not"},
|
| 68 |
+
)
|
| 69 |
+
caption_reward_weight: Optional[float] = field(
|
| 70 |
+
default=0.1,
|
| 71 |
+
metadata={"help": "Weight for the caption reward"},
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# This function is partially borrowed from Video-R1[https://github.com/tulerfeng/Video-R1]
|
| 76 |
+
def accuracy_reward(completions, solution, **kwargs):
|
| 77 |
+
|
| 78 |
+
def extract_answer(text):
|
| 79 |
+
pattern = r'<answer>(.*?)</answer>'
|
| 80 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 81 |
+
if match:
|
| 82 |
+
return match.group(1).strip()
|
| 83 |
+
return ""
|
| 84 |
+
|
| 85 |
+
def extract_option(text):
|
| 86 |
+
pattern = r'<option>(.*?)</option>'
|
| 87 |
+
match = re.search(pattern, text, re.DOTALL)
|
| 88 |
+
if match:
|
| 89 |
+
return match.group(1).strip()
|
| 90 |
+
return ""
|
| 91 |
+
|
| 92 |
+
def is_number(num_str):
|
| 93 |
+
try:
|
| 94 |
+
float(num_str)
|
| 95 |
+
return True
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
def extract_numbers(answer):
|
| 100 |
+
pattern = r"[-+]?\d*\.?\d+"
|
| 101 |
+
match = re.search(pattern, answer)
|
| 102 |
+
if match:
|
| 103 |
+
number_str = match.group()
|
| 104 |
+
if answer.strip().endswith('%'):
|
| 105 |
+
number = float(number_str) / 100
|
| 106 |
+
else:
|
| 107 |
+
number = float(number_str)
|
| 108 |
+
return number
|
| 109 |
+
else:
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
def anls(reference, hypothesis):
|
| 113 |
+
distance = Levenshtein.distance(reference, hypothesis)
|
| 114 |
+
max_length = max(len(reference), len(hypothesis))
|
| 115 |
+
similarity = 1 - (distance / max_length)
|
| 116 |
+
|
| 117 |
+
return similarity
|
| 118 |
+
|
| 119 |
+
def compute_rouge_score(reference, hypothesis, use_stemmer=True):
|
| 120 |
+
scorer = rouge_scorer.RougeScorer(
|
| 121 |
+
['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
|
| 122 |
+
scores = scorer.score(reference, hypothesis)
|
| 123 |
+
average_fmeasure = (
|
| 124 |
+
scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
|
| 125 |
+
return average_fmeasure
|
| 126 |
+
|
| 127 |
+
question_type = kwargs['problem_type'][0]
|
| 128 |
+
|
| 129 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 130 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 131 |
+
rewards = []
|
| 132 |
+
|
| 133 |
+
for content, sol in zip(contents, solution):
|
| 134 |
+
try:
|
| 135 |
+
output_ans = extract_answer(content)
|
| 136 |
+
gt_ans = extract_answer(sol)
|
| 137 |
+
if question_type == "OCR":
|
| 138 |
+
if is_number(gt_ans):
|
| 139 |
+
output_ans = extract_numbers(output_ans)
|
| 140 |
+
reward = 1.0 if output_ans == float(
|
| 141 |
+
gt_ans) else 0.0
|
| 142 |
+
else:
|
| 143 |
+
reward = anls(gt_ans.lower(),
|
| 144 |
+
output_ans.lower())
|
| 145 |
+
reward = max(0.0, min(1.0, reward))
|
| 146 |
+
elif question_type == "free-form":
|
| 147 |
+
score = compute_rouge_score(gt_ans, output_ans)
|
| 148 |
+
reward = max(0.0, min(1.0, score))
|
| 149 |
+
else:
|
| 150 |
+
if is_number(gt_ans):
|
| 151 |
+
output_ans = extract_numbers(output_ans)
|
| 152 |
+
reward = 1.0 if output_ans == float(
|
| 153 |
+
gt_ans) else 0.0
|
| 154 |
+
else:
|
| 155 |
+
reward = 1.0 if output_ans.lower() == gt_ans.lower() else 0.0
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(
|
| 158 |
+
f"Error in reward_fn for question_type '{question_type}': {e}")
|
| 159 |
+
reward = 0.0
|
| 160 |
+
|
| 161 |
+
rewards.append(reward)
|
| 162 |
+
|
| 163 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 164 |
+
log_path = 'debug.log'
|
| 165 |
+
with open(log_path, "a") as f:
|
| 166 |
+
try:
|
| 167 |
+
f.write(
|
| 168 |
+
f"------------- {current_time} Accuracy reward: {reward} -------------\n")
|
| 169 |
+
f.write(f"Content: {content}\n")
|
| 170 |
+
f.write(f"Solution: {sol}\n")
|
| 171 |
+
f.write(f"type: {question_type}\n")
|
| 172 |
+
except BaseException:
|
| 173 |
+
f.write("writeing error")
|
| 174 |
+
|
| 175 |
+
return rewards
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def format_reward(completions, **kwargs):
|
| 179 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 180 |
+
pattern = r"<info>.*?</info>\s<think>.*?</think>\s*<answer>.*?</answer>"
|
| 181 |
+
completion_contents = [completion[0]["content"]
|
| 182 |
+
for completion in completions]
|
| 183 |
+
matches = [re.fullmatch(pattern, content, re.DOTALL)
|
| 184 |
+
for content in completion_contents]
|
| 185 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
reward_funcs_registry = {
|
| 189 |
+
"accuracy": accuracy_reward,
|
| 190 |
+
"format": format_reward,
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
SYSTEM_PROMPT = (
|
| 195 |
+
"You are tasked with analyzing an image to generate an exhaustive and detailed description. "
|
| 196 |
+
"Your goal is to extract and describe all possible information from the image, including but not limited to objects, "
|
| 197 |
+
"numbers, text, and the relationships between these elements. The description should be as fine and detailed as possible, "
|
| 198 |
+
"capturing every nuance. After generating the detailed description, you need to analyze it and provide step-by-step "
|
| 199 |
+
"detailed reasoning for the given question based on the information. Finally, provide a single word or phrase answer "
|
| 200 |
+
"to the question. The description, reasoning process and answer are enclosed within <info> </info>, <think> </think> "
|
| 201 |
+
"and <answer> </answer> tags, respectively, i.e., <info> image description here </info> <think> reasoning process here "
|
| 202 |
+
"</think> <answer> answer here </answer>"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def main(script_args, training_args, model_args):
|
| 207 |
+
# Get reward functions
|
| 208 |
+
reward_funcs = [reward_funcs_registry[func]
|
| 209 |
+
for func in script_args.reward_funcs]
|
| 210 |
+
|
| 211 |
+
# Load the dataset
|
| 212 |
+
# dataset = load_dataset(script_args.dataset_name,
|
| 213 |
+
# name=script_args.dataset_config)
|
| 214 |
+
dataset = load_dataset("json", data_files=script_args.dataset_name, split='train')
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# Format into conversation
|
| 218 |
+
def make_conversation_image(example):
|
| 219 |
+
return {
|
| 220 |
+
"prompt": [
|
| 221 |
+
{"role": "system", "content": [
|
| 222 |
+
{"type": "text", "text": SYSTEM_PROMPT}]},
|
| 223 |
+
{
|
| 224 |
+
"role": "user",
|
| 225 |
+
"content": [
|
| 226 |
+
{"type": "image"},
|
| 227 |
+
{"type": "text", "text": example["problem"]},
|
| 228 |
+
],
|
| 229 |
+
},
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
dataset = dataset.map(make_conversation_image)
|
| 234 |
+
|
| 235 |
+
if "Qwen" in model_args.model_name_or_path or "Aria" in model_args.model_name_or_path:
|
| 236 |
+
trainer_cls = Qwen2VLGRPOTrainerCap
|
| 237 |
+
else:
|
| 238 |
+
trainer_cls = GRPOTrainer
|
| 239 |
+
|
| 240 |
+
# Initialize the GRPO trainer
|
| 241 |
+
trainer = trainer_cls(
|
| 242 |
+
model=model_args.model_name_or_path,
|
| 243 |
+
reward_funcs=reward_funcs,
|
| 244 |
+
args=training_args,
|
| 245 |
+
train_dataset=dataset,
|
| 246 |
+
eval_dataset=None,
|
| 247 |
+
peft_config=get_peft_config(model_args),
|
| 248 |
+
attn_implementation=model_args.attn_implementation,
|
| 249 |
+
max_pixels=script_args.max_pixels,
|
| 250 |
+
min_pixels=script_args.min_pixels,
|
| 251 |
+
caption_reward=script_args.caption_reward,
|
| 252 |
+
caption_reward_weight=script_args.caption_reward_weight,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
trainer.train()
|
| 256 |
+
# trainer.train()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
| 261 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 262 |
+
|
| 263 |
+
print('training_args:\n', training_args)
|
| 264 |
+
print('script_args:\n', script_args)
|
| 265 |
+
print('model_args:\n', model_args)
|
| 266 |
+
main(script_args, training_args, model_args)
|
src/r1-v/src/open_r1/sft_video.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Example usage:
|
| 16 |
+
accelerate launch \
|
| 17 |
+
--config_file=deepspeed_zero2.yaml \
|
| 18 |
+
train_video_llm.py \
|
| 19 |
+
--dataset_name mfarre/simplevideoshorts \
|
| 20 |
+
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
| 21 |
+
--per_device_train_batch_size 1 \
|
| 22 |
+
--gradient_accumulation_steps 4 \
|
| 23 |
+
--output_dir video-llm-output \
|
| 24 |
+
--bf16 \
|
| 25 |
+
--torch_dtype bfloat16 \
|
| 26 |
+
--gradient_checkpointing
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
import json
|
| 31 |
+
import random
|
| 32 |
+
import requests
|
| 33 |
+
import torch
|
| 34 |
+
from torch.optim import AdamW
|
| 35 |
+
from datasets import load_dataset
|
| 36 |
+
from transformers import (
|
| 37 |
+
AutoModelForVision2Seq,
|
| 38 |
+
AutoProcessor,
|
| 39 |
+
BitsAndBytesConfig,
|
| 40 |
+
Qwen2VLProcessor,
|
| 41 |
+
Qwen2VLForConditionalGeneration,
|
| 42 |
+
Qwen2_5_VLForConditionalGeneration
|
| 43 |
+
)
|
| 44 |
+
from transformers import get_linear_schedule_with_warmup
|
| 45 |
+
|
| 46 |
+
from trl import (
|
| 47 |
+
ModelConfig,
|
| 48 |
+
ScriptArguments,
|
| 49 |
+
SFTConfig,
|
| 50 |
+
SFTTrainer,
|
| 51 |
+
TrlParser,
|
| 52 |
+
get_kbit_device_map,
|
| 53 |
+
get_peft_config,
|
| 54 |
+
)
|
| 55 |
+
from accelerate import Accelerator
|
| 56 |
+
from qwen_vl_utils import process_vision_info
|
| 57 |
+
|
| 58 |
+
from datasets import Dataset, DatasetDict
|
| 59 |
+
|
| 60 |
+
import wandb
|
| 61 |
+
|
| 62 |
+
from typing import List, Dict, Any
|
| 63 |
+
|
| 64 |
+
os.environ["DS_BUILD_FUSED_ADAM"] = "0"
|
| 65 |
+
|
| 66 |
+
def get_current_device():
|
| 67 |
+
"""Get the current device. For GPU we return the local process index to enable multiple GPU training."""
|
| 68 |
+
return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
|
| 69 |
+
|
| 70 |
+
def download_video(url: str, folder: str = '/tmp/videos/') -> str:
|
| 71 |
+
"""Download video if not already present locally."""
|
| 72 |
+
filename = url.split("/")[-1]
|
| 73 |
+
local_path = os.path.join(folder, filename)
|
| 74 |
+
|
| 75 |
+
if os.path.exists(local_path):
|
| 76 |
+
return local_path
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
with requests.get(url, stream=True) as r:
|
| 80 |
+
r.raise_for_status()
|
| 81 |
+
with open(local_path, 'wb') as f:
|
| 82 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 83 |
+
if chunk:
|
| 84 |
+
f.write(chunk)
|
| 85 |
+
return local_path
|
| 86 |
+
except requests.RequestException as e:
|
| 87 |
+
raise Exception(f"Failed to download video: {e}")
|
| 88 |
+
|
| 89 |
+
def prepare_dataset(example: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
|
| 90 |
+
"""Prepare dataset example for training."""
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
system_message = "You are a helpful assistant"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
QUESTION_TEMPLATE = (
|
| 98 |
+
"{Question}\n"
|
| 99 |
+
"Please think about this question as if you were a human pondering deeply. "
|
| 100 |
+
"Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
|
| 101 |
+
"It's encouraged to include self-reflection or verification in the reasoning process. "
|
| 102 |
+
"Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
TYPE_TEMPLATE = {
|
| 106 |
+
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
|
| 107 |
+
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
|
| 108 |
+
"OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
|
| 109 |
+
"free-form": " Please provide your text answer within the <answer> </answer> tags.",
|
| 110 |
+
"regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if example["problem_type"] == 'multiple choice':
|
| 116 |
+
question = example['problem'] + "Options:\n"
|
| 117 |
+
for op in example["options"]:
|
| 118 |
+
question += op + "\n"
|
| 119 |
+
else:
|
| 120 |
+
question = example['problem']
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
messages = [
|
| 124 |
+
{
|
| 125 |
+
"role": "system",
|
| 126 |
+
"content": [{"type": "text", "text": system_message}]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"role": "user",
|
| 130 |
+
"content": [
|
| 131 |
+
{
|
| 132 |
+
"type": example['data_type'],
|
| 133 |
+
example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
|
| 134 |
+
# "max_pixels": 360*420,
|
| 135 |
+
# "fps": 1.0
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"type": "text",
|
| 139 |
+
"text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
|
| 140 |
+
}
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"role": "assistant",
|
| 145 |
+
"content": [{"type": "text", "text": example['process'] + "\n" + example['solution']}]
|
| 146 |
+
}
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
return {"messages": messages}
|
| 151 |
+
|
| 152 |
+
def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 153 |
+
"""Collate batch of examples for training."""
|
| 154 |
+
texts = []
|
| 155 |
+
# video_inputs = []
|
| 156 |
+
# image_inputs = []
|
| 157 |
+
|
| 158 |
+
for i, example in enumerate(examples):
|
| 159 |
+
try:
|
| 160 |
+
|
| 161 |
+
texts.append(processor.apply_chat_template(example["messages"], tokenize=False))
|
| 162 |
+
image_inputs, video_inputs, video_kwargs = process_vision_info(example["messages"], return_video_kwargs=True)
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
raise ValueError(f"Failed to process example {i}: {e}")
|
| 166 |
+
|
| 167 |
+
inputs = processor(
|
| 168 |
+
text=texts,
|
| 169 |
+
images=image_inputs,
|
| 170 |
+
videos=video_inputs,
|
| 171 |
+
return_tensors="pt",
|
| 172 |
+
padding=True
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
labels = inputs["input_ids"].clone()
|
| 176 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
| 177 |
+
|
| 178 |
+
# Handle visual tokens based on processor type
|
| 179 |
+
visual_tokens = [151652, 151653, 151656] if isinstance(processor, Qwen2VLProcessor) else [
|
| 180 |
+
processor.tokenizer.convert_tokens_to_ids(processor.image_token)
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
for visual_token_id in visual_tokens:
|
| 184 |
+
labels[labels == visual_token_id] = -100
|
| 185 |
+
|
| 186 |
+
inputs["labels"] = labels
|
| 187 |
+
return inputs
|
| 188 |
+
|
| 189 |
+
if __name__ == "__main__":
|
| 190 |
+
# Parse arguments
|
| 191 |
+
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
| 192 |
+
script_args, training_args, model_config = parser.parse_args_and_config()
|
| 193 |
+
|
| 194 |
+
# Configure training args
|
| 195 |
+
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
| 196 |
+
training_args.remove_unused_columns = False
|
| 197 |
+
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
| 198 |
+
|
| 199 |
+
# Load dataset
|
| 200 |
+
if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
|
| 201 |
+
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
|
| 202 |
+
else:
|
| 203 |
+
# Load the dataset
|
| 204 |
+
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
| 205 |
+
|
| 206 |
+
# Setup model
|
| 207 |
+
torch_dtype = (
|
| 208 |
+
model_config.torch_dtype
|
| 209 |
+
if model_config.torch_dtype in ["auto", None]
|
| 210 |
+
else getattr(torch, model_config.torch_dtype)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# # Quantization configuration for 4-bit training
|
| 214 |
+
# bnb_config = BitsAndBytesConfig(
|
| 215 |
+
# load_in_4bit=True,
|
| 216 |
+
# bnb_4bit_use_double_quant=True,
|
| 217 |
+
# bnb_4bit_quant_type="nf4",
|
| 218 |
+
# bnb_4bit_compute_dtype=torch.bfloat16
|
| 219 |
+
# )
|
| 220 |
+
|
| 221 |
+
# Model initialization
|
| 222 |
+
model_kwargs = dict(
|
| 223 |
+
revision=model_config.model_revision,
|
| 224 |
+
trust_remote_code=model_config.trust_remote_code,
|
| 225 |
+
torch_dtype=torch_dtype,
|
| 226 |
+
device_map=get_kbit_device_map(),
|
| 227 |
+
# quantization_config=bnb_config,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if "Qwen2-VL" in model_config.model_name_or_path:
|
| 232 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
| 233 |
+
elif "Qwen2.5-VL" in model_config.model_name_or_path:
|
| 234 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
| 235 |
+
else:
|
| 236 |
+
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
| 237 |
+
|
| 238 |
+
processor = AutoProcessor.from_pretrained(
|
| 239 |
+
model_config.model_name_or_path,
|
| 240 |
+
trust_remote_code=model_config.trust_remote_code
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Prepare dataset
|
| 244 |
+
prepared_dataset = [prepare_dataset(example) for example in dataset['train']]
|
| 245 |
+
|
| 246 |
+
# Initialize wandb if specified
|
| 247 |
+
if training_args.report_to == "wandb":
|
| 248 |
+
wandb.init(project="video-llm-training")
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
'''
|
| 252 |
+
Below is added code
|
| 253 |
+
'''
|
| 254 |
+
base_lr = 2e-4
|
| 255 |
+
optimizer = AdamW(
|
| 256 |
+
params=model.parameters(),
|
| 257 |
+
lr=base_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
num_training_steps = len(prepared_dataset) // (
|
| 261 |
+
training_args.per_device_train_batch_size
|
| 262 |
+
* training_args.gradient_accumulation_steps
|
| 263 |
+
* training_args.world_size
|
| 264 |
+
) * training_args.num_train_epochs
|
| 265 |
+
|
| 266 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 267 |
+
optimizer,
|
| 268 |
+
num_warmup_steps=int(0.05 * num_training_steps),
|
| 269 |
+
num_training_steps=num_training_steps,
|
| 270 |
+
)
|
| 271 |
+
'''
|
| 272 |
+
Above is added code
|
| 273 |
+
'''
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Initialize trainer
|
| 277 |
+
trainer = SFTTrainer(
|
| 278 |
+
model=model,
|
| 279 |
+
args=training_args,
|
| 280 |
+
train_dataset=prepared_dataset,
|
| 281 |
+
data_collator=collate_fn,
|
| 282 |
+
peft_config=get_peft_config(model_config),
|
| 283 |
+
# tokenizer=processor.tokenizer
|
| 284 |
+
optimizers=(optimizer, lr_scheduler),
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Train model
|
| 288 |
+
trainer.train()
|
| 289 |
+
|
| 290 |
+
# Save final model
|
| 291 |
+
|
| 292 |
+
trainer.save_model(training_args.output_dir)
|
| 293 |
+
processor.save_pretrained(training_args.output_dir)
|
| 294 |
+
|
| 295 |
+
if trainer.accelerator.is_main_process:
|
| 296 |
+
# Restore k,v cache for fast inference
|
| 297 |
+
trainer.model.config.use_cache = True
|
| 298 |
+
trainer.model.config.save_pretrained(training_args.output_dir)
|
| 299 |
+
|
| 300 |
+
# Cleanup
|
| 301 |
+
del model
|
| 302 |
+
del trainer
|
| 303 |
+
torch.cuda.empty_cache()
|
| 304 |
+
wandb.finish()
|
src/r1-v/src/open_r1/trainer/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .grpo_trainer import Qwen2VLGRPOTrainer
|
| 2 |
+
from .vllm_grpo_trainer_modified import Qwen2VLGRPOVLLMTrainerModified
|
| 3 |
+
from .vllm_grpo_trainer_modified_orig import Qwen2VLGRPOVLLMTrainerModifiedOrig
|
| 4 |
+
from .vllm_grpo_trainer_selfConst import Qwen2VLGRPOVLLMTrainerSelfConst
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"Qwen2VLGRPOTrainer",
|
| 9 |
+
"Qwen2VLGRPOVLLMTrainerModified",
|
| 10 |
+
"Qwen2VLGRPOVLLMTrainerModifiedOrig",
|
| 11 |
+
"Qwen2VLGRPOVLLMTrainerSelfConst"
|
| 12 |
+
]
|
src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_error.py
ADDED
|
@@ -0,0 +1,1061 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import textwrap
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from typing import Any, Callable, Optional, Union
|
| 19 |
+
from accelerate.utils.other import is_compiled_module
|
| 20 |
+
from accelerate.utils import broadcast_object_list, gather, gather_object
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.data
|
| 23 |
+
import transformers
|
| 24 |
+
import warnings
|
| 25 |
+
from unittest.mock import patch
|
| 26 |
+
from datasets import Dataset, IterableDataset
|
| 27 |
+
from packaging import version
|
| 28 |
+
from transformers import (
|
| 29 |
+
AriaForConditionalGeneration,
|
| 30 |
+
AriaProcessor,
|
| 31 |
+
AutoModelForCausalLM,
|
| 32 |
+
AutoModelForSequenceClassification,
|
| 33 |
+
AutoProcessor,
|
| 34 |
+
AutoTokenizer,
|
| 35 |
+
GenerationConfig,
|
| 36 |
+
PreTrainedModel,
|
| 37 |
+
PreTrainedTokenizerBase,
|
| 38 |
+
Qwen2VLForConditionalGeneration,
|
| 39 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 40 |
+
Trainer,
|
| 41 |
+
TrainerCallback,
|
| 42 |
+
is_wandb_available,
|
| 43 |
+
)
|
| 44 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 45 |
+
from transformers.utils import is_peft_available
|
| 46 |
+
|
| 47 |
+
from trl.data_utils import (
|
| 48 |
+
apply_chat_template,
|
| 49 |
+
is_conversational,
|
| 50 |
+
maybe_apply_chat_template,
|
| 51 |
+
)
|
| 52 |
+
from trl.import_utils import is_vllm_available
|
| 53 |
+
|
| 54 |
+
from trl.models import (
|
| 55 |
+
create_reference_model,
|
| 56 |
+
prepare_deepspeed,
|
| 57 |
+
unwrap_model_for_generation,
|
| 58 |
+
)
|
| 59 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 60 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
|
| 61 |
+
from trl import GRPOTrainer
|
| 62 |
+
|
| 63 |
+
import copy
|
| 64 |
+
|
| 65 |
+
if is_peft_available():
|
| 66 |
+
from peft import PeftConfig, get_peft_model
|
| 67 |
+
|
| 68 |
+
if is_vllm_available():
|
| 69 |
+
from vllm import LLM, SamplingParams
|
| 70 |
+
|
| 71 |
+
if is_wandb_available():
|
| 72 |
+
import wandb
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
from torch.utils.data import Sampler
|
| 75 |
+
import gc
|
| 76 |
+
from qwen_vl_utils import process_vision_info
|
| 77 |
+
|
| 78 |
+
import re
|
| 79 |
+
|
| 80 |
+
def extract_answer(predict: str) -> Optional[str]:
|
| 81 |
+
"""
|
| 82 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 83 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 84 |
+
or None if no <answer> tag is found.
|
| 85 |
+
"""
|
| 86 |
+
match = re.search(r"<answer>([\s\S]*?)</answer>", predict, re.DOTALL)
|
| 87 |
+
if not match:
|
| 88 |
+
return None
|
| 89 |
+
return match.group(1).strip()
|
| 90 |
+
|
| 91 |
+
def extract_info(predict: str) -> Optional[str]:
|
| 92 |
+
"""
|
| 93 |
+
Extracts the content of the <answer>…</answer> block from `predict`.
|
| 94 |
+
Returns the inner text (with leading/trailing whitespace stripped),
|
| 95 |
+
or None if no <answer> tag is found.
|
| 96 |
+
"""
|
| 97 |
+
match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
|
| 98 |
+
if not match:
|
| 99 |
+
return None
|
| 100 |
+
return match.group(1).strip()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 105 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 106 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Qwen2VLGRPOVLLMTrainerModified(Trainer):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
model: Union[str, PreTrainedModel],
|
| 113 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 114 |
+
args: GRPOConfig = None,
|
| 115 |
+
script_args = None,
|
| 116 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 117 |
+
eval_dataset: Optional[
|
| 118 |
+
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
|
| 119 |
+
] = None,
|
| 120 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 121 |
+
reward_processing_classes: Optional[
|
| 122 |
+
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
|
| 123 |
+
] = None,
|
| 124 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 125 |
+
optimizers: tuple[
|
| 126 |
+
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
| 127 |
+
] = (None, None),
|
| 128 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 129 |
+
# qwen2-vl related params
|
| 130 |
+
max_pixels: Optional[int] = 12845056,
|
| 131 |
+
min_pixels: Optional[int] = 3136,
|
| 132 |
+
attn_implementation: str = "flash_attention_2",
|
| 133 |
+
):
|
| 134 |
+
|
| 135 |
+
# Args
|
| 136 |
+
if args is None:
|
| 137 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 138 |
+
model_name = model_name.split("/")[-1]
|
| 139 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 140 |
+
|
| 141 |
+
# Models
|
| 142 |
+
# Trained model
|
| 143 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 144 |
+
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 145 |
+
if isinstance(model, str):
|
| 146 |
+
model_id = model
|
| 147 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 148 |
+
if (
|
| 149 |
+
isinstance(torch_dtype, torch.dtype)
|
| 150 |
+
or torch_dtype == "auto"
|
| 151 |
+
or torch_dtype is None
|
| 152 |
+
):
|
| 153 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 154 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 155 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 156 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 157 |
+
else:
|
| 158 |
+
raise ValueError(
|
| 159 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 160 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 161 |
+
)
|
| 162 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 163 |
+
model_init_kwargs["use_cache"] = (
|
| 164 |
+
False
|
| 165 |
+
if args.gradient_checkpointing
|
| 166 |
+
else model_init_kwargs.get("use_cache")
|
| 167 |
+
)
|
| 168 |
+
if "Qwen2-VL" in model_id:
|
| 169 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 170 |
+
model, **model_init_kwargs
|
| 171 |
+
)
|
| 172 |
+
elif "Qwen2.5-VL" in model_id:
|
| 173 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 174 |
+
model, **model_init_kwargs
|
| 175 |
+
)
|
| 176 |
+
elif "Aria" in model_id:
|
| 177 |
+
model_init_kwargs.pop("use_cache")
|
| 178 |
+
model = AriaForConditionalGeneration.from_pretrained(
|
| 179 |
+
model, **model_init_kwargs
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
| 183 |
+
else:
|
| 184 |
+
model_id = model.config._name_or_path
|
| 185 |
+
if args.model_init_kwargs is not None:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
| 188 |
+
"This argument can only be used when the `model` argument is a string."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
if peft_config is not None:
|
| 192 |
+
model = get_peft_model(model, peft_config)
|
| 193 |
+
|
| 194 |
+
# Reference model
|
| 195 |
+
if is_deepspeed_zero3_enabled():
|
| 196 |
+
if "Qwen2-VL" in model_id:
|
| 197 |
+
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 198 |
+
model_id, **model_init_kwargs
|
| 199 |
+
)
|
| 200 |
+
elif "Qwen2.5-VL" in model_id:
|
| 201 |
+
self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 202 |
+
model_id, **model_init_kwargs
|
| 203 |
+
)
|
| 204 |
+
elif "Aria" in model_id:
|
| 205 |
+
self.ref_model = AriaForConditionalGeneration.from_pretrained(
|
| 206 |
+
model_id, **model_init_kwargs
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 210 |
+
model_id, **model_init_kwargs
|
| 211 |
+
)
|
| 212 |
+
elif peft_config is None:
|
| 213 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 214 |
+
self.ref_model = create_reference_model(model)
|
| 215 |
+
else:
|
| 216 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 217 |
+
# to revert to the initial model.
|
| 218 |
+
self.ref_model = None
|
| 219 |
+
|
| 220 |
+
# Processing class
|
| 221 |
+
if processing_class is None:
|
| 222 |
+
if "Qwen" in model_id or "Aria" in model_id:
|
| 223 |
+
processing_class = AutoProcessor.from_pretrained(model_id)
|
| 224 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 225 |
+
processing_class.pad_token_id = pad_token_id
|
| 226 |
+
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 227 |
+
if "Qwen" in model_id:
|
| 228 |
+
processing_class.image_processor.max_pixels = max_pixels
|
| 229 |
+
processing_class.image_processor.min_pixels = min_pixels
|
| 230 |
+
else:
|
| 231 |
+
processing_class = AutoTokenizer.from_pretrained(
|
| 232 |
+
model.config._name_or_path, padding_side="left"
|
| 233 |
+
)
|
| 234 |
+
pad_token_id = processing_class.pad_token_id
|
| 235 |
+
|
| 236 |
+
# Reward functions
|
| 237 |
+
if not isinstance(reward_funcs, list):
|
| 238 |
+
reward_funcs = [reward_funcs]
|
| 239 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 240 |
+
if isinstance(reward_func, str):
|
| 241 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 242 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 243 |
+
)
|
| 244 |
+
self.reward_funcs = reward_funcs
|
| 245 |
+
|
| 246 |
+
# Reward processing class
|
| 247 |
+
if reward_processing_classes is None:
|
| 248 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 249 |
+
elif not isinstance(reward_processing_classes, list):
|
| 250 |
+
reward_processing_classes = [reward_processing_classes]
|
| 251 |
+
else:
|
| 252 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 253 |
+
raise ValueError(
|
| 254 |
+
"The number of reward processing classes must match the number of reward functions."
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
for i, (reward_processing_class, reward_func) in enumerate(
|
| 258 |
+
zip(reward_processing_classes, reward_funcs)
|
| 259 |
+
):
|
| 260 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 261 |
+
if reward_processing_class is None:
|
| 262 |
+
reward_processing_class = AutoTokenizer.from_pretrained(
|
| 263 |
+
reward_func.config._name_or_path
|
| 264 |
+
)
|
| 265 |
+
if reward_processing_class.pad_token_id is None:
|
| 266 |
+
reward_processing_class.pad_token = (
|
| 267 |
+
reward_processing_class.eos_token
|
| 268 |
+
)
|
| 269 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 270 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 271 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 272 |
+
reward_processing_classes[i] = reward_processing_class
|
| 273 |
+
self.reward_processing_classes = reward_processing_classes
|
| 274 |
+
|
| 275 |
+
# Data collator
|
| 276 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 277 |
+
return features
|
| 278 |
+
|
| 279 |
+
# Training arguments
|
| 280 |
+
self.max_prompt_length = args.max_prompt_length
|
| 281 |
+
self.max_completion_length = (
|
| 282 |
+
args.max_completion_length
|
| 283 |
+
) # = |o_i| in the GRPO paper
|
| 284 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 285 |
+
self.temporal = script_args.temporal
|
| 286 |
+
self.generation_config = GenerationConfig(
|
| 287 |
+
max_new_tokens=self.max_completion_length,
|
| 288 |
+
do_sample=True,
|
| 289 |
+
temperature=1, # HACK
|
| 290 |
+
num_return_sequences=self.num_generations,
|
| 291 |
+
pad_token_id=pad_token_id,
|
| 292 |
+
)
|
| 293 |
+
self.beta = args.beta
|
| 294 |
+
|
| 295 |
+
self.shuffled_num_generations = self.num_generations // 2
|
| 296 |
+
self.shuffled_generation_config = GenerationConfig(
|
| 297 |
+
max_new_tokens=self.max_completion_length,
|
| 298 |
+
do_sample=True,
|
| 299 |
+
top_p=0.95,
|
| 300 |
+
temperature=1, # HACK
|
| 301 |
+
num_return_sequences=self.shuffled_num_generations,
|
| 302 |
+
pad_token_id=pad_token_id,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.dummy_generation_config = GenerationConfig(
|
| 306 |
+
max_new_tokens=1,
|
| 307 |
+
do_sample=True,
|
| 308 |
+
top_p=0.95,
|
| 309 |
+
temperature=1, # HACK
|
| 310 |
+
num_return_sequences=1,
|
| 311 |
+
pad_token_id=pad_token_id,
|
| 312 |
+
)
|
| 313 |
+
self.len_control = script_args.len_control
|
| 314 |
+
self.beta = args.beta
|
| 315 |
+
|
| 316 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 317 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 318 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 319 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 320 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 321 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 322 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 323 |
+
|
| 324 |
+
# Initialize the metrics
|
| 325 |
+
self._metrics = defaultdict(list)
|
| 326 |
+
self.use_vllm = args.use_vllm
|
| 327 |
+
|
| 328 |
+
super().__init__(
|
| 329 |
+
model=model,
|
| 330 |
+
args=args,
|
| 331 |
+
data_collator=data_collator,
|
| 332 |
+
train_dataset=train_dataset,
|
| 333 |
+
eval_dataset=eval_dataset,
|
| 334 |
+
processing_class=processing_class,
|
| 335 |
+
callbacks=callbacks,
|
| 336 |
+
optimizers=optimizers,
|
| 337 |
+
)
|
| 338 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 339 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 340 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 341 |
+
self.model_accepts_loss_kwargs = False
|
| 342 |
+
|
| 343 |
+
if self.use_vllm:
|
| 344 |
+
if not is_vllm_available():
|
| 345 |
+
raise ImportError(
|
| 346 |
+
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
| 347 |
+
"`pip install vllm` to use it."
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if self.accelerator.is_main_process:
|
| 351 |
+
vllm_device = self.args.vllm_device
|
| 352 |
+
if vllm_device == "auto":
|
| 353 |
+
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
| 354 |
+
# Check that the requested device is available
|
| 355 |
+
if (
|
| 356 |
+
vllm_device.split(":")[0] == "cuda"
|
| 357 |
+
and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
|
| 358 |
+
):
|
| 359 |
+
raise ValueError(
|
| 360 |
+
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
| 361 |
+
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
| 362 |
+
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
| 363 |
+
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
| 364 |
+
)
|
| 365 |
+
# Check that the requested device is not also used for training
|
| 366 |
+
if vllm_device in {
|
| 367 |
+
f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
|
| 368 |
+
}:
|
| 369 |
+
warnings.warn(
|
| 370 |
+
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
| 371 |
+
"behavior. It is recommended to use a dedicated device for vLLM."
|
| 372 |
+
)
|
| 373 |
+
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
| 374 |
+
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
| 375 |
+
# setting (profiling_patch).
|
| 376 |
+
world_size_patch = patch(
|
| 377 |
+
"torch.distributed.get_world_size", return_value=1
|
| 378 |
+
)
|
| 379 |
+
profiling_patch = patch(
|
| 380 |
+
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
|
| 381 |
+
return_value=None,
|
| 382 |
+
)
|
| 383 |
+
with world_size_patch, profiling_patch:
|
| 384 |
+
print("vllm is running on: ", vllm_device)
|
| 385 |
+
self.llm = LLM(
|
| 386 |
+
model=model.name_or_path,
|
| 387 |
+
device=vllm_device,
|
| 388 |
+
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
| 389 |
+
dtype=torch.bfloat16,
|
| 390 |
+
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
| 391 |
+
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
| 392 |
+
# This is particularly useful here because we generate completions from the same prompts.
|
| 393 |
+
enable_prefix_caching=True,
|
| 394 |
+
enforce_eager=True,
|
| 395 |
+
mm_processor_kwargs=(
|
| 396 |
+
{
|
| 397 |
+
"max_pixels": max_pixels,
|
| 398 |
+
"min_pixels": min_pixels,
|
| 399 |
+
}
|
| 400 |
+
# if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
|
| 401 |
+
if False
|
| 402 |
+
else None
|
| 403 |
+
),
|
| 404 |
+
max_model_len=args.max_prompt_length + args.max_completion_length,
|
| 405 |
+
)
|
| 406 |
+
self.sampling_params = SamplingParams(
|
| 407 |
+
temperature=1.0,
|
| 408 |
+
top_p=0.95,
|
| 409 |
+
max_tokens=self.max_completion_length,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
|
| 413 |
+
|
| 414 |
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
| 415 |
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
| 416 |
+
# synchronize all processes after vLLM has been fully initialized.
|
| 417 |
+
self.accelerator.wait_for_everyone()
|
| 418 |
+
else:
|
| 419 |
+
raise ValueError(
|
| 420 |
+
"GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
if self.ref_model is not None:
|
| 424 |
+
if self.is_deepspeed_enabled:
|
| 425 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 426 |
+
else:
|
| 427 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 428 |
+
|
| 429 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 430 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 431 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 432 |
+
|
| 433 |
+
def _set_signature_columns_if_needed(self):
|
| 434 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 435 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 436 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 437 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 438 |
+
if self._signature_columns is None:
|
| 439 |
+
self._signature_columns = ["prompt"]
|
| 440 |
+
|
| 441 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 442 |
+
def _get_per_token_logps(self, model, input_ids, **kwargs):
|
| 443 |
+
# logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
|
| 444 |
+
# import pdb
|
| 445 |
+
# pdb.set_trace()
|
| 446 |
+
logits = model(input_ids, **kwargs).logits
|
| 447 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 448 |
+
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
| 449 |
+
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
| 450 |
+
per_token_logps = []
|
| 451 |
+
for logits_row, input_ids_row in zip(logits, input_ids):
|
| 452 |
+
log_probs = logits_row.log_softmax(dim=-1)
|
| 453 |
+
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
| 454 |
+
per_token_logps.append(token_log_prob)
|
| 455 |
+
return torch.stack(per_token_logps)
|
| 456 |
+
|
| 457 |
+
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
|
| 458 |
+
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
|
| 459 |
+
def _prepare_inputs(
|
| 460 |
+
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
| 461 |
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
| 462 |
+
return inputs
|
| 463 |
+
|
| 464 |
+
def remove_none_from_data(self, data):
|
| 465 |
+
for entry in data:
|
| 466 |
+
if "content" in entry and isinstance(entry["content"], list):
|
| 467 |
+
for sub_entry in entry["content"]:
|
| 468 |
+
if isinstance(sub_entry, dict):
|
| 469 |
+
keys_to_remove = [k for k, v in sub_entry.items() if v is None]
|
| 470 |
+
for k in keys_to_remove:
|
| 471 |
+
del sub_entry[k]
|
| 472 |
+
return data
|
| 473 |
+
|
| 474 |
+
def _vllm_generate(self, prompts_text, mm_data, n):
|
| 475 |
+
"""
|
| 476 |
+
Helper that wraps the whole ‘gather-broadcast-slice-pad-decode’ dance
|
| 477 |
+
and returns (completion_ids, decoded_texts) *ON THIS RANK ONLY*.
|
| 478 |
+
`mm_data` can be None/[] for pure-text inputs.
|
| 479 |
+
"""
|
| 480 |
+
device = self.accelerator.device
|
| 481 |
+
|
| 482 |
+
# --------------- gather everything to rank-0 ----------------
|
| 483 |
+
all_prompts = gather_object(prompts_text)
|
| 484 |
+
all_mm_data = gather_object(mm_data or [[]] * len(prompts_text))
|
| 485 |
+
|
| 486 |
+
# build the multimodal inputs expected by vLLM
|
| 487 |
+
vllm_inputs = [
|
| 488 |
+
{"prompt": p, "multi_modal_data": m[0] if m else {}}
|
| 489 |
+
for p, m in zip(all_prompts, all_mm_data)
|
| 490 |
+
]
|
| 491 |
+
|
| 492 |
+
# -------------------------------------------------------------
|
| 493 |
+
if self.accelerator.is_main_process:
|
| 494 |
+
p = copy.deepcopy(self.sampling_params)
|
| 495 |
+
p.n = n
|
| 496 |
+
outs = self.llm.generate(vllm_inputs, sampling_params=p, use_tqdm=False)
|
| 497 |
+
comp_ids = [o.token_ids for c in outs for o in c.outputs]
|
| 498 |
+
else:
|
| 499 |
+
comp_ids = [None] * (len(vllm_inputs) * n)
|
| 500 |
+
|
| 501 |
+
# broadcast back, pick this rank’s slice
|
| 502 |
+
comp_ids = broadcast_object_list(comp_ids, from_process=0)
|
| 503 |
+
lo = self.accelerator.process_index * len(prompts_text) * n
|
| 504 |
+
hi = (self.accelerator.process_index + 1) * len(prompts_text) * n
|
| 505 |
+
comp_ids = comp_ids[lo:hi]
|
| 506 |
+
|
| 507 |
+
# pad, convert to tensor → decode
|
| 508 |
+
comp_ids = [torch.tensor(x, device=device) for x in comp_ids]
|
| 509 |
+
comp_ids = pad(comp_ids, padding_value=self.processing_class.pad_token_id)
|
| 510 |
+
decoded = self.processing_class.batch_decode(comp_ids, skip_special_tokens=True)
|
| 511 |
+
return comp_ids, decoded
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def compute_loss(
|
| 515 |
+
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
| 516 |
+
):
|
| 517 |
+
if return_outputs:
|
| 518 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 519 |
+
# Compute the per-token log probabilities for the model
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
device = self.accelerator.device
|
| 523 |
+
prompts = [x["prompt"] for x in inputs]
|
| 524 |
+
# images = [x["image"] for x in inputs]
|
| 525 |
+
prompts_text = [
|
| 526 |
+
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
| 527 |
+
for example in inputs
|
| 528 |
+
]
|
| 529 |
+
|
| 530 |
+
input_copy = copy.deepcopy(inputs[0]['prompt'])
|
| 531 |
+
|
| 532 |
+
input_copy = self.remove_none_from_data(input_copy)
|
| 533 |
+
|
| 534 |
+
data_type = inputs[0]['data_type']
|
| 535 |
+
|
| 536 |
+
if data_type == 'image':
|
| 537 |
+
input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
|
| 538 |
+
elif data_type == 'video':
|
| 539 |
+
input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
prompt_inputs = self.processing_class(
|
| 546 |
+
text=copy.deepcopy(prompts_text),
|
| 547 |
+
images=image_inputs,
|
| 548 |
+
videos=video_inputs,
|
| 549 |
+
return_tensors="pt",
|
| 550 |
+
padding=True,
|
| 551 |
+
padding_side="left",
|
| 552 |
+
add_special_tokens=False,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
|
| 556 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 557 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 558 |
+
|
| 559 |
+
if self.max_prompt_length is not None:
|
| 560 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 561 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
if self.temporal:
|
| 565 |
+
if video_inputs:
|
| 566 |
+
indices = torch.randperm(video_inputs[0].size(0))
|
| 567 |
+
shuffled_video_inputs = [video_inputs[0][indices]]
|
| 568 |
+
shuffled_prompt_inputs = self.processing_class(
|
| 569 |
+
text=copy.deepcopy(prompts_text),
|
| 570 |
+
images=image_inputs,
|
| 571 |
+
videos=shuffled_video_inputs,
|
| 572 |
+
return_tensors="pt",
|
| 573 |
+
padding=True,
|
| 574 |
+
padding_side="left",
|
| 575 |
+
add_special_tokens=False,
|
| 576 |
+
)
|
| 577 |
+
shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
|
| 578 |
+
shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
|
| 579 |
+
shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
|
| 580 |
+
if self.max_prompt_length is not None:
|
| 581 |
+
shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
|
| 582 |
+
shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
|
| 583 |
+
else:
|
| 584 |
+
shuffled_mm_data = [None]
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
if self.args.use_vllm:
|
| 589 |
+
# First, have main process load weights if needed
|
| 590 |
+
if self.state.global_step != self._last_loaded_step:
|
| 591 |
+
with unwrap_model_for_generation(
|
| 592 |
+
self.model,
|
| 593 |
+
self.accelerator,
|
| 594 |
+
gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
|
| 595 |
+
) as unwrapped_model:
|
| 596 |
+
if is_compiled_module(unwrapped_model):
|
| 597 |
+
state_dict = unwrapped_model._orig_mod.state_dict()
|
| 598 |
+
else:
|
| 599 |
+
state_dict = unwrapped_model.state_dict()
|
| 600 |
+
if self.accelerator.is_main_process:
|
| 601 |
+
llm_model = (
|
| 602 |
+
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
| 603 |
+
)
|
| 604 |
+
# import pdb
|
| 605 |
+
# pdb.set_trace()
|
| 606 |
+
llm_model.load_weights(state_dict.items())
|
| 607 |
+
self._last_loaded_step = self.state.global_step
|
| 608 |
+
'''
|
| 609 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
| 610 |
+
all_prompts_text = gather_object(prompts_text)
|
| 611 |
+
all_mm_data = gather_object(mm_data)
|
| 612 |
+
# group into pairs
|
| 613 |
+
all_multimodal_inputs = []
|
| 614 |
+
|
| 615 |
+
if self.temporal:
|
| 616 |
+
shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
|
| 617 |
+
shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
|
| 618 |
+
shuffled_all_multimodal_inputs = []
|
| 619 |
+
|
| 620 |
+
# 2. Refer to TobiasLee's implementation suggestions
|
| 621 |
+
# this is a better implementation for vLLM sampling.
|
| 622 |
+
for prompt, mm_item in zip(all_prompts_text, all_mm_data):
|
| 623 |
+
all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
|
| 624 |
+
|
| 625 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 626 |
+
for mm_item in shuffled_all_mm_data:
|
| 627 |
+
shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
|
| 628 |
+
|
| 629 |
+
# Create sampling params with num_generations
|
| 630 |
+
if self.accelerator.is_main_process:
|
| 631 |
+
# Clone to avoid modifying original params
|
| 632 |
+
sampling_params = copy.deepcopy(self.sampling_params)
|
| 633 |
+
sampling_params.n = self.num_generations
|
| 634 |
+
# Single generate call with all prompts
|
| 635 |
+
if self.accelerator.is_main_process:
|
| 636 |
+
outputs = self.llm.generate(
|
| 637 |
+
all_multimodal_inputs,
|
| 638 |
+
sampling_params=sampling_params,
|
| 639 |
+
use_tqdm=False,
|
| 640 |
+
)
|
| 641 |
+
# Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
|
| 642 |
+
completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
|
| 643 |
+
|
| 644 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 645 |
+
# Clone to avoid modifying original params
|
| 646 |
+
shuffled_sampling_params = copy.deepcopy(self.sampling_params)
|
| 647 |
+
shuffled_sampling_params.n = self.num_generations // 2
|
| 648 |
+
# Single generate call with all prompts
|
| 649 |
+
if self.accelerator.is_main_process:
|
| 650 |
+
shuffled_outputs = self.llm.generate(
|
| 651 |
+
shuffled_all_multimodal_inputs,
|
| 652 |
+
sampling_params=shuffled_sampling_params,
|
| 653 |
+
use_tqdm=False,
|
| 654 |
+
)
|
| 655 |
+
# Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
|
| 656 |
+
shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
else:
|
| 660 |
+
completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
|
| 661 |
+
|
| 662 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 663 |
+
shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
# broadcast and slice
|
| 667 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
| 668 |
+
process_slice = slice(
|
| 669 |
+
self.accelerator.process_index * len(prompts) * self.num_generations,
|
| 670 |
+
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
|
| 671 |
+
)
|
| 672 |
+
completion_ids = completion_ids[process_slice]
|
| 673 |
+
|
| 674 |
+
# Pad the completions, and concatenate them with the prompts
|
| 675 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
| 676 |
+
completion_ids = pad(
|
| 677 |
+
completion_ids, padding_value=self.processing_class.pad_token_id
|
| 678 |
+
)
|
| 679 |
+
'''
|
| 680 |
+
|
| 681 |
+
completion_ids, completions = self._vllm_generate(
|
| 682 |
+
prompts_text, # original text prompts
|
| 683 |
+
mm_data, # vision payload (may be empty for text-only)
|
| 684 |
+
self.num_generations,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
|
| 688 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 689 |
+
|
| 690 |
+
prompt_length = prompt_ids.size(1)
|
| 691 |
+
|
| 692 |
+
print('prompt_length:', prompt_length)
|
| 693 |
+
|
| 694 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 695 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 696 |
+
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
'''
|
| 700 |
+
This is the additional code that avoids the shuffled_all_mm_data variable undefined error.
|
| 701 |
+
'''
|
| 702 |
+
if self.temporal and video_inputs:
|
| 703 |
+
# ❶ make the shuffled video batch (you already computed shuffled_video_inputs)
|
| 704 |
+
local_shuffled_mm = [[data_type, shuffled_video_inputs]]
|
| 705 |
+
shuffled_prompts = copy.deepcopy(prompts_text)
|
| 706 |
+
|
| 707 |
+
# ❷ generate half as many completions for each prompt
|
| 708 |
+
shuffled_completion_ids, _ = self._vllm_generate(
|
| 709 |
+
prompts_text=shuffled_prompts,
|
| 710 |
+
mm_data=local_shuffled_mm,
|
| 711 |
+
n=self.num_generations // 2,
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
# ❸ mimic the old triple-list so later broadcast logic works unchanged
|
| 715 |
+
shuffled_all_mm_data = [[self.accelerator.process_index,
|
| 716 |
+
data_type,
|
| 717 |
+
shuffled_video_inputs]]
|
| 718 |
+
# -----------------------------------------------------------------
|
| 719 |
+
|
| 720 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 721 |
+
# broadcast and slice
|
| 722 |
+
shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
|
| 723 |
+
process_id_list = []
|
| 724 |
+
for mm_item in shuffled_all_mm_data:
|
| 725 |
+
process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
|
| 726 |
+
|
| 727 |
+
if video_inputs:
|
| 728 |
+
cur_shuffled_completion_ids = []
|
| 729 |
+
for i in range(len(process_id_list)):
|
| 730 |
+
if self.accelerator.process_index == process_id_list[i]:
|
| 731 |
+
cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
|
| 732 |
+
|
| 733 |
+
# Pad the completions, and concatenate them with the prompts
|
| 734 |
+
cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
|
| 735 |
+
cur_shuffled_completion_ids = pad(
|
| 736 |
+
cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
|
| 737 |
+
)
|
| 738 |
+
shuffled_completion_ids = cur_shuffled_completion_ids
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
else:
|
| 742 |
+
raise ValueError("Only vLLM generation is supported in this version ")
|
| 743 |
+
'''Above is additional code'''
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 747 |
+
# broadcast and slice
|
| 748 |
+
shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
|
| 749 |
+
process_id_list = []
|
| 750 |
+
for mm_item in shuffled_all_mm_data:
|
| 751 |
+
process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
|
| 752 |
+
|
| 753 |
+
if video_inputs:
|
| 754 |
+
cur_shuffled_completion_ids = []
|
| 755 |
+
for i in range(len(process_id_list)):
|
| 756 |
+
if self.accelerator.process_index == process_id_list[i]:
|
| 757 |
+
cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
|
| 758 |
+
|
| 759 |
+
# Pad the completions, and concatenate them with the prompts
|
| 760 |
+
cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
|
| 761 |
+
cur_shuffled_completion_ids = pad(
|
| 762 |
+
cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
|
| 763 |
+
)
|
| 764 |
+
shuffled_completion_ids = cur_shuffled_completion_ids
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
else:
|
| 768 |
+
raise ValueError("Only vLLM generation is supported in this version ")
|
| 769 |
+
|
| 770 |
+
# below are the same with yifan's code
|
| 771 |
+
# Mask everything after the first EOS token
|
| 772 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 773 |
+
device = self.accelerator.device
|
| 774 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 775 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 776 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 777 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
prompt_inputs.pop("input_ids")
|
| 782 |
+
prompt_inputs.pop("attention_mask")
|
| 783 |
+
|
| 784 |
+
if data_type == 'image':
|
| 785 |
+
prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
|
| 786 |
+
prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
|
| 787 |
+
# import pdb; pdb.set_trace()
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
if data_type == 'video':
|
| 791 |
+
prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
|
| 792 |
+
prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
|
| 793 |
+
if 'second_per_grid_ts' in prompt_inputs:
|
| 794 |
+
del prompt_inputs["second_per_grid_ts"]
|
| 795 |
+
|
| 796 |
+
# import pdb
|
| 797 |
+
# pdb.set_trace()
|
| 798 |
+
|
| 799 |
+
# per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
|
| 800 |
+
per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
|
| 801 |
+
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
| 802 |
+
per_token_logps = per_token_logps[:, prompt_length - 1 :]
|
| 803 |
+
|
| 804 |
+
gc.collect()
|
| 805 |
+
torch.cuda.empty_cache()
|
| 806 |
+
|
| 807 |
+
with torch.inference_mode():
|
| 808 |
+
if self.ref_model is not None:
|
| 809 |
+
ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
|
| 810 |
+
else:
|
| 811 |
+
with self.accelerator.unwrap_model(model).disable_adapter():
|
| 812 |
+
ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
|
| 813 |
+
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
|
| 814 |
+
|
| 815 |
+
x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
|
| 816 |
+
per_token_kl = torch.exp(x_clamped) - x_clamped - 1
|
| 817 |
+
|
| 818 |
+
gc.collect()
|
| 819 |
+
torch.cuda.empty_cache()
|
| 820 |
+
|
| 821 |
+
if self.temporal and video_inputs:
|
| 822 |
+
|
| 823 |
+
shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
|
| 824 |
+
if is_conversational(inputs[0]):
|
| 825 |
+
shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
|
| 826 |
+
|
| 827 |
+
# Compute the rewards
|
| 828 |
+
shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
|
| 829 |
+
shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
|
| 830 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 831 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 832 |
+
):
|
| 833 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 834 |
+
shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
| 835 |
+
for key in shuffled_reward_kwargs:
|
| 836 |
+
for example in inputs:
|
| 837 |
+
# Repeat each value in the column for `num_generations` times
|
| 838 |
+
shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
|
| 839 |
+
shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
|
| 840 |
+
shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
# Decode the generated completions
|
| 845 |
+
completions = self.processing_class.batch_decode(
|
| 846 |
+
completion_ids, skip_special_tokens=True
|
| 847 |
+
)
|
| 848 |
+
if is_conversational(inputs[0]):
|
| 849 |
+
completions = [
|
| 850 |
+
[{"role": "assistant", "content": completion}]
|
| 851 |
+
for completion in completions
|
| 852 |
+
]
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
'''Below is code for second completions generation'''
|
| 856 |
+
if is_conversational(inputs[0]):
|
| 857 |
+
first_texts = [c[0]["content"] for c in completions]
|
| 858 |
+
else:
|
| 859 |
+
first_texts = completions
|
| 860 |
+
|
| 861 |
+
# ------------------------------------------------------------
|
| 862 |
+
# 2️⃣ Build follow-up prompts with `extract_info`
|
| 863 |
+
# ------------------------------------------------------------
|
| 864 |
+
follow_up_prompts = [extract_info(txt) for txt in first_texts]
|
| 865 |
+
|
| 866 |
+
# ------------------------------------------------------------
|
| 867 |
+
# 3️⃣ SECOND-hop generation ➜ `second_completions`
|
| 868 |
+
# ------------------------------------------------------------
|
| 869 |
+
_, second_texts = self._vllm_generate(
|
| 870 |
+
follow_up_prompts, # new prompts (pure text)
|
| 871 |
+
None, # no vision payload
|
| 872 |
+
1 # one follow-up per prompt
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
# pack in chat format if needed
|
| 876 |
+
if is_conversational(inputs[0]):
|
| 877 |
+
second_completions = [
|
| 878 |
+
[{"role": "assistant", "content": t}] for t in second_texts
|
| 879 |
+
]
|
| 880 |
+
else:
|
| 881 |
+
second_completions = second_texts
|
| 882 |
+
|
| 883 |
+
'''Above is code for second completions generation'''
|
| 884 |
+
|
| 885 |
+
# Compute the rewards
|
| 886 |
+
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
|
| 887 |
+
rewards_per_func = torch.zeros(
|
| 888 |
+
len(prompts), len(self.reward_funcs), device=device
|
| 889 |
+
)
|
| 890 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 891 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 892 |
+
):
|
| 893 |
+
reward_kwargs = {
|
| 894 |
+
key: []
|
| 895 |
+
for key in inputs[0].keys()
|
| 896 |
+
if key not in ["prompt", "completion"]
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
'''Below is code for taking second generations'''
|
| 900 |
+
# every original example contributes `self.num_generations`
|
| 901 |
+
for example in inputs:
|
| 902 |
+
for _ in range(self.num_generations): # n times
|
| 903 |
+
for key in reward_kwargs:
|
| 904 |
+
reward_kwargs[key].append(example[key])
|
| 905 |
+
|
| 906 |
+
# -------- call the reward function --------
|
| 907 |
+
outputs = reward_func(
|
| 908 |
+
prompts=follow_up_prompts, # ⬅ extracted info
|
| 909 |
+
completions=second_completions, # ⬅ fresh answers
|
| 910 |
+
**reward_kwargs,
|
| 911 |
+
)
|
| 912 |
+
rewards_per_func[:, i] = torch.tensor(outputs, dtype=torch.float32, device=device)
|
| 913 |
+
'''Above is code for taking second generations'''
|
| 914 |
+
|
| 915 |
+
# for key in reward_kwargs:
|
| 916 |
+
# for example in inputs:
|
| 917 |
+
# # Repeat each value in the column for `num_generations` times
|
| 918 |
+
# reward_kwargs[key].extend([example[key]] * self.num_generations)
|
| 919 |
+
# output_reward_func = reward_func(
|
| 920 |
+
# prompts=prompts, completions=completions, **reward_kwargs
|
| 921 |
+
# )
|
| 922 |
+
# rewards_per_func[:, i] = torch.tensor(
|
| 923 |
+
# output_reward_func, dtype=torch.float32, device=device
|
| 924 |
+
# )
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
# rewards_per_func = gather(rewards_per_func)
|
| 928 |
+
# # Sum the rewards from all reward functions
|
| 929 |
+
# rewards = rewards_per_func.sum(dim=1)
|
| 930 |
+
|
| 931 |
+
# process_slice = slice(
|
| 932 |
+
# self.accelerator.process_index * len(prompts),
|
| 933 |
+
# (self.accelerator.process_index + 1) * len(prompts),
|
| 934 |
+
# )
|
| 935 |
+
|
| 936 |
+
# rewards = rewards[process_slice]
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
if self.temporal and video_inputs:
|
| 941 |
+
temporal_rewards_per_func = rewards_per_func.clone()
|
| 942 |
+
|
| 943 |
+
acc_mean = temporal_rewards_per_func[:, 0].mean()
|
| 944 |
+
shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
|
| 945 |
+
|
| 946 |
+
if acc_mean >= 0.8 * shuffled_acc_mean:
|
| 947 |
+
mask = temporal_rewards_per_func[:, 0] > 0.1
|
| 948 |
+
temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
|
| 949 |
+
temporal_rewards = torch.tensor([1.0]).to('cuda')
|
| 950 |
+
else:
|
| 951 |
+
temporal_rewards = torch.tensor([0.0]).to('cuda')
|
| 952 |
+
else:
|
| 953 |
+
temporal_rewards = torch.tensor([0.5]).to('cuda')
|
| 954 |
+
|
| 955 |
+
# Sum the rewards from all reward functions
|
| 956 |
+
if self.temporal and video_inputs:
|
| 957 |
+
rewards = temporal_rewards_per_func.sum(dim=1)
|
| 958 |
+
else:
|
| 959 |
+
rewards = rewards_per_func.sum(dim=1)
|
| 960 |
+
|
| 961 |
+
if self.len_control:
|
| 962 |
+
mem_rewards = [0] * self.num_generations
|
| 963 |
+
mask = rewards_per_func[:, 0] > 0.1
|
| 964 |
+
lenth_list = completion_mask.sum(1)
|
| 965 |
+
selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
|
| 966 |
+
# if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
|
| 967 |
+
# if len(selected_indices) > 1:
|
| 968 |
+
# selected_items = [(i, lenth_list[i]) for i in selected_indices]
|
| 969 |
+
# sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
|
| 970 |
+
# N = len(sorted_items)
|
| 971 |
+
# for rank, (idx, length) in enumerate(sorted_items):
|
| 972 |
+
# reward = 0.2 - 0.2 * (rank / N)
|
| 973 |
+
# rewards[idx] += reward
|
| 974 |
+
# mem_rewards[idx] = reward
|
| 975 |
+
# for idx in range(len(lenth_list)):
|
| 976 |
+
# if lenth_list[idx] >= 512:
|
| 977 |
+
# rewards[idx] -= 0.5
|
| 978 |
+
|
| 979 |
+
if len(selected_indices) > 1:
|
| 980 |
+
for idx in selected_indices:
|
| 981 |
+
if 320 <= lenth_list[idx] <= 512:
|
| 982 |
+
rewards[idx] += 0.2
|
| 983 |
+
|
| 984 |
+
print(rewards)
|
| 985 |
+
print(completion_mask.sum(1))
|
| 986 |
+
|
| 987 |
+
# Compute grouped-wise rewards
|
| 988 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 989 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 990 |
+
|
| 991 |
+
# Normalize the rewards to compute the advantages
|
| 992 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 993 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 994 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 995 |
+
|
| 996 |
+
# x - x.detach() allows for preserving gradients from x
|
| 997 |
+
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
| 998 |
+
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
| 999 |
+
# per_token_loss = -per_token_loss
|
| 1000 |
+
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
# import pdb
|
| 1004 |
+
# pdb.set_trace()
|
| 1005 |
+
|
| 1006 |
+
# Log the metrics
|
| 1007 |
+
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 1008 |
+
self._metrics["completion_length"].append(completion_length)
|
| 1009 |
+
|
| 1010 |
+
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
| 1011 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 1012 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 1013 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 1014 |
+
else:
|
| 1015 |
+
reward_func_name = reward_func.__name__
|
| 1016 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 1017 |
+
|
| 1018 |
+
gathered_rewards = self.accelerator.gather_for_metrics(rewards)
|
| 1019 |
+
|
| 1020 |
+
num_devices = gathered_rewards.size(0) // self.num_generations
|
| 1021 |
+
rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
|
| 1022 |
+
wrong_devices = (rewards_per_device <= 1).all(dim=1)
|
| 1023 |
+
wrong_ratio = wrong_devices.sum().item() / num_devices
|
| 1024 |
+
|
| 1025 |
+
correct_devices = (rewards_per_device >= 2).all(dim=1)
|
| 1026 |
+
correct_ratio = correct_devices.sum().item() / num_devices
|
| 1027 |
+
|
| 1028 |
+
self._metrics["all_wrong"].append(wrong_ratio)
|
| 1029 |
+
self._metrics["all_correct"].append(correct_ratio)
|
| 1030 |
+
|
| 1031 |
+
if self.temporal:
|
| 1032 |
+
temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
|
| 1033 |
+
self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
|
| 1034 |
+
|
| 1035 |
+
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
| 1036 |
+
|
| 1037 |
+
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
| 1038 |
+
|
| 1039 |
+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 1040 |
+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
return loss
|
| 1044 |
+
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 1049 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 1050 |
+
|
| 1051 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 1052 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 1053 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
| 1054 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 1055 |
+
|
| 1056 |
+
logs = {**logs, **metrics}
|
| 1057 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 1058 |
+
super().log(logs, start_time)
|
| 1059 |
+
else: # transformers<=4.46
|
| 1060 |
+
super().log(logs)
|
| 1061 |
+
self._metrics.clear()
|
src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_orig.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import textwrap
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from typing import Any, Callable, Optional, Union
|
| 19 |
+
from accelerate.utils.other import is_compiled_module
|
| 20 |
+
from accelerate.utils import broadcast_object_list, gather, gather_object
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.data
|
| 23 |
+
import transformers
|
| 24 |
+
import warnings
|
| 25 |
+
from unittest.mock import patch
|
| 26 |
+
from datasets import Dataset, IterableDataset
|
| 27 |
+
from packaging import version
|
| 28 |
+
from transformers import (
|
| 29 |
+
AriaForConditionalGeneration,
|
| 30 |
+
AriaProcessor,
|
| 31 |
+
AutoModelForCausalLM,
|
| 32 |
+
AutoModelForSequenceClassification,
|
| 33 |
+
AutoProcessor,
|
| 34 |
+
AutoTokenizer,
|
| 35 |
+
GenerationConfig,
|
| 36 |
+
PreTrainedModel,
|
| 37 |
+
PreTrainedTokenizerBase,
|
| 38 |
+
Qwen2VLForConditionalGeneration,
|
| 39 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 40 |
+
Trainer,
|
| 41 |
+
TrainerCallback,
|
| 42 |
+
is_wandb_available,
|
| 43 |
+
)
|
| 44 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 45 |
+
from transformers.utils import is_peft_available
|
| 46 |
+
|
| 47 |
+
from trl.data_utils import (
|
| 48 |
+
apply_chat_template,
|
| 49 |
+
is_conversational,
|
| 50 |
+
maybe_apply_chat_template,
|
| 51 |
+
)
|
| 52 |
+
from trl.import_utils import is_vllm_available
|
| 53 |
+
|
| 54 |
+
from trl.models import (
|
| 55 |
+
create_reference_model,
|
| 56 |
+
prepare_deepspeed,
|
| 57 |
+
unwrap_model_for_generation,
|
| 58 |
+
)
|
| 59 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 60 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
|
| 61 |
+
from trl import GRPOTrainer
|
| 62 |
+
|
| 63 |
+
import copy
|
| 64 |
+
|
| 65 |
+
if is_peft_available():
|
| 66 |
+
from peft import PeftConfig, get_peft_model
|
| 67 |
+
|
| 68 |
+
if is_vllm_available():
|
| 69 |
+
from vllm import LLM, SamplingParams
|
| 70 |
+
|
| 71 |
+
if is_wandb_available():
|
| 72 |
+
import wandb
|
| 73 |
+
import torch.nn as nn
|
| 74 |
+
from torch.utils.data import Sampler
|
| 75 |
+
import gc
|
| 76 |
+
from qwen_vl_utils import process_vision_info
|
| 77 |
+
|
| 78 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 79 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 80 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Qwen2VLGRPOVLLMTrainerModifiedOrig(Trainer):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
model: Union[str, PreTrainedModel],
|
| 87 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 88 |
+
args: GRPOConfig = None,
|
| 89 |
+
script_args = None,
|
| 90 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 91 |
+
eval_dataset: Optional[
|
| 92 |
+
Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
|
| 93 |
+
] = None,
|
| 94 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 95 |
+
reward_processing_classes: Optional[
|
| 96 |
+
Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
|
| 97 |
+
] = None,
|
| 98 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 99 |
+
optimizers: tuple[
|
| 100 |
+
Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
|
| 101 |
+
] = (None, None),
|
| 102 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 103 |
+
# qwen2-vl related params
|
| 104 |
+
max_pixels: Optional[int] = 12845056,
|
| 105 |
+
min_pixels: Optional[int] = 3136,
|
| 106 |
+
attn_implementation: str = "flash_attention_2",
|
| 107 |
+
):
|
| 108 |
+
|
| 109 |
+
# Args
|
| 110 |
+
if args is None:
|
| 111 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 112 |
+
model_name = model_name.split("/")[-1]
|
| 113 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 114 |
+
|
| 115 |
+
# Models
|
| 116 |
+
# Trained model
|
| 117 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 118 |
+
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 119 |
+
if isinstance(model, str):
|
| 120 |
+
model_id = model
|
| 121 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 122 |
+
if (
|
| 123 |
+
isinstance(torch_dtype, torch.dtype)
|
| 124 |
+
or torch_dtype == "auto"
|
| 125 |
+
or torch_dtype is None
|
| 126 |
+
):
|
| 127 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 128 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 129 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 130 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 134 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 135 |
+
)
|
| 136 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 137 |
+
model_init_kwargs["use_cache"] = (
|
| 138 |
+
False
|
| 139 |
+
if args.gradient_checkpointing
|
| 140 |
+
else model_init_kwargs.get("use_cache")
|
| 141 |
+
)
|
| 142 |
+
if "Qwen2-VL" in model_id:
|
| 143 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 144 |
+
model, **model_init_kwargs
|
| 145 |
+
)
|
| 146 |
+
elif "Qwen2.5-VL" in model_id:
|
| 147 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 148 |
+
model, **model_init_kwargs
|
| 149 |
+
)
|
| 150 |
+
elif "Aria" in model_id:
|
| 151 |
+
model_init_kwargs.pop("use_cache")
|
| 152 |
+
model = AriaForConditionalGeneration.from_pretrained(
|
| 153 |
+
model, **model_init_kwargs
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
|
| 157 |
+
else:
|
| 158 |
+
model_id = model.config._name_or_path
|
| 159 |
+
if args.model_init_kwargs is not None:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
| 162 |
+
"This argument can only be used when the `model` argument is a string."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if peft_config is not None:
|
| 166 |
+
model = get_peft_model(model, peft_config)
|
| 167 |
+
|
| 168 |
+
# Reference model
|
| 169 |
+
if is_deepspeed_zero3_enabled():
|
| 170 |
+
if "Qwen2-VL" in model_id:
|
| 171 |
+
self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
| 172 |
+
model_id, **model_init_kwargs
|
| 173 |
+
)
|
| 174 |
+
elif "Qwen2.5-VL" in model_id:
|
| 175 |
+
self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 176 |
+
model_id, **model_init_kwargs
|
| 177 |
+
)
|
| 178 |
+
elif "Aria" in model_id:
|
| 179 |
+
self.ref_model = AriaForConditionalGeneration.from_pretrained(
|
| 180 |
+
model_id, **model_init_kwargs
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 184 |
+
model_id, **model_init_kwargs
|
| 185 |
+
)
|
| 186 |
+
elif peft_config is None:
|
| 187 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 188 |
+
self.ref_model = create_reference_model(model)
|
| 189 |
+
else:
|
| 190 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 191 |
+
# to revert to the initial model.
|
| 192 |
+
self.ref_model = None
|
| 193 |
+
|
| 194 |
+
# Processing class
|
| 195 |
+
# if processing_class is None:
|
| 196 |
+
# if "Qwen" in model_id or "Aria" in model_id:
|
| 197 |
+
# processing_class = AutoProcessor.from_pretrained(model_id)
|
| 198 |
+
# pad_token_id = processing_class.tokenizer.pad_token_id
|
| 199 |
+
# processing_class.pad_token_id = pad_token_id
|
| 200 |
+
# processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 201 |
+
# if "Qwen" in model_id:
|
| 202 |
+
# processing_class.image_processor.max_pixels = max_pixels
|
| 203 |
+
# processing_class.image_processor.min_pixels = min_pixels
|
| 204 |
+
# else:
|
| 205 |
+
# processing_class = AutoTokenizer.from_pretrained(
|
| 206 |
+
# model.config._name_or_path, padding_side="left"
|
| 207 |
+
# )
|
| 208 |
+
# pad_token_id = processing_class.pad_token_id
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ────────────────────────────────────────────────────────────────
|
| 212 |
+
# Robust processor loading ― works for both fresh models *and* checkpoints
|
| 213 |
+
# ────────────────────────────────────────────────────────────────
|
| 214 |
+
if processing_class is None:
|
| 215 |
+
# 1️⃣ First try to load whatever lives in the directory we were given.
|
| 216 |
+
# This succeeds if you previously did `processor.save_pretrained(output_dir)`.
|
| 217 |
+
try:
|
| 218 |
+
processing_class = AutoProcessor.from_pretrained(model_id)
|
| 219 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 220 |
+
except (OSError, ValueError): # no processor files found
|
| 221 |
+
# 2️⃣ Fall back to inspecting the *model object* instead of the path.
|
| 222 |
+
is_vl_model = (
|
| 223 |
+
hasattr(model, "vision_tower") or # Qwen-VL, InternVL, etc.
|
| 224 |
+
getattr(model.config, "vision_config", None) is not None or
|
| 225 |
+
getattr(model.config, "image_vocab_size", None) is not None
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if is_vl_model:
|
| 229 |
+
# Always use the *base* model name stored in the config.
|
| 230 |
+
base_name = model.config._name_or_path # e.g. "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 231 |
+
processing_class = AutoProcessor.from_pretrained(base_name)
|
| 232 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 233 |
+
|
| 234 |
+
# Optional Qwen-specific limits
|
| 235 |
+
if hasattr(processing_class, "image_processor"):
|
| 236 |
+
processing_class.image_processor.max_pixels = max_pixels
|
| 237 |
+
processing_class.image_processor.min_pixels = min_pixels
|
| 238 |
+
else:
|
| 239 |
+
# Pure text model → plain tokenizer
|
| 240 |
+
processing_class = AutoTokenizer.from_pretrained(
|
| 241 |
+
model.config._name_or_path, padding_side="left"
|
| 242 |
+
)
|
| 243 |
+
pad_token_id = processing_class.pad_token_id
|
| 244 |
+
|
| 245 |
+
# 3️⃣ Harmonise attributes the rest of the trainer expects
|
| 246 |
+
processing_class.pad_token_id = pad_token_id
|
| 247 |
+
if not hasattr(processing_class, "eos_token_id"):
|
| 248 |
+
processing_class.eos_token_id = pad_token_id
|
| 249 |
+
# ────────────────────────────────────────────────────────────────
|
| 250 |
+
|
| 251 |
+
# Reward functions
|
| 252 |
+
if not isinstance(reward_funcs, list):
|
| 253 |
+
reward_funcs = [reward_funcs]
|
| 254 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 255 |
+
if isinstance(reward_func, str):
|
| 256 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 257 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 258 |
+
)
|
| 259 |
+
self.reward_funcs = reward_funcs
|
| 260 |
+
|
| 261 |
+
# Reward processing class
|
| 262 |
+
if reward_processing_classes is None:
|
| 263 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 264 |
+
elif not isinstance(reward_processing_classes, list):
|
| 265 |
+
reward_processing_classes = [reward_processing_classes]
|
| 266 |
+
else:
|
| 267 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 268 |
+
raise ValueError(
|
| 269 |
+
"The number of reward processing classes must match the number of reward functions."
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
for i, (reward_processing_class, reward_func) in enumerate(
|
| 273 |
+
zip(reward_processing_classes, reward_funcs)
|
| 274 |
+
):
|
| 275 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 276 |
+
if reward_processing_class is None:
|
| 277 |
+
reward_processing_class = AutoTokenizer.from_pretrained(
|
| 278 |
+
reward_func.config._name_or_path
|
| 279 |
+
)
|
| 280 |
+
if reward_processing_class.pad_token_id is None:
|
| 281 |
+
reward_processing_class.pad_token = (
|
| 282 |
+
reward_processing_class.eos_token
|
| 283 |
+
)
|
| 284 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 285 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 286 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 287 |
+
reward_processing_classes[i] = reward_processing_class
|
| 288 |
+
self.reward_processing_classes = reward_processing_classes
|
| 289 |
+
|
| 290 |
+
# Data collator
|
| 291 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 292 |
+
return features
|
| 293 |
+
|
| 294 |
+
# Training arguments
|
| 295 |
+
self.max_prompt_length = args.max_prompt_length
|
| 296 |
+
self.max_completion_length = (
|
| 297 |
+
args.max_completion_length
|
| 298 |
+
) # = |o_i| in the GRPO paper
|
| 299 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 300 |
+
self.temporal = script_args.temporal
|
| 301 |
+
self.generation_config = GenerationConfig(
|
| 302 |
+
max_new_tokens=self.max_completion_length,
|
| 303 |
+
do_sample=True,
|
| 304 |
+
temperature=1, # HACK
|
| 305 |
+
num_return_sequences=self.num_generations,
|
| 306 |
+
pad_token_id=pad_token_id,
|
| 307 |
+
)
|
| 308 |
+
self.beta = args.beta
|
| 309 |
+
|
| 310 |
+
self.shuffled_num_generations = self.num_generations // 2
|
| 311 |
+
self.shuffled_generation_config = GenerationConfig(
|
| 312 |
+
max_new_tokens=self.max_completion_length,
|
| 313 |
+
do_sample=True,
|
| 314 |
+
top_p=0.95,
|
| 315 |
+
temperature=1, # HACK
|
| 316 |
+
num_return_sequences=self.shuffled_num_generations,
|
| 317 |
+
pad_token_id=pad_token_id,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.dummy_generation_config = GenerationConfig(
|
| 321 |
+
max_new_tokens=1,
|
| 322 |
+
do_sample=True,
|
| 323 |
+
top_p=0.95,
|
| 324 |
+
temperature=1, # HACK
|
| 325 |
+
num_return_sequences=1,
|
| 326 |
+
pad_token_id=pad_token_id,
|
| 327 |
+
)
|
| 328 |
+
self.len_control = script_args.len_control
|
| 329 |
+
self.beta = args.beta
|
| 330 |
+
|
| 331 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 332 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 333 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 334 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 335 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 336 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 337 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 338 |
+
|
| 339 |
+
# Initialize the metrics
|
| 340 |
+
self._metrics = defaultdict(list)
|
| 341 |
+
self.use_vllm = args.use_vllm
|
| 342 |
+
|
| 343 |
+
super().__init__(
|
| 344 |
+
model=model,
|
| 345 |
+
args=args,
|
| 346 |
+
data_collator=data_collator,
|
| 347 |
+
train_dataset=train_dataset,
|
| 348 |
+
eval_dataset=eval_dataset,
|
| 349 |
+
processing_class=processing_class,
|
| 350 |
+
callbacks=callbacks,
|
| 351 |
+
optimizers=optimizers,
|
| 352 |
+
)
|
| 353 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 354 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 355 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 356 |
+
self.model_accepts_loss_kwargs = False
|
| 357 |
+
|
| 358 |
+
if self.use_vllm:
|
| 359 |
+
if not is_vllm_available():
|
| 360 |
+
raise ImportError(
|
| 361 |
+
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
| 362 |
+
"`pip install vllm` to use it."
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
if self.accelerator.is_main_process:
|
| 366 |
+
vllm_device = self.args.vllm_device
|
| 367 |
+
if vllm_device == "auto":
|
| 368 |
+
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
| 369 |
+
# Check that the requested device is available
|
| 370 |
+
if (
|
| 371 |
+
vllm_device.split(":")[0] == "cuda"
|
| 372 |
+
and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
|
| 373 |
+
):
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
| 376 |
+
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
| 377 |
+
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
| 378 |
+
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
| 379 |
+
)
|
| 380 |
+
# Check that the requested device is not also used for training
|
| 381 |
+
if vllm_device in {
|
| 382 |
+
f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
|
| 383 |
+
}:
|
| 384 |
+
warnings.warn(
|
| 385 |
+
f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
|
| 386 |
+
"behavior. It is recommended to use a dedicated device for vLLM."
|
| 387 |
+
)
|
| 388 |
+
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
| 389 |
+
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
| 390 |
+
# setting (profiling_patch).
|
| 391 |
+
world_size_patch = patch(
|
| 392 |
+
"torch.distributed.get_world_size", return_value=1
|
| 393 |
+
)
|
| 394 |
+
profiling_patch = patch(
|
| 395 |
+
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
|
| 396 |
+
return_value=None,
|
| 397 |
+
)
|
| 398 |
+
with world_size_patch, profiling_patch:
|
| 399 |
+
print("vllm is running on: ", vllm_device)
|
| 400 |
+
self.llm = LLM(
|
| 401 |
+
model=model.name_or_path,
|
| 402 |
+
device=vllm_device,
|
| 403 |
+
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
| 404 |
+
dtype=torch.bfloat16,
|
| 405 |
+
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
| 406 |
+
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
| 407 |
+
# This is particularly useful here because we generate completions from the same prompts.
|
| 408 |
+
enable_prefix_caching=True,
|
| 409 |
+
enforce_eager=True,
|
| 410 |
+
mm_processor_kwargs=(
|
| 411 |
+
{
|
| 412 |
+
"max_pixels": max_pixels,
|
| 413 |
+
"min_pixels": min_pixels,
|
| 414 |
+
}
|
| 415 |
+
# if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
|
| 416 |
+
if False
|
| 417 |
+
else None
|
| 418 |
+
),
|
| 419 |
+
max_model_len=args.max_prompt_length + args.max_completion_length,
|
| 420 |
+
)
|
| 421 |
+
self.sampling_params = SamplingParams(
|
| 422 |
+
temperature=1.0,
|
| 423 |
+
top_p=0.95,
|
| 424 |
+
max_tokens=self.max_completion_length,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
|
| 428 |
+
|
| 429 |
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
| 430 |
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
| 431 |
+
# synchronize all processes after vLLM has been fully initialized.
|
| 432 |
+
self.accelerator.wait_for_everyone()
|
| 433 |
+
else:
|
| 434 |
+
raise ValueError(
|
| 435 |
+
"GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if self.ref_model is not None:
|
| 439 |
+
if self.is_deepspeed_enabled:
|
| 440 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 441 |
+
else:
|
| 442 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 443 |
+
|
| 444 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 445 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 446 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 447 |
+
|
| 448 |
+
def _set_signature_columns_if_needed(self):
|
| 449 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 450 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 451 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 452 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 453 |
+
if self._signature_columns is None:
|
| 454 |
+
self._signature_columns = ["prompt"]
|
| 455 |
+
|
| 456 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 457 |
+
def _get_per_token_logps(self, model, input_ids, **kwargs):
|
| 458 |
+
# logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
|
| 459 |
+
# import pdb
|
| 460 |
+
# pdb.set_trace()
|
| 461 |
+
logits = model(input_ids, **kwargs).logits
|
| 462 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 463 |
+
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
| 464 |
+
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
| 465 |
+
per_token_logps = []
|
| 466 |
+
for logits_row, input_ids_row in zip(logits, input_ids):
|
| 467 |
+
log_probs = logits_row.log_softmax(dim=-1)
|
| 468 |
+
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
| 469 |
+
per_token_logps.append(token_log_prob)
|
| 470 |
+
return torch.stack(per_token_logps)
|
| 471 |
+
|
| 472 |
+
# Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
|
| 473 |
+
# Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
|
| 474 |
+
def _prepare_inputs(
|
| 475 |
+
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
| 476 |
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
| 477 |
+
return inputs
|
| 478 |
+
|
| 479 |
+
def remove_none_from_data(self, data):
|
| 480 |
+
for entry in data:
|
| 481 |
+
if "content" in entry and isinstance(entry["content"], list):
|
| 482 |
+
for sub_entry in entry["content"]:
|
| 483 |
+
if isinstance(sub_entry, dict):
|
| 484 |
+
keys_to_remove = [k for k, v in sub_entry.items() if v is None]
|
| 485 |
+
for k in keys_to_remove:
|
| 486 |
+
del sub_entry[k]
|
| 487 |
+
return data
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def compute_loss(
|
| 492 |
+
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
| 493 |
+
):
|
| 494 |
+
if return_outputs:
|
| 495 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 496 |
+
# Compute the per-token log probabilities for the model
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
device = self.accelerator.device
|
| 500 |
+
prompts = [x["prompt"] for x in inputs]
|
| 501 |
+
# images = [x["image"] for x in inputs]
|
| 502 |
+
prompts_text = [
|
| 503 |
+
maybe_apply_chat_template(example, self.processing_class)["prompt"]
|
| 504 |
+
for example in inputs
|
| 505 |
+
]
|
| 506 |
+
|
| 507 |
+
input_copy = copy.deepcopy(inputs[0]['prompt'])
|
| 508 |
+
|
| 509 |
+
input_copy = self.remove_none_from_data(input_copy)
|
| 510 |
+
|
| 511 |
+
data_type = inputs[0]['data_type']
|
| 512 |
+
|
| 513 |
+
if data_type == 'image':
|
| 514 |
+
input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
|
| 515 |
+
elif data_type == 'video':
|
| 516 |
+
input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
prompt_inputs = self.processing_class(
|
| 523 |
+
text=copy.deepcopy(prompts_text),
|
| 524 |
+
images=image_inputs,
|
| 525 |
+
videos=video_inputs,
|
| 526 |
+
return_tensors="pt",
|
| 527 |
+
padding=True,
|
| 528 |
+
padding_side="left",
|
| 529 |
+
add_special_tokens=False,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
|
| 533 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 534 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 535 |
+
|
| 536 |
+
if self.max_prompt_length is not None:
|
| 537 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 538 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
if self.temporal:
|
| 542 |
+
if video_inputs:
|
| 543 |
+
indices = torch.randperm(video_inputs[0].size(0))
|
| 544 |
+
shuffled_video_inputs = [video_inputs[0][indices]]
|
| 545 |
+
shuffled_prompt_inputs = self.processing_class(
|
| 546 |
+
text=copy.deepcopy(prompts_text),
|
| 547 |
+
images=image_inputs,
|
| 548 |
+
videos=shuffled_video_inputs,
|
| 549 |
+
return_tensors="pt",
|
| 550 |
+
padding=True,
|
| 551 |
+
padding_side="left",
|
| 552 |
+
add_special_tokens=False,
|
| 553 |
+
)
|
| 554 |
+
shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
|
| 555 |
+
shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
|
| 556 |
+
shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
|
| 557 |
+
if self.max_prompt_length is not None:
|
| 558 |
+
shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
|
| 559 |
+
shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
|
| 560 |
+
else:
|
| 561 |
+
shuffled_mm_data = [None]
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
if self.args.use_vllm:
|
| 566 |
+
# First, have main process load weights if needed
|
| 567 |
+
if self.state.global_step != self._last_loaded_step:
|
| 568 |
+
with unwrap_model_for_generation(
|
| 569 |
+
self.model,
|
| 570 |
+
self.accelerator,
|
| 571 |
+
gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
|
| 572 |
+
) as unwrapped_model:
|
| 573 |
+
if is_compiled_module(unwrapped_model):
|
| 574 |
+
state_dict = unwrapped_model._orig_mod.state_dict()
|
| 575 |
+
else:
|
| 576 |
+
state_dict = unwrapped_model.state_dict()
|
| 577 |
+
if self.accelerator.is_main_process:
|
| 578 |
+
llm_model = (
|
| 579 |
+
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
| 580 |
+
)
|
| 581 |
+
# import pdb
|
| 582 |
+
# pdb.set_trace()
|
| 583 |
+
llm_model.load_weights(state_dict.items())
|
| 584 |
+
self._last_loaded_step = self.state.global_step
|
| 585 |
+
|
| 586 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
| 587 |
+
all_prompts_text = gather_object(prompts_text)
|
| 588 |
+
all_mm_data = gather_object(mm_data)
|
| 589 |
+
# group into pairs
|
| 590 |
+
all_multimodal_inputs = []
|
| 591 |
+
|
| 592 |
+
if self.temporal:
|
| 593 |
+
shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
|
| 594 |
+
shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
|
| 595 |
+
shuffled_all_multimodal_inputs = []
|
| 596 |
+
|
| 597 |
+
# 2. Refer to TobiasLee's implementation suggestions
|
| 598 |
+
# this is a better implementation for vLLM sampling.
|
| 599 |
+
for prompt, mm_item in zip(all_prompts_text, all_mm_data):
|
| 600 |
+
all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
|
| 601 |
+
|
| 602 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 603 |
+
for mm_item in shuffled_all_mm_data:
|
| 604 |
+
shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
|
| 605 |
+
|
| 606 |
+
# Create sampling params with num_generations
|
| 607 |
+
if self.accelerator.is_main_process:
|
| 608 |
+
# Clone to avoid modifying original params
|
| 609 |
+
sampling_params = copy.deepcopy(self.sampling_params)
|
| 610 |
+
sampling_params.n = self.num_generations
|
| 611 |
+
# Single generate call with all prompts
|
| 612 |
+
if self.accelerator.is_main_process:
|
| 613 |
+
outputs = self.llm.generate(
|
| 614 |
+
all_multimodal_inputs,
|
| 615 |
+
sampling_params=sampling_params,
|
| 616 |
+
use_tqdm=False,
|
| 617 |
+
)
|
| 618 |
+
# Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
|
| 619 |
+
completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
|
| 620 |
+
|
| 621 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 622 |
+
# Clone to avoid modifying original params
|
| 623 |
+
shuffled_sampling_params = copy.deepcopy(self.sampling_params)
|
| 624 |
+
shuffled_sampling_params.n = self.num_generations // 2
|
| 625 |
+
# Single generate call with all prompts
|
| 626 |
+
if self.accelerator.is_main_process:
|
| 627 |
+
shuffled_outputs = self.llm.generate(
|
| 628 |
+
shuffled_all_multimodal_inputs,
|
| 629 |
+
sampling_params=shuffled_sampling_params,
|
| 630 |
+
use_tqdm=False,
|
| 631 |
+
)
|
| 632 |
+
# Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
|
| 633 |
+
shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
else:
|
| 637 |
+
completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
|
| 638 |
+
|
| 639 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 640 |
+
shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
# broadcast and slice
|
| 644 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
| 645 |
+
process_slice = slice(
|
| 646 |
+
self.accelerator.process_index * len(prompts) * self.num_generations,
|
| 647 |
+
(self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
|
| 648 |
+
)
|
| 649 |
+
completion_ids = completion_ids[process_slice]
|
| 650 |
+
|
| 651 |
+
# Pad the completions, and concatenate them with the prompts
|
| 652 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
| 653 |
+
completion_ids = pad(
|
| 654 |
+
completion_ids, padding_value=self.processing_class.pad_token_id
|
| 655 |
+
)
|
| 656 |
+
prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
|
| 657 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 658 |
+
|
| 659 |
+
prompt_length = prompt_ids.size(1)
|
| 660 |
+
|
| 661 |
+
# print('prompt_length:', prompt_length)
|
| 662 |
+
|
| 663 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 664 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 665 |
+
prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
if self.temporal and shuffled_all_mm_data!=[]:
|
| 669 |
+
# broadcast and slice
|
| 670 |
+
shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
|
| 671 |
+
process_id_list = []
|
| 672 |
+
for mm_item in shuffled_all_mm_data:
|
| 673 |
+
process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
|
| 674 |
+
|
| 675 |
+
if video_inputs:
|
| 676 |
+
cur_shuffled_completion_ids = []
|
| 677 |
+
for i in range(len(process_id_list)):
|
| 678 |
+
if self.accelerator.process_index == process_id_list[i]:
|
| 679 |
+
cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
|
| 680 |
+
|
| 681 |
+
# Pad the completions, and concatenate them with the prompts
|
| 682 |
+
cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
|
| 683 |
+
cur_shuffled_completion_ids = pad(
|
| 684 |
+
cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
|
| 685 |
+
)
|
| 686 |
+
shuffled_completion_ids = cur_shuffled_completion_ids
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
else:
|
| 690 |
+
raise ValueError("Only vLLM generation is supported in this version ")
|
| 691 |
+
|
| 692 |
+
# below are the same with yifan's code
|
| 693 |
+
# Mask everything after the first EOS token
|
| 694 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 695 |
+
device = self.accelerator.device
|
| 696 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 697 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 698 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 699 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
prompt_inputs.pop("input_ids")
|
| 704 |
+
prompt_inputs.pop("attention_mask")
|
| 705 |
+
|
| 706 |
+
if data_type == 'image':
|
| 707 |
+
prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
|
| 708 |
+
prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
|
| 709 |
+
# import pdb; pdb.set_trace()
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
if data_type == 'video':
|
| 713 |
+
prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
|
| 714 |
+
prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
|
| 715 |
+
if 'second_per_grid_ts' in prompt_inputs:
|
| 716 |
+
del prompt_inputs["second_per_grid_ts"]
|
| 717 |
+
|
| 718 |
+
# import pdb
|
| 719 |
+
# pdb.set_trace()
|
| 720 |
+
|
| 721 |
+
# per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
|
| 722 |
+
per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
|
| 723 |
+
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
| 724 |
+
per_token_logps = per_token_logps[:, prompt_length - 1 :]
|
| 725 |
+
|
| 726 |
+
gc.collect()
|
| 727 |
+
torch.cuda.empty_cache()
|
| 728 |
+
|
| 729 |
+
with torch.inference_mode():
|
| 730 |
+
if self.ref_model is not None:
|
| 731 |
+
ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
|
| 732 |
+
else:
|
| 733 |
+
with self.accelerator.unwrap_model(model).disable_adapter():
|
| 734 |
+
ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
|
| 735 |
+
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
|
| 736 |
+
|
| 737 |
+
x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
|
| 738 |
+
per_token_kl = torch.exp(x_clamped) - x_clamped - 1
|
| 739 |
+
|
| 740 |
+
gc.collect()
|
| 741 |
+
torch.cuda.empty_cache()
|
| 742 |
+
|
| 743 |
+
if self.temporal and video_inputs:
|
| 744 |
+
|
| 745 |
+
shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
|
| 746 |
+
if is_conversational(inputs[0]):
|
| 747 |
+
shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
|
| 748 |
+
|
| 749 |
+
# Compute the rewards
|
| 750 |
+
shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
|
| 751 |
+
shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
|
| 752 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 753 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 754 |
+
):
|
| 755 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 756 |
+
shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
| 757 |
+
for key in shuffled_reward_kwargs:
|
| 758 |
+
for example in inputs:
|
| 759 |
+
# Repeat each value in the column for `num_generations` times
|
| 760 |
+
shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
|
| 761 |
+
shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
|
| 762 |
+
shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
# Decode the generated completions
|
| 767 |
+
completions = self.processing_class.batch_decode(
|
| 768 |
+
completion_ids, skip_special_tokens=True
|
| 769 |
+
)
|
| 770 |
+
if is_conversational(inputs[0]):
|
| 771 |
+
completions = [
|
| 772 |
+
[{"role": "assistant", "content": completion}]
|
| 773 |
+
for completion in completions
|
| 774 |
+
]
|
| 775 |
+
|
| 776 |
+
# Compute the rewards
|
| 777 |
+
prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
|
| 778 |
+
rewards_per_func = torch.zeros(
|
| 779 |
+
len(prompts), len(self.reward_funcs), device=device
|
| 780 |
+
)
|
| 781 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 782 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 783 |
+
):
|
| 784 |
+
reward_kwargs = {
|
| 785 |
+
key: []
|
| 786 |
+
for key in inputs[0].keys()
|
| 787 |
+
if key not in ["prompt", "completion"]
|
| 788 |
+
}
|
| 789 |
+
for key in reward_kwargs:
|
| 790 |
+
for example in inputs:
|
| 791 |
+
# Repeat each value in the column for `num_generations` times
|
| 792 |
+
reward_kwargs[key].extend([example[key]] * self.num_generations)
|
| 793 |
+
output_reward_func = reward_func(
|
| 794 |
+
prompts=prompts, completions=completions, **reward_kwargs
|
| 795 |
+
)
|
| 796 |
+
rewards_per_func[:, i] = torch.tensor(
|
| 797 |
+
output_reward_func, dtype=torch.float32, device=device
|
| 798 |
+
)
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
# rewards_per_func = gather(rewards_per_func)
|
| 802 |
+
# # Sum the rewards from all reward functions
|
| 803 |
+
# rewards = rewards_per_func.sum(dim=1)
|
| 804 |
+
|
| 805 |
+
# process_slice = slice(
|
| 806 |
+
# self.accelerator.process_index * len(prompts),
|
| 807 |
+
# (self.accelerator.process_index + 1) * len(prompts),
|
| 808 |
+
# )
|
| 809 |
+
|
| 810 |
+
# rewards = rewards[process_slice]
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
if self.temporal and video_inputs:
|
| 815 |
+
temporal_rewards_per_func = rewards_per_func.clone()
|
| 816 |
+
|
| 817 |
+
acc_mean = temporal_rewards_per_func[:, 0].mean()
|
| 818 |
+
shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
|
| 819 |
+
|
| 820 |
+
if acc_mean >= 0.8 * shuffled_acc_mean:
|
| 821 |
+
mask = temporal_rewards_per_func[:, 0] > 0.1
|
| 822 |
+
temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
|
| 823 |
+
temporal_rewards = torch.tensor([1.0]).to('cuda')
|
| 824 |
+
else:
|
| 825 |
+
temporal_rewards = torch.tensor([0.0]).to('cuda')
|
| 826 |
+
else:
|
| 827 |
+
temporal_rewards = torch.tensor([0.5]).to('cuda')
|
| 828 |
+
|
| 829 |
+
# Sum the rewards from all reward functions
|
| 830 |
+
if self.temporal and video_inputs:
|
| 831 |
+
rewards = temporal_rewards_per_func.sum(dim=1)
|
| 832 |
+
else:
|
| 833 |
+
rewards = rewards_per_func.sum(dim=1)
|
| 834 |
+
|
| 835 |
+
if self.len_control:
|
| 836 |
+
mem_rewards = [0] * self.num_generations
|
| 837 |
+
mask = rewards_per_func[:, 0] > 0.1
|
| 838 |
+
lenth_list = completion_mask.sum(1)
|
| 839 |
+
selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
|
| 840 |
+
# if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
|
| 841 |
+
# if len(selected_indices) > 1:
|
| 842 |
+
# selected_items = [(i, lenth_list[i]) for i in selected_indices]
|
| 843 |
+
# sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
|
| 844 |
+
# N = len(sorted_items)
|
| 845 |
+
# for rank, (idx, length) in enumerate(sorted_items):
|
| 846 |
+
# reward = 0.2 - 0.2 * (rank / N)
|
| 847 |
+
# rewards[idx] += reward
|
| 848 |
+
# mem_rewards[idx] = reward
|
| 849 |
+
# for idx in range(len(lenth_list)):
|
| 850 |
+
# if lenth_list[idx] >= 512:
|
| 851 |
+
# rewards[idx] -= 0.5
|
| 852 |
+
|
| 853 |
+
if len(selected_indices) > 1:
|
| 854 |
+
for idx in selected_indices:
|
| 855 |
+
if 320 <= lenth_list[idx] <= 1600:
|
| 856 |
+
rewards[idx] += 0.2
|
| 857 |
+
|
| 858 |
+
# print(rewards)
|
| 859 |
+
# print(completion_mask.sum(1))
|
| 860 |
+
|
| 861 |
+
# Compute grouped-wise rewards
|
| 862 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 863 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 864 |
+
|
| 865 |
+
# Normalize the rewards to compute the advantages
|
| 866 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 867 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 868 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 869 |
+
|
| 870 |
+
# x - x.detach() allows for preserving gradients from x
|
| 871 |
+
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
| 872 |
+
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
| 873 |
+
# per_token_loss = -per_token_loss
|
| 874 |
+
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
# import pdb
|
| 878 |
+
# pdb.set_trace()
|
| 879 |
+
|
| 880 |
+
# Log the metrics
|
| 881 |
+
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 882 |
+
self._metrics["completion_length"].append(completion_length)
|
| 883 |
+
|
| 884 |
+
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
| 885 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 886 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 887 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 888 |
+
else:
|
| 889 |
+
reward_func_name = reward_func.__name__
|
| 890 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 891 |
+
|
| 892 |
+
gathered_rewards = self.accelerator.gather_for_metrics(rewards)
|
| 893 |
+
|
| 894 |
+
num_devices = gathered_rewards.size(0) // self.num_generations
|
| 895 |
+
rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
|
| 896 |
+
wrong_devices = (rewards_per_device <= 1).all(dim=1)
|
| 897 |
+
wrong_ratio = wrong_devices.sum().item() / num_devices
|
| 898 |
+
|
| 899 |
+
correct_devices = (rewards_per_device >= 2).all(dim=1)
|
| 900 |
+
correct_ratio = correct_devices.sum().item() / num_devices
|
| 901 |
+
|
| 902 |
+
self._metrics["all_wrong"].append(wrong_ratio)
|
| 903 |
+
self._metrics["all_correct"].append(correct_ratio)
|
| 904 |
+
|
| 905 |
+
if self.temporal:
|
| 906 |
+
temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
|
| 907 |
+
self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
|
| 908 |
+
|
| 909 |
+
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
| 910 |
+
|
| 911 |
+
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
| 912 |
+
|
| 913 |
+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 914 |
+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
return loss
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 923 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 924 |
+
|
| 925 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 926 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 927 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
| 928 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 929 |
+
|
| 930 |
+
logs = {**logs, **metrics}
|
| 931 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 932 |
+
super().log(logs, start_time)
|
| 933 |
+
else: # transformers<=4.46
|
| 934 |
+
super().log(logs)
|
| 935 |
+
self._metrics.clear()
|