DingZhenDojoCat commited on
Commit
7ed0fb5
·
verified ·
1 Parent(s): bb7f76d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. images/dataset.png +3 -0
  3. images/demo1.png +3 -0
  4. previous_version/Video-R1-main-previous/images/2B_curve.png +3 -0
  5. previous_version/Video-R1-main-previous/images/7B_curve.png +3 -0
  6. previous_version/Video-R1-main-previous/images/7B_nextqa.png +3 -0
  7. previous_version/Video-R1-main-previous/images/CATER_new_003595.gif +3 -0
  8. previous_version/Video-R1-main-previous/images/sample.png +3 -0
  9. previous_version/Video-R1-main-previous/src/distill_r1/create_hf_dataset.py +119 -0
  10. previous_version/Video-R1-main-previous/src/distill_r1/generate_scene_qa_pairs.ipynb +569 -0
  11. previous_version/Video-R1-main-previous/src/distill_r1/grpo_r1_distilled.jpg +3 -0
  12. previous_version/Video-R1-main-previous/src/distill_r1/query_r1.py +114 -0
  13. previous_version/Video-R1-main-previous/src/eval/prompts/geoqa_test_prompts.jsonl +0 -0
  14. previous_version/Video-R1-main-previous/src/eval/prompts/superclevr_test200_counting_problems.jsonl +200 -0
  15. previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_counting_superclevr.py +136 -0
  16. previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa.py +149 -0
  17. previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa_multigpu.py +205 -0
  18. previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_video_counting.py +141 -0
  19. previous_version/Video-R1-main-previous/src/qwen-vl-utils/.python-version +1 -0
  20. previous_version/Video-R1-main-previous/src/qwen-vl-utils/README.md +94 -0
  21. previous_version/Video-R1-main-previous/src/qwen-vl-utils/pyproject.toml +75 -0
  22. previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements-dev.lock +84 -0
  23. previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements.lock +32 -0
  24. previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py +7 -0
  25. previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py +379 -0
  26. previous_version/Video-R1-main-previous/src/r1-v/temp_image.png +3 -0
  27. src/r1-v/.gitignore +178 -0
  28. src/r1-v/LICENSE +201 -0
  29. src/r1-v/Makefile +20 -0
  30. src/r1-v/setup.cfg +41 -0
  31. src/r1-v/setup.py +132 -0
  32. src/r1-v/src/open_r1/__init__.py +0 -0
  33. src/r1-v/src/open_r1/evaluate.py +85 -0
  34. src/r1-v/src/open_r1/generate.py +156 -0
  35. src/r1-v/src/open_r1/grpo-cot-72BEval.py +489 -0
  36. src/r1-v/src/open_r1/grpo-cot-LLMEval.py +552 -0
  37. src/r1-v/src/open_r1/grpo-cot-answerBERT-eval.py +429 -0
  38. src/r1-v/src/open_r1/grpo-cot-noDesEval.py +446 -0
  39. src/r1-v/src/open_r1/grpo-cot-noInfo.py +346 -0
  40. src/r1-v/src/open_r1/grpo-cot-qwenEval.py +523 -0
  41. src/r1-v/src/open_r1/grpo-cot-selfEval.py +457 -0
  42. src/r1-v/src/open_r1/grpo-cot-selfEvalConst.py +456 -0
  43. src/r1-v/src/open_r1/grpo-cot.py +351 -0
  44. src/r1-v/src/open_r1/grpo-description-LLMEval.py +579 -0
  45. src/r1-v/src/open_r1/grpo.py +318 -0
  46. src/r1-v/src/open_r1/grpo_vllm_caption.py +266 -0
  47. src/r1-v/src/open_r1/sft_video.py +304 -0
  48. src/r1-v/src/open_r1/trainer/__init__.py +12 -0
  49. src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_error.py +1061 -0
  50. src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_orig.py +935 -0
.gitattributes CHANGED
@@ -36,3 +36,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  images/curves.png filter=lfs diff=lfs merge=lfs -text
37
  images/demo2.png filter=lfs diff=lfs merge=lfs -text
38
  images/performance.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
36
  images/curves.png filter=lfs diff=lfs merge=lfs -text
37
  images/demo2.png filter=lfs diff=lfs merge=lfs -text
38
  images/performance.png filter=lfs diff=lfs merge=lfs -text
39
+ images/dataset.png filter=lfs diff=lfs merge=lfs -text
40
+ images/demo1.png filter=lfs diff=lfs merge=lfs -text
41
+ previous_version/Video-R1-main-previous/src/r1-v/temp_image.png filter=lfs diff=lfs merge=lfs -text
42
+ previous_version/Video-R1-main-previous/src/distill_r1/grpo_r1_distilled.jpg filter=lfs diff=lfs merge=lfs -text
43
+ previous_version/Video-R1-main-previous/images/7B_nextqa.png filter=lfs diff=lfs merge=lfs -text
44
+ previous_version/Video-R1-main-previous/images/sample.png filter=lfs diff=lfs merge=lfs -text
45
+ previous_version/Video-R1-main-previous/images/CATER_new_003595.gif filter=lfs diff=lfs merge=lfs -text
46
+ previous_version/Video-R1-main-previous/images/2B_curve.png filter=lfs diff=lfs merge=lfs -text
47
+ previous_version/Video-R1-main-previous/images/7B_curve.png filter=lfs diff=lfs merge=lfs -text
images/dataset.png ADDED

Git LFS Details

  • SHA256: e0affaa1cf8d870c6a6ec41be54494e073c51987fe5ad424a8ee3437b1dcc116
  • Pointer size: 131 Bytes
  • Size of remote file: 589 kB
images/demo1.png ADDED

Git LFS Details

  • SHA256: 94c40671d8761915a8de02f4548f0e1715069aa8d171f08d5b27af3f2a715548
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
previous_version/Video-R1-main-previous/images/2B_curve.png ADDED

Git LFS Details

  • SHA256: 8f5b2aaa2c501639cc570bf9f8b8a94dedf3e3d8f9b2ad2ef6e13d01478b733d
  • Pointer size: 131 Bytes
  • Size of remote file: 321 kB
previous_version/Video-R1-main-previous/images/7B_curve.png ADDED

Git LFS Details

  • SHA256: 38e42d31de8bf93659529b9334c1aa58c71d91fa55e1eeef7f4f6fece1ca4663
  • Pointer size: 131 Bytes
  • Size of remote file: 310 kB
previous_version/Video-R1-main-previous/images/7B_nextqa.png ADDED

Git LFS Details

  • SHA256: 99c0f930a3f67a870386ee16896b1f45a3c84dfd43b27dd4d128a8ae66406f19
  • Pointer size: 131 Bytes
  • Size of remote file: 334 kB
previous_version/Video-R1-main-previous/images/CATER_new_003595.gif ADDED

Git LFS Details

  • SHA256: 9ed0306a7a088e526eb2ccfb8e0f44d987fa48548248649f1cd4a270955634cd
  • Pointer size: 131 Bytes
  • Size of remote file: 777 kB
previous_version/Video-R1-main-previous/images/sample.png ADDED

Git LFS Details

  • SHA256: e616764501a3833e9035ccd48b79b19f23cc02c597cedde681edf0b63f27d09c
  • Pointer size: 131 Bytes
  • Size of remote file: 244 kB
previous_version/Video-R1-main-previous/src/distill_r1/create_hf_dataset.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from datasets import load_dataset
5
+ from tqdm import tqdm
6
+
7
+ random.seed(1234)
8
+ VAL_NUM = 5000
9
+
10
+
11
+ def create_r1_train_dataset(
12
+ valid_pair_json,
13
+ data_dir,
14
+ img_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/trainA/",
15
+ ):
16
+ os.makedirs(data_dir, exist_ok=True)
17
+ pairs = [json.loads(line) for line in open(valid_pair_json, "r")]
18
+ mapped_pairs = []
19
+
20
+ for idx, pair in tqdm(enumerate(pairs)):
21
+ img_filename = pair["img_filename"]
22
+ new_pair = {}
23
+ try:
24
+ new_pair["thinking"] = (
25
+ pair["r1_response"]
26
+ .split("<think>")[1]
27
+ .split("</think>")[0]
28
+ .replace("scene description", "image")
29
+ )
30
+ except Exception as e:
31
+ print(f"Error processing pair response: ", pair["r1_response"])
32
+ continue # skip this pair
33
+ # add index to distinguish the same image
34
+ dataset_filename = (
35
+ img_filename.split(".")[0] + "_" + str(idx) + "." + img_filename.split(".")[1]
36
+ )
37
+ if not os.path.exists(f"{data_dir}/{img_filename}"):
38
+ os.system(f"cp {img_dir}/{img_filename} {data_dir}/{dataset_filename}")
39
+ q, a = pair["q"], pair["a"]
40
+ new_pair["problem"] = q
41
+ # get the thinking path
42
+
43
+ new_pair["thinking"] = "<think>" + new_pair["thinking"] + "</think>"
44
+ new_pair["solution"] = f"<answer> {a} </answer>"
45
+ new_pair["file_name"] = dataset_filename
46
+ mapped_pairs.append(new_pair)
47
+ with open(f"{data_dir}/metadata.jsonl", "w") as f:
48
+ for pair in mapped_pairs:
49
+ f.write(json.dumps(pair) + "\n")
50
+
51
+ train_dataset = load_dataset(
52
+ "imagefolder",
53
+ data_dir=data_dir,
54
+ split="train",
55
+ )
56
+ return train_dataset
57
+
58
+
59
+ def create_val_dataset(
60
+ json_file,
61
+ data_dir,
62
+ val_num=VAL_NUM,
63
+ image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valB",
64
+ ):
65
+ os.makedirs(data_dir, exist_ok=True)
66
+ val = json.load(open(json_file))
67
+ random.shuffle(val)
68
+ val = val[:val_num]
69
+ val_pairs = []
70
+ for idx, pair in tqdm(enumerate(val)):
71
+ q, a = pair["q"], pair["a"]
72
+ img_filename = pair["img_filename"]
73
+ # copy images to the DATA_DIR
74
+ val_filename = (
75
+ img_filename.split(".")[0] + f"_{idx}." + img_filename.split(".")[1]
76
+ )
77
+ if not os.path.exists(f"{data_dir}/{img_filename}"):
78
+ os.system(f"cp {image_dir}/{img_filename} {data_dir}/{val_filename}")
79
+ new_pair = {}
80
+ new_pair["problem"] = q
81
+ new_pair["solution"] = f"<answer> {a} </answer>"
82
+ new_pair["file_name"] = val_filename
83
+ val_pairs.append(new_pair)
84
+ with open(f"{data_dir}/metadata.jsonl", "w") as f:
85
+ for pair in val_pairs:
86
+ f.write(json.dumps(pair) + "\n")
87
+ val_dataset = load_dataset("imagefolder", data_dir=data_dir, split="train")
88
+ return val_dataset
89
+
90
+
91
+ # valA split
92
+ VALA_DATA_DIR = "data/Clevr_CoGenT_ValA"
93
+ VALB_DATA_DIR = "data/Clevr_CoGenT_ValB"
94
+ valA_json = (
95
+ "/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_valA.json"
96
+ )
97
+ valB_json = (
98
+ "/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_valB.json"
99
+ )
100
+ TRAIN_DATADIR = "data/Clevr_CoGenT_TrainA_R1"
101
+ train_dataset = create_r1_train_dataset(
102
+ "/home/lilei/Visual-R1/filter_results_v2/valid_pairs.jsonl",
103
+ TRAIN_DATADIR,
104
+ )
105
+
106
+ # print(train_dataset)
107
+ valA_dataset = create_val_dataset(
108
+ valA_json,
109
+ VALA_DATA_DIR,
110
+ image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valA",
111
+ )
112
+ valB_dataset = create_val_dataset(
113
+ valB_json,
114
+ VALB_DATA_DIR,
115
+ image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valB",
116
+ )
117
+ valA_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_ValA")
118
+ valB_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_ValB")
119
+ train_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_TrainA_R1")
previous_version/Video-R1-main-previous/src/distill_r1/generate_scene_qa_pairs.ipynb ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "3a704ea6-2e61-4aaa-97aa-416579c9bc13",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json\n",
11
+ "import random"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 4,
17
+ "id": "c4920a8f-cddd-4063-8cab-215d238b5dad",
18
+ "metadata": {},
19
+ "outputs": [
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "CLEVR_trainA_scenes.json CLEVR_valA_scenes.json CLEVR_valB_scenes.json\n"
25
+ ]
26
+ }
27
+ ],
28
+ "source": [
29
+ "!ls CLEVR_CoGenT_v1.0/scenes"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 15,
35
+ "id": "934fa005-3b2a-43ed-8a71-6a12b7579546",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "split = \"valB\"\n",
40
+ "clevr_train_json = f\"CLEVR_CoGenT_v1.0/scenes/CLEVR_{split}_scenes.json\"\n",
41
+ "train_qs = f\"CLEVR_CoGenT_v1.0/questions/CLEVR_{split}_questions.json\"\n",
42
+ "data = json.load(open(clevr_train_json))\n",
43
+ "qs = json.load(open(train_qs))"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 16,
49
+ "id": "1f0d6180-94c4-4aea-bd2b-8d5cfeb0aecb",
50
+ "metadata": {},
51
+ "outputs": [
52
+ {
53
+ "name": "stdout",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "[{'pixel_coords': [343, 131, 11.278693199157715], 'size': 'small', 'color': 'green', 'material': 'metal', 'shape': 'sphere', '3d_coords': [0.9906095862388611, 2.083291530609131, 0.3499999940395355], 'rotation': 107.73596690369371}, {'pixel_coords': [396, 172, 9.857704162597656], 'size': 'small', 'color': 'cyan', 'material': 'rubber', 'shape': 'sphere', '3d_coords': [2.69626522064209, 1.5257188081741333, 0.3499999940395355], 'rotation': 305.3536122513589}, {'pixel_coords': [115, 182, 8.91348934173584], 'size': 'large', 'color': 'yellow', 'material': 'rubber', 'shape': 'cylinder', '3d_coords': [0.049163494259119034, -2.864100217819214, 0.699999988079071], 'rotation': 161.8370138842408}, {'pixel_coords': [203, 131, 10.548327445983887], 'size': 'large', 'color': 'purple', 'material': 'rubber', 'shape': 'cube', '3d_coords': [-0.4719269275665283, -0.5699371695518494, 0.699999988079071], 'rotation': 159.41862667811446}, {'pixel_coords': [253, 75, 13.141877174377441], 'size': 'large', 'color': 'red', 'material': 'rubber', 'shape': 'cube', '3d_coords': [-2.036878824234009, 2.222999334335327, 0.699999988079071], 'rotation': 37.40490732771224}]\n",
57
+ "len: 5\n"
58
+ ]
59
+ }
60
+ ],
61
+ "source": [
62
+ "print(data['scenes'][0]['objects'])\n",
63
+ "print(\"len: \", len(data['scenes'][0]['objects']))"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 17,
69
+ "id": "7c828ca4-08f9-4927-a745-224a95379c2f",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "def object_info_to_description(object_list):\n",
74
+ " descriptions = []\n",
75
+ " random.shuffle(object_list)\n",
76
+ " for obj in object_list:\n",
77
+ " desc = f\"A {obj['size']} {obj['color']} {obj['material']} {obj['shape']}\"\n",
78
+ " desc += f\" rotated {obj['rotation']:.1f}° located at\"\n",
79
+ " desc += f\" 3D coordinates ({obj['3d_coords'][0]:.2f}, {obj['3d_coords'][1]:.2f}, {obj['3d_coords'][2]:.2f})\"\n",
80
+ " desc += f\" and pixel coordinates ({obj['pixel_coords'][0]}, {obj['pixel_coords'][1]}, {obj['pixel_coords'][2]:.2f})\"\n",
81
+ " descriptions.append(desc)\n",
82
+ " \n",
83
+ " final_description = \"Scene Description:\\n\"\n",
84
+ " for i, desc in enumerate(descriptions, 1):\n",
85
+ " final_description += f\"{desc}\\n\"\n",
86
+ " \n",
87
+ " return final_description"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 18,
93
+ "id": "cb048e25-d554-4bd7-bf11-878e071b5987",
94
+ "metadata": {},
95
+ "outputs": [
96
+ {
97
+ "data": {
98
+ "text/plain": [
99
+ "'Scene Description:\\nA large yellow rubber cylinder rotated 161.8° located at 3D coordinates (0.05, -2.86, 0.70) and pixel coordinates (115, 182, 8.91)\\nA large purple rubber cube rotated 159.4° located at 3D coordinates (-0.47, -0.57, 0.70) and pixel coordinates (203, 131, 10.55)\\nA large red rubber cube rotated 37.4° located at 3D coordinates (-2.04, 2.22, 0.70) and pixel coordinates (253, 75, 13.14)\\nA small green metal sphere rotated 107.7° located at 3D coordinates (0.99, 2.08, 0.35) and pixel coordinates (343, 131, 11.28)\\nA small cyan rubber sphere rotated 305.4° located at 3D coordinates (2.70, 1.53, 0.35) and pixel coordinates (396, 172, 9.86)\\n'"
100
+ ]
101
+ },
102
+ "execution_count": 18,
103
+ "metadata": {},
104
+ "output_type": "execute_result"
105
+ }
106
+ ],
107
+ "source": [
108
+ "object_info_to_description(data['scenes'][0]['objects'])"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 19,
114
+ "id": "ffacd5f3-e9a4-46ca-8c50-187ab12c9f1b",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "img2obj_dict = {}\n",
119
+ "for scene in data['scenes']:\n",
120
+ " obj_list = scene['objects']\n",
121
+ " img2obj_dict[scene['image_filename']] = obj_list"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 20,
127
+ "id": "db35f03c-1529-4776-bf4f-3bd44e960e5f",
128
+ "metadata": {},
129
+ "outputs": [
130
+ {
131
+ "data": {
132
+ "text/plain": [
133
+ "{'question_index': 0,\n",
134
+ " 'question_family_index': 29,\n",
135
+ " 'image_index': 0,\n",
136
+ " 'question': 'The big thing that is in front of the large rubber cube in front of the small thing that is behind the tiny matte ball is what color?',\n",
137
+ " 'answer': 'yellow',\n",
138
+ " 'image_filename': 'CLEVR_valB_000000.png',\n",
139
+ " 'split': 'valB',\n",
140
+ " 'program': [{'value_inputs': [], 'inputs': [], 'function': 'scene'},\n",
141
+ " {'value_inputs': ['small'], 'inputs': [0], 'function': 'filter_size'},\n",
142
+ " {'value_inputs': ['rubber'], 'inputs': [1], 'function': 'filter_material'},\n",
143
+ " {'value_inputs': ['sphere'], 'inputs': [2], 'function': 'filter_shape'},\n",
144
+ " {'value_inputs': [], 'inputs': [3], 'function': 'unique'},\n",
145
+ " {'value_inputs': ['behind'], 'inputs': [4], 'function': 'relate'},\n",
146
+ " {'value_inputs': ['small'], 'inputs': [5], 'function': 'filter_size'},\n",
147
+ " {'value_inputs': [], 'inputs': [6], 'function': 'unique'},\n",
148
+ " {'value_inputs': ['front'], 'inputs': [7], 'function': 'relate'},\n",
149
+ " {'value_inputs': ['large'], 'inputs': [8], 'function': 'filter_size'},\n",
150
+ " {'value_inputs': ['rubber'], 'inputs': [9], 'function': 'filter_material'},\n",
151
+ " {'value_inputs': ['cube'], 'inputs': [10], 'function': 'filter_shape'},\n",
152
+ " {'value_inputs': [], 'inputs': [11], 'function': 'unique'},\n",
153
+ " {'value_inputs': ['front'], 'inputs': [12], 'function': 'relate'},\n",
154
+ " {'value_inputs': ['large'], 'inputs': [13], 'function': 'filter_size'},\n",
155
+ " {'value_inputs': [], 'inputs': [14], 'function': 'unique'},\n",
156
+ " {'value_inputs': [], 'inputs': [15], 'function': 'query_color'}]}"
157
+ ]
158
+ },
159
+ "execution_count": 20,
160
+ "metadata": {},
161
+ "output_type": "execute_result"
162
+ }
163
+ ],
164
+ "source": [
165
+ "qs['questions'][0]"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 21,
171
+ "id": "66b746fc-569c-4922-a442-79dbbc09e33b",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "random.shuffle(qs['questions'])\n",
176
+ "cnt = 0 \n",
177
+ "qa_pairs = [] \n",
178
+ "added_pair = set()\n",
179
+ "for qd in qs['questions']:\n",
180
+ " img_idx = qd['image_filename']\n",
181
+ " total_count = len(img2obj_dict[img_idx]) # object list length\n",
182
+ " desc = object_info_to_description(img2obj_dict[img_idx])\n",
183
+ " question, answer = qd['question'], qd['answer']\n",
184
+ " if 'how many' in question.lower() or 'number' in question.lower():\n",
185
+ " qa_pairs.append({\n",
186
+ " \"img_filename\": img_idx,\n",
187
+ " 'q': question,\n",
188
+ " 'a': answer,\n",
189
+ " 'description': desc \n",
190
+ " })\n",
191
+ " if img_idx not in added_pair:\n",
192
+ " qa_pairs.append({\n",
193
+ " \"img_filename\": img_idx,\n",
194
+ " 'q': \"How many items are there in the described scene?\",\n",
195
+ " 'a': total_count,\n",
196
+ " 'description': desc \n",
197
+ " })\n",
198
+ " added_pair.add(img_idx)\n"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 22,
204
+ "id": "c271fa7b-fed5-472f-a302-6ec203c4b787",
205
+ "metadata": {},
206
+ "outputs": [
207
+ {
208
+ "data": {
209
+ "text/plain": [
210
+ "59978"
211
+ ]
212
+ },
213
+ "execution_count": 22,
214
+ "metadata": {},
215
+ "output_type": "execute_result"
216
+ }
217
+ ],
218
+ "source": [
219
+ "len(qa_pairs)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 23,
225
+ "id": "b0da8a70-c3f5-4e48-b384-3684933d72ef",
226
+ "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "data": {
230
+ "text/plain": [
231
+ "14884"
232
+ ]
233
+ },
234
+ "execution_count": 23,
235
+ "metadata": {},
236
+ "output_type": "execute_result"
237
+ }
238
+ ],
239
+ "source": [
240
+ "len(added_pair)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 24,
246
+ "id": "c648587e-2ec0-427c-b594-f55dd187b4d9",
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "# save for later loading\n",
251
+ "with open(f\"clever_counting_problems_clevr_cogent_v1.0_{split}.json\", 'w') as fw:\n",
252
+ " json.dump( qa_pairs, fw, indent=4)"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": 20,
258
+ "id": "b3a8cbe4-4261-41d3-a481-43a0b1cc2795",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "random.shuffle(qa_pairs)"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 57,
268
+ "id": "d6dff4e7-65dd-4e82-82df-340ec2a57919",
269
+ "metadata": {},
270
+ "outputs": [
271
+ {
272
+ "data": {
273
+ "text/plain": [
274
+ "[{'img_filename': 'CLEVR_trainA_048403.png',\n",
275
+ " 'q': 'How many things are both on the right side of the big yellow rubber thing and left of the purple ball?',\n",
276
+ " 'a': '5',\n",
277
+ " 'description': 'Scene Description:\\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\\n'},\n",
278
+ " {'img_filename': 'CLEVR_trainA_048403.png',\n",
279
+ " 'q': 'How many items are there in the described scene?',\n",
280
+ " 'a': 9,\n",
281
+ " 'description': 'Scene Description:\\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\\n'}]"
282
+ ]
283
+ },
284
+ "execution_count": 57,
285
+ "metadata": {},
286
+ "output_type": "execute_result"
287
+ }
288
+ ],
289
+ "source": [
290
+ "qa_pairs[:2]"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": 26,
296
+ "id": "a6a66364-5b47-4138-91d6-a045404d21b1",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "def query_r1(query='who are you?', model=\"deepseek-ai/DeepSeek-R1\"):\n",
301
+ " # Create the chat completion\n",
302
+ " response = client.chat.completions.create(\n",
303
+ " model=model,\n",
304
+ " messages=[\n",
305
+ " {'role': 'user', \n",
306
+ " 'content': query}\n",
307
+ " ],\n",
308
+ " stream=False,\n",
309
+ " )\n",
310
+ " # Print the response\n",
311
+ " return response.choices[0].message.content"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": 44,
317
+ "id": "e5d5649f-c4e3-4f3f-b76e-7f7ed27f68e8",
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": [
321
+ "def format_query(qa_dict):\n",
322
+ " query = \"Answer the question according to scene description.\\n\\n\"\n",
323
+ " query += qa_dict['description']\n",
324
+ " query += f\"\\nQuestion:\\n{qa_dict['q']}\"\n",
325
+ " return query \n",
326
+ " "
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": 39,
332
+ "id": "7f568a4e-f217-464a-8329-bbefb64d9653",
333
+ "metadata": {},
334
+ "outputs": [
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "<think>Okay, let's see. The user is asking how many items are there in the described scene. Let me go through the scene description step by step.\n",
340
+ "\n",
341
+ "So, the scene description lists each object with details like color, material, shape, rotation, 3D coordinates, and pixel coordinates. Each entry starts with \"A\" which usually indicates one item each. Let me count each one.\n",
342
+ "\n",
343
+ "First entry: \"A small green metal cylinder...\" That's one. Second: \"A small blue rubber cylinder...\" Second item. Third: \"A small cyan rubber cylinder...\" That's three. Fourth: \"A large cyan metal sphere...\" Four. Fifth: \"A large brown metal cube...\" Five. Sixth: \"A large yellow rubber cube...\" Six. Seventh: \"A large brown rubber cylinder...\" That's seven. \n",
344
+ "\n",
345
+ "Wait, did I miss any? Let me check again. The list has entries from \"A small green...\" up to the seventh one. Each sentence starts with \"A\", which suggests each is a separate item. No commas separating multiple items in a single entry. Each has different attributes and coordinates, so they must all be distinct. \n",
346
+ "\n",
347
+ "So the answer should be 7 items.\n",
348
+ "</think>\n",
349
+ "\n",
350
+ "There are 7 items in the described scene. Each entry corresponds to one distinct object, listed by their properties, coordinates, and rotations.\n",
351
+ "None\n"
352
+ ]
353
+ }
354
+ ],
355
+ "source": [
356
+ "debug_query = format_query(qa_pairs[0])\n",
357
+ "print(query_r1(debug_query))"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 41,
363
+ "id": "cdc4231a-8ef4-4cf6-a575-d84ae7bbd0b5",
364
+ "metadata": {},
365
+ "outputs": [
366
+ {
367
+ "name": "stdout",
368
+ "output_type": "stream",
369
+ "text": [
370
+ "Answer the question accordingly to scene description.\n",
371
+ "\n",
372
+ "Scene Description:\n",
373
+ "A small green metal cylinder rotated 329.5° located at 3D coordinates (-2.49, -1.65, 0.35) and pixel coordinates (111, 132, 11.81)\n",
374
+ "A small blue rubber cylinder rotated 312.2° located at 3D coordinates (-1.73, -2.91, 0.35) and pixel coordinates (76, 163, 10.57)\n",
375
+ "A small cyan rubber cylinder rotated 48.4° located at 3D coordinates (-2.10, -0.22, 0.35) and pixel coordinates (172, 118, 12.41)\n",
376
+ "A large cyan metal sphere rotated 27.4° located at 3D coordinates (1.52, -1.26, 0.70) and pixel coordinates (247, 181, 9.33)\n",
377
+ "A large brown metal cube rotated 107.7° located at 3D coordinates (-0.73, 2.39, 0.70) and pixel coordinates (290, 92, 12.93)\n",
378
+ "A large yellow rubber cube rotated 288.2° located at 3D coordinates (0.52, 0.63, 0.70) and pixel coordinates (279, 130, 11.09)\n",
379
+ "A large brown rubber cylinder rotated 229.8° located at 3D coordinates (2.38, 0.38, 0.70) and pixel coordinates (343, 166, 9.77)\n",
380
+ "\n",
381
+ "Question:\n",
382
+ "How many items are there in the described scene?\n"
383
+ ]
384
+ }
385
+ ],
386
+ "source": [
387
+ "print(debug_query)"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": 42,
393
+ "id": "4cf90eb6-2cce-4e3d-8190-c44168a66dca",
394
+ "metadata": {},
395
+ "outputs": [
396
+ {
397
+ "data": {
398
+ "text/plain": [
399
+ "{'img_filename': 'CLEVR_train_044000.png',\n",
400
+ " 'q': 'How many rubber objects are either small blue spheres or small things?',\n",
401
+ " 'a': '2',\n",
402
+ " 'description': 'Scene Description:\\nA large purple rubber sphere rotated 78.4° located at 3D coordinates (2.27, 0.87, 0.70) and pixel coordinates (360, 156, 9.49)\\nA large gray metal cube rotated 152.7° located at 3D coordinates (2.79, -1.26, 0.70) and pixel coordinates (301, 213, 7.91)\\nA large purple metal sphere rotated 79.2° located at 3D coordinates (-2.66, -2.74, 0.70) and pixel coordinates (51, 126, 10.61)\\nA large blue rubber sphere rotated 279.5° located at 3D coordinates (1.31, 2.72, 0.70) and pixel coordinates (376, 112, 11.19)\\nA small brown rubber cube rotated 124.1° located at 3D coordinates (-2.49, 2.61, 0.35) and pixel coordinates (251, 82, 13.79)\\nA small green rubber sphere rotated 323.9° located at 3D coordinates (-2.02, 0.45, 0.35) and pixel coordinates (197, 109, 12.22)\\n'}"
403
+ ]
404
+ },
405
+ "execution_count": 42,
406
+ "metadata": {},
407
+ "output_type": "execute_result"
408
+ }
409
+ ],
410
+ "source": [
411
+ "qa_pairs[1]"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": 45,
417
+ "id": "33fcd4eb-1f33-47d4-a453-76ef00e6d5d3",
418
+ "metadata": {},
419
+ "outputs": [
420
+ {
421
+ "name": "stdout",
422
+ "output_type": "stream",
423
+ "text": [
424
+ "<think>Okay, let's tackle this question. The user wants to know how many rubber objects are either small blue spheres or small things. Hmm, first, I need to parse each part of the question correctly.\n",
425
+ "\n",
426
+ "Let's go back to the scene description. The objects are listed with their attributes: size, color, material, shape, rotation, 3D and pixel coordinates. The key here is to filter the objects based on the criteria given.\n",
427
+ "\n",
428
+ "The question has two parts: \"small blue spheres\" OR \"small things\". So any rubber object that is either a small blue sphere or any small thing (regardless of other attributes) counts. But wait, do both categories need to be rubber? Because the question says \"rubber objects are either...\". So rubber is the material, and the condition is either being a small blue sphere or a small thing. So first, all rubber objects, and among them, count those that are either small blue spheres or small (regardless of color or shape). Wait, no. Let me re-read the question.\n",
429
+ "\n",
430
+ "The question is: \"How many rubber objects are either small blue spheres or small things?\" So rubber is the material. Within all rubber objects, count the ones that are either (1) small blue spheres OR (2) small things. Wait, does (2) being small things mean that even if they're small and of any color or shape, but they must be rubber?\n",
431
+ "\n",
432
+ "Yes, because the entire set is rubber objects. So first, select all objects where material is rubber. Then, within those, count how many are either (1) small, blue, sphere, or (2) small (any color or shape). Wait, but the structure is \"either X or Y\", where X is \"small blue sphere\" and Y is \"small things\". But \"small things\" would include all small objects, regardless of color and shape. However, since we've already narrowed it to rubber objects, \"small things\" here would be small rubber objects, regardless of color and shape.\n",
433
+ "\n",
434
+ "But wait, the condition is within rubber objects. So for the first part, small blue spheres (must check size, color, shape) and for the second part, small things (size is small, any color and shape, but since material is already rubber, that's covered). But wait, does the OR merge the two conditions, leading to rubber objects that are either (small blue spheres) or (small any-color any-shape).\n",
435
+ "\n",
436
+ "So the combined condition is: object is rubber AND ( (is small AND blue AND sphere) OR (is small) ). Wait, but if the condition for the second part is just \"small things\", which would imply any small object. But the entire group is rubber objects. So it's rubber objects that are small blue spheres OR rubber objects that are small (regardless of color or shape).\n",
437
+ "\n",
438
+ "Wait, no. Let's parse the sentence again: \"rubber objects are either small blue spheres or small things\". The \"either/or\" applies to \"small blue spheres\" and \"small things\". So, each rubber object has to be either (a small blue sphere) or (a small thing). However, \"small things\" here might refer to any small object regardless of other attributes. So if a rubber object is small, regardless of color or shape, it counts. But then, the first condition (small blue sphere) would also satisfy being a small thing. Wait, so there's an overlap. But when dealing with OR conditions, we have to avoid double-counting. So, the actual count is the number of small rubber objects (since any small rubber object is covered by the second part, which includes all small rubber objects, whether blue spheres or not) plus any objects that are small blue spheres but not rubber? But no, the question specifies \"rubber objects\", so we can ignore non-rubber ones.\n",
439
+ "\n",
440
+ "Wait, perhaps the wording is: \"rubber objects that are either small blue spheres or small things\". So \"small things\" here must reference other attributes. Wait, maybe there's ambiguity here. If the user is grouping \"small things\" as a separate category, regardless of being the other attributes. Let me try to approach this step by step.\n",
441
+ "\n",
442
+ "First, list all the rubber objects from the scene description:\n",
443
+ "\n",
444
+ "Looking through the list:\n",
445
+ "\n",
446
+ "1. A large purple rubber sphere ... location etc.\n",
447
+ "So material rubber, large, purple, sphere.\n",
448
+ "\n",
449
+ "2. A large gray metal cube ... metal, so not rubber.\n",
450
+ "\n",
451
+ "3. A large purple metal sphere ... metal, not rubber.\n",
452
+ "\n",
453
+ "4. A large blue rubber sphere ... rubber, large, blue, sphere.\n",
454
+ "\n",
455
+ "5. A small brown rubber cube ... rubber, small, brown, cube.\n",
456
+ "\n",
457
+ "6. A small green rubber sphere ... rubber, small, green, sphere.\n",
458
+ "\n",
459
+ "So the rubber objects are items 1,4,5,6.\n",
460
+ "\n",
461
+ "Now, for each of these four rubber objects, check if they are either (small blue sphere) or (small things).\n",
462
+ "\n",
463
+ "Let's check each:\n",
464
+ "\n",
465
+ "1. Large purple rubber sphere. Size: large. So for the first condition (small blue sphere): no. For the second condition (small thing): size large, so no. Not included.\n",
466
+ "\n",
467
+ "4. Large blue rubber sphere. Large. So, even though it's blue and sphere, the size is large. So (small blue sphere: no, since size is large). (small thing: no.) So not included.\n",
468
+ "\n",
469
+ "5. Small brown rubber cube. Size is small. So regardless of color and shape (brown, cube), this is a small thing. So it meets the second condition. So count it.\n",
470
+ "\n",
471
+ "6. Small green rubber sphere. Size: small. So meets the second condition (small thing). But it's a sphere and green, not blue. So this counts under the second category.\n",
472
+ "\n",
473
+ "Now, what about the first condition: small blue spheres. Are there any rubber objects that are small blue spheres? The only blue rubber sphere is item 4, which is large. So none of the rubber objects are small blue spheres.\n",
474
+ "\n",
475
+ "So now the two options under the OR are:\n",
476
+ "\n",
477
+ "- small blue sphere (none found)\n",
478
+ "- small things (items 5 and 6), both are rubber.\n",
479
+ "\n",
480
+ "Therefore total count is 0 + 2 = 2.\n",
481
+ "\n",
482
+ "Wait, but the initial question is: number of rubber objects that are either (small blue spheres) or (small things). Since \"small blue spheres\" are a subset of \"small things\" (since if something is a small blue sphere, it is a small thing). Therefore if there were any such objects, they would be counted twice, but in this case, since there are none, the total is just the count of small things.\n",
483
+ "\n",
484
+ "But according to the data, the rubber objects that are small are two: item 5 (small brown cube) and item 6 (small green sphere). Neither is a blue sphere, so no overlap. Therefore, sum is 2.\n",
485
+ "\n",
486
+ "So the answer should be 2.\n",
487
+ "\n",
488
+ "I need to confirm again whether the OR condition includes the union, meaning any rubber object that is either condition. So if either condition is true, count it.\n",
489
+ "\n",
490
+ "Since both 5 and 6 are small, even though they aren't blue spheres, they belong to the second category (small things). The blue sphere (if it were small) would have been counted as well. But in this case, the valid counts are 5 and 6. So yes, answer is 2.\n",
491
+ "</think>\n",
492
+ "\n",
493
+ "The answer is 2. \n",
494
+ "\n",
495
+ "**Explanation:**\n",
496
+ "First, identify all rubber objects from the scene description:\n",
497
+ "1. Large purple rubber sphere (not small)\n",
498
+ "2. Large blue rubber sphere (not small)\n",
499
+ "3. Small brown rubber cube (small)\n",
500
+ "4. Small green rubber sphere (small)\n",
501
+ "\n",
502
+ "Next, apply the criteria:\n",
503
+ "- **Small blue spheres**: None of the rubber objects meet this (the only blue rubber sphere is large).\n",
504
+ "- **Small rubber objects (regardless of color/shape)**: The small brown rubber cube and small green rubber sphere qualify (2 objects).\n",
505
+ "\n",
506
+ "Thus, there are **2 rubber objects** that fit either criterion.\n"
507
+ ]
508
+ }
509
+ ],
510
+ "source": [
511
+ "debug_query1 = format_query(qa_pairs[1])\n",
512
+ "res1 = query_r1(debug_query1)"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 47,
518
+ "id": "8e516bd0-f1e5-4898-88a3-3afcaf0ae34e",
519
+ "metadata": {},
520
+ "outputs": [
521
+ {
522
+ "data": {
523
+ "text/plain": [
524
+ "{'img_filename': 'CLEVR_train_044000.png',\n",
525
+ " 'q': 'How many rubber objects are either small blue spheres or small things?',\n",
526
+ " 'a': '2',\n",
527
+ " 'description': 'Scene Description:\\nA large purple rubber sphere rotated 78.4° located at 3D coordinates (2.27, 0.87, 0.70) and pixel coordinates (360, 156, 9.49)\\nA large gray metal cube rotated 152.7° located at 3D coordinates (2.79, -1.26, 0.70) and pixel coordinates (301, 213, 7.91)\\nA large purple metal sphere rotated 79.2° located at 3D coordinates (-2.66, -2.74, 0.70) and pixel coordinates (51, 126, 10.61)\\nA large blue rubber sphere rotated 279.5° located at 3D coordinates (1.31, 2.72, 0.70) and pixel coordinates (376, 112, 11.19)\\nA small brown rubber cube rotated 124.1° located at 3D coordinates (-2.49, 2.61, 0.35) and pixel coordinates (251, 82, 13.79)\\nA small green rubber sphere rotated 323.9° located at 3D coordinates (-2.02, 0.45, 0.35) and pixel coordinates (197, 109, 12.22)\\n'}"
528
+ ]
529
+ },
530
+ "execution_count": 47,
531
+ "metadata": {},
532
+ "output_type": "execute_result"
533
+ }
534
+ ],
535
+ "source": [
536
+ "qa_pairs[1]"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "code",
541
+ "execution_count": null,
542
+ "id": "92784518-49e2-443d-9541-2785cbb944cf",
543
+ "metadata": {},
544
+ "outputs": [],
545
+ "source": []
546
+ }
547
+ ],
548
+ "metadata": {
549
+ "kernelspec": {
550
+ "display_name": "Python 3 (ipykernel)",
551
+ "language": "python",
552
+ "name": "python3"
553
+ },
554
+ "language_info": {
555
+ "codemirror_mode": {
556
+ "name": "ipython",
557
+ "version": 3
558
+ },
559
+ "file_extension": ".py",
560
+ "mimetype": "text/x-python",
561
+ "name": "python",
562
+ "nbconvert_exporter": "python",
563
+ "pygments_lexer": "ipython3",
564
+ "version": "3.12.2"
565
+ }
566
+ },
567
+ "nbformat": 4,
568
+ "nbformat_minor": 5
569
+ }
previous_version/Video-R1-main-previous/src/distill_r1/grpo_r1_distilled.jpg ADDED

Git LFS Details

  • SHA256: e0f6135ef837a375090b07e29a18fd2d5cb819100c73d5dc7ea63401f66caf59
  • Pointer size: 131 Bytes
  • Size of remote file: 304 kB
previous_version/Video-R1-main-previous/src/distill_r1/query_r1.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import os
4
+ from openai import OpenAI
5
+ from tqdm import tqdm
6
+ import concurrent.futures
7
+ from typing import List, Dict, Optional
8
+ from datetime import datetime
9
+ from threading import Lock
10
+ import time
11
+ from prompt import R1_SYS_PROMPT
12
+ # Initialize the client
13
+ client = OpenAI(
14
+ api_key=os.environ.get("SL_KEY", "YOUR_SILCONFLOW_KEY"),
15
+ base_url="https://api.siliconflow.cn/v1",
16
+ )
17
+
18
+ # Create a lock for thread-safe file writing
19
+ file_lock = Lock()
20
+
21
+ def format_query(qa_dict: Dict, v2=False) -> str:
22
+ query = "Answer the question according to scene description.\n\n"
23
+ query += qa_dict["description"]
24
+ query += f"\nQuestion:\n{qa_dict['q']}"
25
+ if v2:
26
+ query += "\nInstructions:\n"
27
+ query += "1. Carefully analyze the scene description\n"
28
+ query += "2. Provide your reasoning if necessary\n"
29
+ query += "3. For the final answer, start a new line with '**The answer is: **' followed by your answer\n"
30
+ return query
31
+
32
+ def write_to_jsonl(result: Dict, filename: str):
33
+ """Thread-safe function to write a result to JSONL file"""
34
+ with file_lock:
35
+ with open(filename, 'a') as f:
36
+ f.write(json.dumps(result) + '\n')
37
+
38
+ def query_r1(qa_pair: Dict, output_file: str, model: str = "deepseek-ai/DeepSeek-R1", v2=False) -> Optional[Dict]:
39
+ query = format_query(qa_pair, v2=v2)
40
+ try:
41
+ response = client.chat.completions.create(
42
+ model=model,
43
+ messages=[
44
+ {"role": "system", "content": R1_SYS_PROMPT},
45
+ {"role": "user", "content": query}],
46
+ stream=False,
47
+ max_tokens=4096
48
+ )
49
+ result = {
50
+ **qa_pair,
51
+ "r1_response": response.choices[0].message.content,
52
+ "timestamp": datetime.now().isoformat()
53
+ }
54
+ # Write result immediately
55
+ write_to_jsonl(result, output_file)
56
+ time.sleep(4)
57
+ return result
58
+ except Exception as e:
59
+ print(f"Error processing query: {e}")
60
+ error_result = {
61
+ **qa_pair,
62
+ "error": str(e),
63
+ "timestamp": datetime.now().isoformat()
64
+ }
65
+ write_to_jsonl(error_result, f"errors_{output_file}")
66
+ time.sleep(10)
67
+ return None
68
+
69
+ def process_qa_pairs_parallel(qa_pairs: List[Dict], output_file: str, max_workers: int = 10) -> List[Dict]:
70
+ successful_count = 0
71
+
72
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
73
+ # Create futures for all qa_pairs
74
+ futures = [executor.submit(query_r1, qa_pair, output_file, v2="v2" in output_file) for qa_pair in qa_pairs]
75
+
76
+ # Process results as they complete with progress bar
77
+ results = []
78
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
79
+ try:
80
+ result = future.result()
81
+ if result is not None:
82
+ results.append(result)
83
+ successful_count += 1
84
+ except Exception as e:
85
+ print(f"Failed to process query: {e}")
86
+
87
+ return results
88
+
89
+ if __name__ == "__main__":
90
+ # Load and shuffle QA pairs
91
+ random.seed(1234)
92
+ qa_pairs = json.load(open("/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_trainA.json"))
93
+ random.shuffle(qa_pairs)
94
+ qa_pairs = qa_pairs[:10000]
95
+ # Create output filename with timestamp
96
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
97
+ output_file = f"r1_results_clevr_cogent_v1.0_trainA_v2.jsonl"
98
+
99
+ finished = set()
100
+ with open(output_file, 'r') as f:
101
+ for line in f:
102
+ ins = json.loads(line)
103
+ key = ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"])
104
+ finished.add(key)
105
+ qa_pairs = [ins for ins in qa_pairs if ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"]) not in finished]
106
+ print("Finished: ", len(finished))
107
+ print("Remaining: ", len(qa_pairs))
108
+ # Process QA pairs in parallel
109
+ r1_results = process_qa_pairs_parallel(qa_pairs, output_file)
110
+
111
+ # Print final statistics
112
+ print(f"Successfully processed {len(r1_results)} out of {len(qa_pairs)} queries")
113
+ print(f"Results saved to {output_file}")
114
+ print(f"Any errors were saved to errors_{output_file}")
previous_version/Video-R1-main-previous/src/eval/prompts/geoqa_test_prompts.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
previous_version/Video-R1-main-previous/src/eval/prompts/superclevr_test200_counting_problems.jsonl ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"image_path": "./images/superCLEVR_new_025000.png", "question": "How many different items are there in the image?", "ground_truth": 4}
2
+ {"image_path": "./images/superCLEVR_new_025001.png", "question": "How many different items are there in the image?", "ground_truth": 9}
3
+ {"image_path": "./images/superCLEVR_new_025002.png", "question": "How many different items are there in the image?", "ground_truth": 10}
4
+ {"image_path": "./images/superCLEVR_new_025003.png", "question": "How many different items are there in the image?", "ground_truth": 4}
5
+ {"image_path": "./images/superCLEVR_new_025004.png", "question": "How many different items are there in the image?", "ground_truth": 3}
6
+ {"image_path": "./images/superCLEVR_new_025005.png", "question": "How many different items are there in the image?", "ground_truth": 3}
7
+ {"image_path": "./images/superCLEVR_new_025006.png", "question": "How many different items are there in the image?", "ground_truth": 3}
8
+ {"image_path": "./images/superCLEVR_new_025007.png", "question": "How many different items are there in the image?", "ground_truth": 4}
9
+ {"image_path": "./images/superCLEVR_new_025008.png", "question": "How many different items are there in the image?", "ground_truth": 9}
10
+ {"image_path": "./images/superCLEVR_new_025009.png", "question": "How many different items are there in the image?", "ground_truth": 10}
11
+ {"image_path": "./images/superCLEVR_new_025010.png", "question": "How many different items are there in the image?", "ground_truth": 7}
12
+ {"image_path": "./images/superCLEVR_new_025011.png", "question": "How many different items are there in the image?", "ground_truth": 7}
13
+ {"image_path": "./images/superCLEVR_new_025012.png", "question": "How many different items are there in the image?", "ground_truth": 7}
14
+ {"image_path": "./images/superCLEVR_new_025013.png", "question": "How many different items are there in the image?", "ground_truth": 6}
15
+ {"image_path": "./images/superCLEVR_new_025014.png", "question": "How many different items are there in the image?", "ground_truth": 5}
16
+ {"image_path": "./images/superCLEVR_new_025015.png", "question": "How many different items are there in the image?", "ground_truth": 10}
17
+ {"image_path": "./images/superCLEVR_new_025016.png", "question": "How many different items are there in the image?", "ground_truth": 4}
18
+ {"image_path": "./images/superCLEVR_new_025017.png", "question": "How many different items are there in the image?", "ground_truth": 5}
19
+ {"image_path": "./images/superCLEVR_new_025018.png", "question": "How many different items are there in the image?", "ground_truth": 6}
20
+ {"image_path": "./images/superCLEVR_new_025019.png", "question": "How many different items are there in the image?", "ground_truth": 8}
21
+ {"image_path": "./images/superCLEVR_new_025020.png", "question": "How many different items are there in the image?", "ground_truth": 10}
22
+ {"image_path": "./images/superCLEVR_new_025021.png", "question": "How many different items are there in the image?", "ground_truth": 3}
23
+ {"image_path": "./images/superCLEVR_new_025022.png", "question": "How many different items are there in the image?", "ground_truth": 4}
24
+ {"image_path": "./images/superCLEVR_new_025023.png", "question": "How many different items are there in the image?", "ground_truth": 4}
25
+ {"image_path": "./images/superCLEVR_new_025024.png", "question": "How many different items are there in the image?", "ground_truth": 5}
26
+ {"image_path": "./images/superCLEVR_new_025025.png", "question": "How many different items are there in the image?", "ground_truth": 5}
27
+ {"image_path": "./images/superCLEVR_new_025026.png", "question": "How many different items are there in the image?", "ground_truth": 7}
28
+ {"image_path": "./images/superCLEVR_new_025027.png", "question": "How many different items are there in the image?", "ground_truth": 4}
29
+ {"image_path": "./images/superCLEVR_new_025028.png", "question": "How many different items are there in the image?", "ground_truth": 4}
30
+ {"image_path": "./images/superCLEVR_new_025029.png", "question": "How many different items are there in the image?", "ground_truth": 9}
31
+ {"image_path": "./images/superCLEVR_new_025030.png", "question": "How many different items are there in the image?", "ground_truth": 8}
32
+ {"image_path": "./images/superCLEVR_new_025031.png", "question": "How many different items are there in the image?", "ground_truth": 6}
33
+ {"image_path": "./images/superCLEVR_new_025032.png", "question": "How many different items are there in the image?", "ground_truth": 3}
34
+ {"image_path": "./images/superCLEVR_new_025033.png", "question": "How many different items are there in the image?", "ground_truth": 10}
35
+ {"image_path": "./images/superCLEVR_new_025034.png", "question": "How many different items are there in the image?", "ground_truth": 9}
36
+ {"image_path": "./images/superCLEVR_new_025035.png", "question": "How many different items are there in the image?", "ground_truth": 9}
37
+ {"image_path": "./images/superCLEVR_new_025036.png", "question": "How many different items are there in the image?", "ground_truth": 3}
38
+ {"image_path": "./images/superCLEVR_new_025037.png", "question": "How many different items are there in the image?", "ground_truth": 6}
39
+ {"image_path": "./images/superCLEVR_new_025038.png", "question": "How many different items are there in the image?", "ground_truth": 6}
40
+ {"image_path": "./images/superCLEVR_new_025039.png", "question": "How many different items are there in the image?", "ground_truth": 5}
41
+ {"image_path": "./images/superCLEVR_new_025040.png", "question": "How many different items are there in the image?", "ground_truth": 3}
42
+ {"image_path": "./images/superCLEVR_new_025041.png", "question": "How many different items are there in the image?", "ground_truth": 10}
43
+ {"image_path": "./images/superCLEVR_new_025042.png", "question": "How many different items are there in the image?", "ground_truth": 6}
44
+ {"image_path": "./images/superCLEVR_new_025043.png", "question": "How many different items are there in the image?", "ground_truth": 3}
45
+ {"image_path": "./images/superCLEVR_new_025044.png", "question": "How many different items are there in the image?", "ground_truth": 6}
46
+ {"image_path": "./images/superCLEVR_new_025045.png", "question": "How many different items are there in the image?", "ground_truth": 5}
47
+ {"image_path": "./images/superCLEVR_new_025046.png", "question": "How many different items are there in the image?", "ground_truth": 7}
48
+ {"image_path": "./images/superCLEVR_new_025047.png", "question": "How many different items are there in the image?", "ground_truth": 5}
49
+ {"image_path": "./images/superCLEVR_new_025048.png", "question": "How many different items are there in the image?", "ground_truth": 5}
50
+ {"image_path": "./images/superCLEVR_new_025049.png", "question": "How many different items are there in the image?", "ground_truth": 10}
51
+ {"image_path": "./images/superCLEVR_new_025050.png", "question": "How many different items are there in the image?", "ground_truth": 6}
52
+ {"image_path": "./images/superCLEVR_new_025051.png", "question": "How many different items are there in the image?", "ground_truth": 3}
53
+ {"image_path": "./images/superCLEVR_new_025052.png", "question": "How many different items are there in the image?", "ground_truth": 7}
54
+ {"image_path": "./images/superCLEVR_new_025053.png", "question": "How many different items are there in the image?", "ground_truth": 9}
55
+ {"image_path": "./images/superCLEVR_new_025054.png", "question": "How many different items are there in the image?", "ground_truth": 7}
56
+ {"image_path": "./images/superCLEVR_new_025055.png", "question": "How many different items are there in the image?", "ground_truth": 6}
57
+ {"image_path": "./images/superCLEVR_new_025056.png", "question": "How many different items are there in the image?", "ground_truth": 9}
58
+ {"image_path": "./images/superCLEVR_new_025057.png", "question": "How many different items are there in the image?", "ground_truth": 8}
59
+ {"image_path": "./images/superCLEVR_new_025058.png", "question": "How many different items are there in the image?", "ground_truth": 10}
60
+ {"image_path": "./images/superCLEVR_new_025059.png", "question": "How many different items are there in the image?", "ground_truth": 10}
61
+ {"image_path": "./images/superCLEVR_new_025060.png", "question": "How many different items are there in the image?", "ground_truth": 8}
62
+ {"image_path": "./images/superCLEVR_new_025061.png", "question": "How many different items are there in the image?", "ground_truth": 8}
63
+ {"image_path": "./images/superCLEVR_new_025062.png", "question": "How many different items are there in the image?", "ground_truth": 8}
64
+ {"image_path": "./images/superCLEVR_new_025063.png", "question": "How many different items are there in the image?", "ground_truth": 10}
65
+ {"image_path": "./images/superCLEVR_new_025064.png", "question": "How many different items are there in the image?", "ground_truth": 3}
66
+ {"image_path": "./images/superCLEVR_new_025065.png", "question": "How many different items are there in the image?", "ground_truth": 4}
67
+ {"image_path": "./images/superCLEVR_new_025066.png", "question": "How many different items are there in the image?", "ground_truth": 6}
68
+ {"image_path": "./images/superCLEVR_new_025067.png", "question": "How many different items are there in the image?", "ground_truth": 7}
69
+ {"image_path": "./images/superCLEVR_new_025068.png", "question": "How many different items are there in the image?", "ground_truth": 3}
70
+ {"image_path": "./images/superCLEVR_new_025069.png", "question": "How many different items are there in the image?", "ground_truth": 10}
71
+ {"image_path": "./images/superCLEVR_new_025070.png", "question": "How many different items are there in the image?", "ground_truth": 9}
72
+ {"image_path": "./images/superCLEVR_new_025071.png", "question": "How many different items are there in the image?", "ground_truth": 6}
73
+ {"image_path": "./images/superCLEVR_new_025072.png", "question": "How many different items are there in the image?", "ground_truth": 10}
74
+ {"image_path": "./images/superCLEVR_new_025073.png", "question": "How many different items are there in the image?", "ground_truth": 5}
75
+ {"image_path": "./images/superCLEVR_new_025074.png", "question": "How many different items are there in the image?", "ground_truth": 9}
76
+ {"image_path": "./images/superCLEVR_new_025075.png", "question": "How many different items are there in the image?", "ground_truth": 3}
77
+ {"image_path": "./images/superCLEVR_new_025076.png", "question": "How many different items are there in the image?", "ground_truth": 5}
78
+ {"image_path": "./images/superCLEVR_new_025077.png", "question": "How many different items are there in the image?", "ground_truth": 5}
79
+ {"image_path": "./images/superCLEVR_new_025078.png", "question": "How many different items are there in the image?", "ground_truth": 5}
80
+ {"image_path": "./images/superCLEVR_new_025079.png", "question": "How many different items are there in the image?", "ground_truth": 9}
81
+ {"image_path": "./images/superCLEVR_new_025080.png", "question": "How many different items are there in the image?", "ground_truth": 5}
82
+ {"image_path": "./images/superCLEVR_new_025081.png", "question": "How many different items are there in the image?", "ground_truth": 5}
83
+ {"image_path": "./images/superCLEVR_new_025082.png", "question": "How many different items are there in the image?", "ground_truth": 10}
84
+ {"image_path": "./images/superCLEVR_new_025083.png", "question": "How many different items are there in the image?", "ground_truth": 4}
85
+ {"image_path": "./images/superCLEVR_new_025084.png", "question": "How many different items are there in the image?", "ground_truth": 8}
86
+ {"image_path": "./images/superCLEVR_new_025085.png", "question": "How many different items are there in the image?", "ground_truth": 8}
87
+ {"image_path": "./images/superCLEVR_new_025086.png", "question": "How many different items are there in the image?", "ground_truth": 10}
88
+ {"image_path": "./images/superCLEVR_new_025087.png", "question": "How many different items are there in the image?", "ground_truth": 9}
89
+ {"image_path": "./images/superCLEVR_new_025088.png", "question": "How many different items are there in the image?", "ground_truth": 3}
90
+ {"image_path": "./images/superCLEVR_new_025089.png", "question": "How many different items are there in the image?", "ground_truth": 4}
91
+ {"image_path": "./images/superCLEVR_new_025090.png", "question": "How many different items are there in the image?", "ground_truth": 9}
92
+ {"image_path": "./images/superCLEVR_new_025091.png", "question": "How many different items are there in the image?", "ground_truth": 7}
93
+ {"image_path": "./images/superCLEVR_new_025092.png", "question": "How many different items are there in the image?", "ground_truth": 6}
94
+ {"image_path": "./images/superCLEVR_new_025093.png", "question": "How many different items are there in the image?", "ground_truth": 10}
95
+ {"image_path": "./images/superCLEVR_new_025094.png", "question": "How many different items are there in the image?", "ground_truth": 6}
96
+ {"image_path": "./images/superCLEVR_new_025095.png", "question": "How many different items are there in the image?", "ground_truth": 6}
97
+ {"image_path": "./images/superCLEVR_new_025096.png", "question": "How many different items are there in the image?", "ground_truth": 8}
98
+ {"image_path": "./images/superCLEVR_new_025097.png", "question": "How many different items are there in the image?", "ground_truth": 7}
99
+ {"image_path": "./images/superCLEVR_new_025098.png", "question": "How many different items are there in the image?", "ground_truth": 10}
100
+ {"image_path": "./images/superCLEVR_new_025099.png", "question": "How many different items are there in the image?", "ground_truth": 10}
101
+ {"image_path": "./images/superCLEVR_new_025100.png", "question": "How many different items are there in the image?", "ground_truth": 5}
102
+ {"image_path": "./images/superCLEVR_new_025101.png", "question": "How many different items are there in the image?", "ground_truth": 7}
103
+ {"image_path": "./images/superCLEVR_new_025102.png", "question": "How many different items are there in the image?", "ground_truth": 3}
104
+ {"image_path": "./images/superCLEVR_new_025103.png", "question": "How many different items are there in the image?", "ground_truth": 6}
105
+ {"image_path": "./images/superCLEVR_new_025104.png", "question": "How many different items are there in the image?", "ground_truth": 9}
106
+ {"image_path": "./images/superCLEVR_new_025105.png", "question": "How many different items are there in the image?", "ground_truth": 7}
107
+ {"image_path": "./images/superCLEVR_new_025106.png", "question": "How many different items are there in the image?", "ground_truth": 8}
108
+ {"image_path": "./images/superCLEVR_new_025107.png", "question": "How many different items are there in the image?", "ground_truth": 8}
109
+ {"image_path": "./images/superCLEVR_new_025108.png", "question": "How many different items are there in the image?", "ground_truth": 3}
110
+ {"image_path": "./images/superCLEVR_new_025109.png", "question": "How many different items are there in the image?", "ground_truth": 7}
111
+ {"image_path": "./images/superCLEVR_new_025110.png", "question": "How many different items are there in the image?", "ground_truth": 8}
112
+ {"image_path": "./images/superCLEVR_new_025111.png", "question": "How many different items are there in the image?", "ground_truth": 9}
113
+ {"image_path": "./images/superCLEVR_new_025112.png", "question": "How many different items are there in the image?", "ground_truth": 9}
114
+ {"image_path": "./images/superCLEVR_new_025113.png", "question": "How many different items are there in the image?", "ground_truth": 6}
115
+ {"image_path": "./images/superCLEVR_new_025114.png", "question": "How many different items are there in the image?", "ground_truth": 6}
116
+ {"image_path": "./images/superCLEVR_new_025115.png", "question": "How many different items are there in the image?", "ground_truth": 9}
117
+ {"image_path": "./images/superCLEVR_new_025116.png", "question": "How many different items are there in the image?", "ground_truth": 7}
118
+ {"image_path": "./images/superCLEVR_new_025117.png", "question": "How many different items are there in the image?", "ground_truth": 9}
119
+ {"image_path": "./images/superCLEVR_new_025118.png", "question": "How many different items are there in the image?", "ground_truth": 5}
120
+ {"image_path": "./images/superCLEVR_new_025119.png", "question": "How many different items are there in the image?", "ground_truth": 9}
121
+ {"image_path": "./images/superCLEVR_new_025120.png", "question": "How many different items are there in the image?", "ground_truth": 6}
122
+ {"image_path": "./images/superCLEVR_new_025121.png", "question": "How many different items are there in the image?", "ground_truth": 10}
123
+ {"image_path": "./images/superCLEVR_new_025122.png", "question": "How many different items are there in the image?", "ground_truth": 10}
124
+ {"image_path": "./images/superCLEVR_new_025123.png", "question": "How many different items are there in the image?", "ground_truth": 6}
125
+ {"image_path": "./images/superCLEVR_new_025124.png", "question": "How many different items are there in the image?", "ground_truth": 8}
126
+ {"image_path": "./images/superCLEVR_new_025125.png", "question": "How many different items are there in the image?", "ground_truth": 8}
127
+ {"image_path": "./images/superCLEVR_new_025126.png", "question": "How many different items are there in the image?", "ground_truth": 3}
128
+ {"image_path": "./images/superCLEVR_new_025127.png", "question": "How many different items are there in the image?", "ground_truth": 7}
129
+ {"image_path": "./images/superCLEVR_new_025128.png", "question": "How many different items are there in the image?", "ground_truth": 6}
130
+ {"image_path": "./images/superCLEVR_new_025129.png", "question": "How many different items are there in the image?", "ground_truth": 4}
131
+ {"image_path": "./images/superCLEVR_new_025130.png", "question": "How many different items are there in the image?", "ground_truth": 5}
132
+ {"image_path": "./images/superCLEVR_new_025131.png", "question": "How many different items are there in the image?", "ground_truth": 8}
133
+ {"image_path": "./images/superCLEVR_new_025132.png", "question": "How many different items are there in the image?", "ground_truth": 3}
134
+ {"image_path": "./images/superCLEVR_new_025133.png", "question": "How many different items are there in the image?", "ground_truth": 5}
135
+ {"image_path": "./images/superCLEVR_new_025134.png", "question": "How many different items are there in the image?", "ground_truth": 8}
136
+ {"image_path": "./images/superCLEVR_new_025135.png", "question": "How many different items are there in the image?", "ground_truth": 8}
137
+ {"image_path": "./images/superCLEVR_new_025136.png", "question": "How many different items are there in the image?", "ground_truth": 6}
138
+ {"image_path": "./images/superCLEVR_new_025137.png", "question": "How many different items are there in the image?", "ground_truth": 5}
139
+ {"image_path": "./images/superCLEVR_new_025138.png", "question": "How many different items are there in the image?", "ground_truth": 3}
140
+ {"image_path": "./images/superCLEVR_new_025139.png", "question": "How many different items are there in the image?", "ground_truth": 4}
141
+ {"image_path": "./images/superCLEVR_new_025140.png", "question": "How many different items are there in the image?", "ground_truth": 3}
142
+ {"image_path": "./images/superCLEVR_new_025141.png", "question": "How many different items are there in the image?", "ground_truth": 9}
143
+ {"image_path": "./images/superCLEVR_new_025142.png", "question": "How many different items are there in the image?", "ground_truth": 10}
144
+ {"image_path": "./images/superCLEVR_new_025143.png", "question": "How many different items are there in the image?", "ground_truth": 5}
145
+ {"image_path": "./images/superCLEVR_new_025144.png", "question": "How many different items are there in the image?", "ground_truth": 6}
146
+ {"image_path": "./images/superCLEVR_new_025145.png", "question": "How many different items are there in the image?", "ground_truth": 10}
147
+ {"image_path": "./images/superCLEVR_new_025146.png", "question": "How many different items are there in the image?", "ground_truth": 5}
148
+ {"image_path": "./images/superCLEVR_new_025147.png", "question": "How many different items are there in the image?", "ground_truth": 6}
149
+ {"image_path": "./images/superCLEVR_new_025148.png", "question": "How many different items are there in the image?", "ground_truth": 8}
150
+ {"image_path": "./images/superCLEVR_new_025149.png", "question": "How many different items are there in the image?", "ground_truth": 8}
151
+ {"image_path": "./images/superCLEVR_new_025150.png", "question": "How many different items are there in the image?", "ground_truth": 9}
152
+ {"image_path": "./images/superCLEVR_new_025151.png", "question": "How many different items are there in the image?", "ground_truth": 8}
153
+ {"image_path": "./images/superCLEVR_new_025152.png", "question": "How many different items are there in the image?", "ground_truth": 10}
154
+ {"image_path": "./images/superCLEVR_new_025153.png", "question": "How many different items are there in the image?", "ground_truth": 3}
155
+ {"image_path": "./images/superCLEVR_new_025154.png", "question": "How many different items are there in the image?", "ground_truth": 5}
156
+ {"image_path": "./images/superCLEVR_new_025155.png", "question": "How many different items are there in the image?", "ground_truth": 10}
157
+ {"image_path": "./images/superCLEVR_new_025156.png", "question": "How many different items are there in the image?", "ground_truth": 3}
158
+ {"image_path": "./images/superCLEVR_new_025157.png", "question": "How many different items are there in the image?", "ground_truth": 6}
159
+ {"image_path": "./images/superCLEVR_new_025158.png", "question": "How many different items are there in the image?", "ground_truth": 4}
160
+ {"image_path": "./images/superCLEVR_new_025159.png", "question": "How many different items are there in the image?", "ground_truth": 5}
161
+ {"image_path": "./images/superCLEVR_new_025160.png", "question": "How many different items are there in the image?", "ground_truth": 9}
162
+ {"image_path": "./images/superCLEVR_new_025161.png", "question": "How many different items are there in the image?", "ground_truth": 3}
163
+ {"image_path": "./images/superCLEVR_new_025162.png", "question": "How many different items are there in the image?", "ground_truth": 5}
164
+ {"image_path": "./images/superCLEVR_new_025163.png", "question": "How many different items are there in the image?", "ground_truth": 10}
165
+ {"image_path": "./images/superCLEVR_new_025164.png", "question": "How many different items are there in the image?", "ground_truth": 9}
166
+ {"image_path": "./images/superCLEVR_new_025165.png", "question": "How many different items are there in the image?", "ground_truth": 7}
167
+ {"image_path": "./images/superCLEVR_new_025166.png", "question": "How many different items are there in the image?", "ground_truth": 8}
168
+ {"image_path": "./images/superCLEVR_new_025167.png", "question": "How many different items are there in the image?", "ground_truth": 7}
169
+ {"image_path": "./images/superCLEVR_new_025168.png", "question": "How many different items are there in the image?", "ground_truth": 3}
170
+ {"image_path": "./images/superCLEVR_new_025169.png", "question": "How many different items are there in the image?", "ground_truth": 10}
171
+ {"image_path": "./images/superCLEVR_new_025170.png", "question": "How many different items are there in the image?", "ground_truth": 8}
172
+ {"image_path": "./images/superCLEVR_new_025171.png", "question": "How many different items are there in the image?", "ground_truth": 7}
173
+ {"image_path": "./images/superCLEVR_new_025172.png", "question": "How many different items are there in the image?", "ground_truth": 4}
174
+ {"image_path": "./images/superCLEVR_new_025173.png", "question": "How many different items are there in the image?", "ground_truth": 10}
175
+ {"image_path": "./images/superCLEVR_new_025174.png", "question": "How many different items are there in the image?", "ground_truth": 9}
176
+ {"image_path": "./images/superCLEVR_new_025175.png", "question": "How many different items are there in the image?", "ground_truth": 4}
177
+ {"image_path": "./images/superCLEVR_new_025176.png", "question": "How many different items are there in the image?", "ground_truth": 9}
178
+ {"image_path": "./images/superCLEVR_new_025177.png", "question": "How many different items are there in the image?", "ground_truth": 6}
179
+ {"image_path": "./images/superCLEVR_new_025178.png", "question": "How many different items are there in the image?", "ground_truth": 10}
180
+ {"image_path": "./images/superCLEVR_new_025179.png", "question": "How many different items are there in the image?", "ground_truth": 6}
181
+ {"image_path": "./images/superCLEVR_new_025180.png", "question": "How many different items are there in the image?", "ground_truth": 3}
182
+ {"image_path": "./images/superCLEVR_new_025181.png", "question": "How many different items are there in the image?", "ground_truth": 3}
183
+ {"image_path": "./images/superCLEVR_new_025182.png", "question": "How many different items are there in the image?", "ground_truth": 8}
184
+ {"image_path": "./images/superCLEVR_new_025183.png", "question": "How many different items are there in the image?", "ground_truth": 5}
185
+ {"image_path": "./images/superCLEVR_new_025184.png", "question": "How many different items are there in the image?", "ground_truth": 5}
186
+ {"image_path": "./images/superCLEVR_new_025185.png", "question": "How many different items are there in the image?", "ground_truth": 3}
187
+ {"image_path": "./images/superCLEVR_new_025186.png", "question": "How many different items are there in the image?", "ground_truth": 4}
188
+ {"image_path": "./images/superCLEVR_new_025187.png", "question": "How many different items are there in the image?", "ground_truth": 5}
189
+ {"image_path": "./images/superCLEVR_new_025188.png", "question": "How many different items are there in the image?", "ground_truth": 5}
190
+ {"image_path": "./images/superCLEVR_new_025189.png", "question": "How many different items are there in the image?", "ground_truth": 3}
191
+ {"image_path": "./images/superCLEVR_new_025190.png", "question": "How many different items are there in the image?", "ground_truth": 5}
192
+ {"image_path": "./images/superCLEVR_new_025191.png", "question": "How many different items are there in the image?", "ground_truth": 8}
193
+ {"image_path": "./images/superCLEVR_new_025192.png", "question": "How many different items are there in the image?", "ground_truth": 3}
194
+ {"image_path": "./images/superCLEVR_new_025193.png", "question": "How many different items are there in the image?", "ground_truth": 9}
195
+ {"image_path": "./images/superCLEVR_new_025194.png", "question": "How many different items are there in the image?", "ground_truth": 10}
196
+ {"image_path": "./images/superCLEVR_new_025195.png", "question": "How many different items are there in the image?", "ground_truth": 5}
197
+ {"image_path": "./images/superCLEVR_new_025196.png", "question": "How many different items are there in the image?", "ground_truth": 6}
198
+ {"image_path": "./images/superCLEVR_new_025197.png", "question": "How many different items are there in the image?", "ground_truth": 3}
199
+ {"image_path": "./images/superCLEVR_new_025198.png", "question": "How many different items are there in the image?", "ground_truth": 4}
200
+ {"image_path": "./images/superCLEVR_new_025199.png", "question": "How many different items are there in the image?", "ground_truth": 3}
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_counting_superclevr.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+
8
+
9
+
10
+ MODEL_PATH="Qwen2-VL-2B-GRPO-CLEVR-70k/checkpoint-100" # Qwen2vl-2b-Instruct for original scores
11
+ BSZ=64 # reduce it if GPU OOM
12
+ OUTPUT_PATH="./logs/counting_results_superclevr_200_qwen2vl_2b_instruct_grpo_100.json"
13
+ PROMPT_PATH="./prompts/superclevr_test200_counting_problems.jsonl"
14
+
15
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
16
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
+ MODEL_PATH,
18
+ torch_dtype=torch.bfloat16,
19
+ attn_implementation="flash_attention_2",
20
+ device_map="auto",
21
+ )
22
+
23
+ # default processer
24
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
25
+
26
+ data = []
27
+ with open(PROMPT_PATH, "r") as f:
28
+ for line in f:
29
+ data.append(json.loads(line))
30
+
31
+
32
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
33
+
34
+ messages = []
35
+
36
+ for i in data:
37
+ message = [{
38
+ "role": "user",
39
+ "content": [
40
+ {
41
+ "type": "image",
42
+ "image": f"file://{i['image_path']}"
43
+ },
44
+ {
45
+ "type": "text",
46
+ "text": QUESTION_TEMPLATE.format(Question=i['question'])
47
+ }
48
+ ]
49
+ }]
50
+ messages.append(message)
51
+
52
+
53
+
54
+
55
+ all_outputs = [] # List to store all answers
56
+
57
+ # Process data in batches
58
+ for i in tqdm(range(0, len(messages), BSZ)):
59
+ batch_messages = messages[i:i + BSZ]
60
+
61
+ # Preparation for inference
62
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
63
+
64
+ image_inputs, video_inputs = process_vision_info(batch_messages)
65
+ inputs = processor(
66
+ text=text,
67
+ images=image_inputs,
68
+ videos=video_inputs,
69
+ padding=True,
70
+ return_tensors="pt",
71
+ )
72
+ inputs = inputs.to("cuda")
73
+
74
+ # Inference: Generation of the output
75
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
76
+
77
+ generated_ids_trimmed = [
78
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
79
+ ]
80
+ batch_output_text = processor.batch_decode(
81
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
82
+ )
83
+
84
+ all_outputs.extend(batch_output_text)
85
+ print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
86
+
87
+
88
+ def extract_number_answer(output_str):
89
+ # Try to find the number within <answer> tags, if can not find, return None
90
+ answer_pattern = r'<answer>\s*(\d+)\s*</answer>'
91
+ match = re.search(answer_pattern, output_str)
92
+
93
+ if match:
94
+ return int(match.group(1))
95
+ return None
96
+
97
+
98
+ final_output = []
99
+ correct_number = 0
100
+
101
+ for input_example, model_output in zip(data,all_outputs):
102
+ original_output = model_output
103
+ ground_truth = input_example['ground_truth']
104
+ model_answer = extract_number_answer(original_output)
105
+
106
+ # Create a result dictionary for this example
107
+ result = {
108
+ 'question': input_example,
109
+ 'ground_truth': ground_truth,
110
+ 'model_output': original_output,
111
+ 'extracted_answer': model_answer
112
+ }
113
+ final_output.append(result)
114
+
115
+ # Count correct answers
116
+ if model_answer is not None and model_answer == ground_truth:
117
+ correct_number += 1
118
+
119
+ # Calculate and print accuracy
120
+ accuracy = correct_number / len(data) * 100
121
+ print(f"\nAccuracy: {accuracy:.2f}%")
122
+
123
+ # Save results to a JSON file
124
+ output_path = OUTPUT_PATH
125
+ with open(output_path, "w") as f:
126
+ json.dump({
127
+ 'accuracy': accuracy,
128
+ 'results': final_output
129
+ }, f, indent=2)
130
+
131
+ print(f"Results saved to {output_path}")
132
+
133
+
134
+
135
+
136
+
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ from math_verify import parse, verify
8
+
9
+
10
+ MODEL_PATH="<MODEL_PATH>" # qwen2vl model or grpoed model on geoqa train
11
+ BSZ=50 # reduce it if GPU OOM
12
+ OUTPUT_PATH="<OUTPUT_LOG>"
13
+ PROMPT_PATH="./prompts/geoqa_test_prompts.jsonl"
14
+
15
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
16
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
+ MODEL_PATH,
18
+ torch_dtype=torch.bfloat16,
19
+ attn_implementation="flash_attention_2",
20
+ device_map="auto",
21
+ )
22
+
23
+ # default processer
24
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
25
+
26
+ data = []
27
+ with open(PROMPT_PATH, "r") as f:
28
+ for line in f:
29
+ data.append(json.loads(line))
30
+
31
+
32
+ QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
33
+
34
+ messages = []
35
+
36
+ data = data
37
+
38
+ for i in data:
39
+ message = [{
40
+ "role": "user",
41
+ "content": [
42
+ {
43
+ "type": "image",
44
+ "image": f"file://{i['image_path']}"
45
+ },
46
+ {
47
+ "type": "text",
48
+ "text": QUESTION_TEMPLATE.format(Question=i['question'])
49
+ }
50
+ ]
51
+ }]
52
+ messages.append(message)
53
+
54
+
55
+
56
+
57
+ all_outputs = [] # List to store all answers
58
+
59
+ # Process data in batches
60
+ for i in tqdm(range(0, len(messages), BSZ)):
61
+ batch_messages = messages[i:i + BSZ]
62
+
63
+ # Preparation for inference
64
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
65
+
66
+ image_inputs, video_inputs = process_vision_info(batch_messages)
67
+ inputs = processor(
68
+ text=text,
69
+ images=image_inputs,
70
+ videos=video_inputs,
71
+ padding=True,
72
+ return_tensors="pt",
73
+ )
74
+ inputs = inputs.to("cuda")
75
+
76
+ # Inference: Generation of the output
77
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False)
78
+
79
+ generated_ids_trimmed = [
80
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
81
+ ]
82
+ batch_output_text = processor.batch_decode(
83
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
84
+ )
85
+
86
+ all_outputs.extend(batch_output_text)
87
+ print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
88
+
89
+
90
+
91
+
92
+
93
+ final_output = []
94
+ correct_number = 0
95
+
96
+ for input_example, model_output in zip(data,all_outputs):
97
+ original_output = model_output
98
+ ground_truth = input_example['ground_truth']
99
+ model_answer = parse(original_output)
100
+
101
+ # Count correct answers
102
+ if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0:
103
+ correct_number += 1
104
+ is_correct = True
105
+ else:
106
+ is_correct = False
107
+
108
+ try:
109
+ result = {
110
+ 'question': input_example,
111
+ 'ground_truth': ground_truth,
112
+ 'model_output': original_output,
113
+ 'extracted_answer':str(model_answer[0]) if model_answer is not None else None,
114
+ 'is_correct':is_correct
115
+ }
116
+
117
+ except Exception as e:
118
+ print("no answer parsed",e,model_answer)
119
+ result = {
120
+ 'question': input_example,
121
+ 'ground_truth': ground_truth,
122
+ 'model_output': original_output,
123
+ 'extracted_answer':None,
124
+ 'is_correct':is_correct
125
+ }
126
+
127
+
128
+
129
+ final_output.append(result)
130
+
131
+
132
+ # Calculate and print accuracy
133
+ accuracy = correct_number / len(data) * 100
134
+ print(f"\nAccuracy: {accuracy:.2f}%")
135
+
136
+ # Save results to a JSON file
137
+ output_path = OUTPUT_PATH
138
+ with open(output_path, "w") as f:
139
+ json.dump({
140
+ 'accuracy': accuracy,
141
+ 'results': final_output
142
+ }, f, indent=2, ensure_ascii=False)
143
+
144
+ print(f"Results saved to {output_path}")
145
+
146
+
147
+
148
+
149
+
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_geoqa_multigpu.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ import tqdm
6
+ from math_verify import parse, verify
7
+ import argparse
8
+ import pandas as pd
9
+ from torch.multiprocessing import Process, set_start_method, Manager
10
+ from transformers.utils.logging import disable_progress_bar
11
+ disable_progress_bar()
12
+
13
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14
+ # >>>>> 1. get evaluation configuration <<<<<
15
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16
+ def get_eval_config():
17
+ parser = argparse.ArgumentParser(description="Inference script for GeoQA evaluation.")
18
+ parser.add_argument("--model_path", required=True, type=str, help="Path to the model checkpoint (e.g., qwen2vl model or a fine-tuned model).")
19
+ parser.add_argument("--batch_size", default=4, type=int, help="Batch size for inference. Reduce if GPU OOM (default: 50).")
20
+ parser.add_argument("--output_path", required=True, type=str, help="Path to save inference result (e.g., JSON file).")
21
+ parser.add_argument("--prompt_path", required=True, type=str, help="Path to the prompts JSONL file for GeoQA evaluation.")
22
+ all_gpu = ",".join(map(str, range(torch.cuda.device_count())))
23
+ parser.add_argument("--gpu_ids", default=all_gpu, help="comma-separated list of GPU IDs to use")
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
28
+ # >>>>>>>>>> 2. load testset <<<<<<<<<<<<<
29
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
30
+ def prepare_test_messages(testset_path):
31
+ testset_data = pd.read_json(testset_path, lines=True).to_dict(orient="records")
32
+ QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
33
+ tested_messages = []
34
+ for i in testset_data:
35
+ message = [{
36
+ "role": "user",
37
+ "content": [
38
+ {
39
+ "type": "image",
40
+ "image": f"file://{i['image_path']}"
41
+ },
42
+ {
43
+ "type": "text",
44
+ "text": QUESTION_TEMPLATE.format(Question=i['question'])
45
+ }
46
+ ]
47
+ }]
48
+ tested_messages.append(message)
49
+ return testset_data, tested_messages
50
+
51
+
52
+
53
+
54
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
55
+ # >>>>> 3. use several GPUs to accelerate inference at testset <<<<<
56
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
57
+
58
+ def init_model(model_path, gpu_id):
59
+ """init a model(args.model_path) on a specific gpu"""
60
+ # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
61
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
62
+ model_path,
63
+ torch_dtype=torch.bfloat16,
64
+ attn_implementation="flash_attention_2",
65
+ device_map=f"cuda:{gpu_id}",
66
+ )
67
+
68
+ # default processer
69
+ processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
70
+ return model, processor
71
+
72
+ def answer_a_batch_question_qwen(batch_messages, model, processor):
73
+ """ let qwen answer a batch of questions """
74
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
75
+ image_inputs, video_inputs = process_vision_info(batch_messages)
76
+ inputs = processor(
77
+ text=text,
78
+ images=image_inputs,
79
+ videos=video_inputs,
80
+ padding=True,
81
+ return_tensors="pt",
82
+ )
83
+ inputs = inputs.to(model.device)
84
+
85
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024) # do_sample=False
86
+ generated_ids_trimmed = [
87
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
88
+ ]
89
+ batch_output_text = processor.batch_decode(
90
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
91
+ )
92
+ return batch_output_text
93
+
94
+ def infer_on_single_gpu(model_path, device_id, chunk_of_tested_messages, batch_size, results=None):
95
+ """init model on this single gpu and let it answer asign chunk of questions"""
96
+ model, processor = init_model(model_path, device_id)
97
+
98
+ ### split batch
99
+ responses = []
100
+ batch_messages_list = [chunk_of_tested_messages[start: start + batch_size]
101
+ for start in range(0, len(chunk_of_tested_messages), batch_size)]
102
+
103
+ for batch_messages in tqdm.auto.tqdm(batch_messages_list, desc=f"GPU {device_id} progress", position=device_id, leave=False):
104
+ batch_output_text = answer_a_batch_question_qwen(batch_messages, model, processor)
105
+
106
+ responses.extend(batch_output_text)
107
+
108
+ results[device_id] = responses
109
+ return
110
+
111
+
112
+ def multi_gpu_inference(prompts, gpu_ids, model_path, batch_size):
113
+ """ let each gpu (along with a model) answer a chunk of questions """
114
+ set_start_method("spawn", force=True)
115
+ manager = Manager()
116
+ gpu_id2result = manager.dict()
117
+
118
+ gpu_ids = [int(gpu_id.strip()) for gpu_id in gpu_ids.split(',')]
119
+ num_gpus = len(gpu_ids)
120
+
121
+ chunk_size = len(prompts) // num_gpus
122
+ processes = []
123
+ for i, gpu_id in enumerate(gpu_ids):
124
+ start_idx = i * chunk_size
125
+ end_idx = (i + 1) * chunk_size if i != num_gpus - 1 else len(prompts)
126
+ chunk = prompts[start_idx: end_idx]
127
+ process = Process(target=infer_on_single_gpu, args=(model_path, gpu_id, chunk, batch_size, gpu_id2result))
128
+ process.start()
129
+ processes.append(process)
130
+
131
+ # for process in tqdm.auto.tqdm(processes, desc="Inference progress", position=num_gpus, leave=True):
132
+ for process in processes:
133
+ process.join()
134
+
135
+ all_predicts = []
136
+ for gpu_id in gpu_ids:
137
+ all_predicts.extend(gpu_id2result[gpu_id])
138
+
139
+ return all_predicts
140
+
141
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
142
+ # >>>>>>>>>> 4. compute metrics <<<<<<<<<<<
143
+ # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
144
+
145
+ def compute_metrics(testset_data, all_predicts):
146
+ final_output = []
147
+ correct_number = 0
148
+
149
+ for input_example, model_output in zip(testset_data, all_predicts):
150
+ original_output = model_output
151
+ ground_truth = input_example['ground_truth']
152
+ model_answer = parse(original_output)
153
+
154
+ # Count correct answers
155
+ if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0:
156
+ correct_number += 1
157
+ is_correct = True
158
+ else:
159
+ is_correct = False
160
+
161
+ try:
162
+ result = {
163
+ 'question': input_example,
164
+ 'ground_truth': ground_truth,
165
+ 'model_output': original_output,
166
+ 'extracted_answer':str(model_answer[0]) if model_answer is not None else None,
167
+ 'is_correct':is_correct
168
+ }
169
+
170
+ except Exception as e:
171
+ print("no answer parsed",e,model_answer)
172
+ result = {
173
+ 'question': input_example,
174
+ 'ground_truth': ground_truth,
175
+ 'model_output': original_output,
176
+ 'extracted_answer':None,
177
+ 'is_correct':is_correct
178
+ }
179
+
180
+
181
+
182
+ final_output.append(result)
183
+
184
+
185
+ # Calculate and print accuracy
186
+ accuracy = correct_number / len(tested_messages) * 100
187
+ print(f"\nAccuracy: {accuracy:.2f}%")
188
+
189
+ # Save results to a JSON file
190
+ with open(args.output_path, "w") as f:
191
+ json.dump({
192
+ 'accuracy': accuracy,
193
+ 'results': final_output
194
+ }, f, indent=2, ensure_ascii=False)
195
+
196
+ print(f"Results saved to {args.output_path}")
197
+
198
+
199
+
200
+ if __name__ == "__main__":
201
+ args = get_eval_config()
202
+ testset_data, tested_messages = prepare_test_messages(testset_path=args.prompt_path)
203
+ all_predicts = multi_gpu_inference(tested_messages, args.gpu_ids, args.model_path, args.batch_size)
204
+ compute_metrics(testset_data, all_predicts)
205
+
previous_version/Video-R1-main-previous/src/eval/test_qwen2vl_video_counting.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2
+ from qwen_vl_utils import process_vision_info
3
+ import torch
4
+ import json
5
+ from tqdm import tqdm
6
+ import re
7
+ import os
8
+
9
+
10
+
11
+ MODEL_PATH="YOUR_PATH" # Qwen2vl-2b-Instruct for original scores
12
+ BSZ=64 # reduce it if GPU OOM
13
+ OUTPUT_PATH="YOUR_PATH/test.json"
14
+ PROMPT_PATH="YOUR_PATH/test_dvd.jsonl"
15
+
16
+ #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ MODEL_PATH,
19
+ torch_dtype=torch.bfloat16,
20
+ attn_implementation="flash_attention_2",
21
+ device_map="auto",
22
+ )
23
+
24
+ # default processer
25
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
26
+
27
+ data = []
28
+ with open(PROMPT_PATH, "r") as f:
29
+ for line in f:
30
+ data.append(json.loads(line))
31
+
32
+ # detailed step-by-step
33
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
34
+
35
+ messages = []
36
+
37
+
38
+ for x in data:
39
+ message = [{
40
+ "role": "user",
41
+ "content": [
42
+ {
43
+ "type": "video",
44
+ "video": os.getcwd() + "/src/r1-v/data" + x['video_filename'][1:]
45
+ },
46
+ {
47
+ "type": "text",
48
+ "text": QUESTION_TEMPLATE.format(Question=x['problem'])
49
+ }
50
+ ]
51
+ }]
52
+ messages.append(message)
53
+
54
+
55
+
56
+
57
+ all_outputs = [] # List to store all answers
58
+
59
+ # Process data in batches
60
+ for i in tqdm(range(0, len(messages), BSZ)):
61
+ batch_messages = messages[i:i + BSZ]
62
+
63
+ # Preparation for inference
64
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
65
+
66
+
67
+ image_inputs, video_inputs = process_vision_info(batch_messages)
68
+ inputs = processor(
69
+ text=text,
70
+ images=image_inputs,
71
+ videos=video_inputs,
72
+ padding=True,
73
+ return_tensors="pt",
74
+ )
75
+ inputs = inputs.to("cuda")
76
+
77
+ # Inference: Generation of the output
78
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
79
+
80
+ generated_ids_trimmed = [
81
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
82
+ ]
83
+ batch_output_text = processor.batch_decode(
84
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
85
+ )
86
+
87
+
88
+ all_outputs.extend(batch_output_text)
89
+ print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
90
+
91
+
92
+ def extract_number_answer(output_str):
93
+ # Try to find the number within <answer> tags, if can not find, return None
94
+ answer_pattern = r'<answer>\s*(\d+)\s*</answer>'
95
+ match = re.search(answer_pattern, output_str)
96
+
97
+ if match:
98
+ return int(match.group(1))
99
+ return None
100
+
101
+
102
+ final_output = []
103
+ correct_number = 0
104
+
105
+ for input_example, model_output in zip(data,all_outputs):
106
+ original_output = model_output
107
+ ground_truth = extract_number_answer(input_example['solution'])
108
+ model_answer = extract_number_answer(original_output)
109
+
110
+
111
+ # Create a result dictionary for this example
112
+ result = {
113
+ 'question': input_example,
114
+ 'ground_truth': ground_truth,
115
+ 'model_output': original_output,
116
+ 'extracted_answer': model_answer
117
+ }
118
+ final_output.append(result)
119
+
120
+ # Count correct answers
121
+ if model_answer is not None and model_answer == ground_truth:
122
+ correct_number += 1
123
+
124
+ # Calculate and print accuracy
125
+ accuracy = correct_number / len(data) * 100
126
+ print(f"\nAccuracy: {accuracy:.2f}%")
127
+
128
+ # Save results to a JSON file
129
+ output_path = OUTPUT_PATH
130
+ with open(output_path, "w") as f:
131
+ json.dump({
132
+ 'accuracy': accuracy,
133
+ 'results': final_output
134
+ }, f, indent=2)
135
+
136
+ print(f"Results saved to {output_path}")
137
+
138
+
139
+
140
+
141
+
previous_version/Video-R1-main-previous/src/qwen-vl-utils/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.8.19
previous_version/Video-R1-main-previous/src/qwen-vl-utils/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # qwen-vl-utils
2
+
3
+ Qwen-VL Utils contains a set of helper functions for processing and integrating visual language information with Qwen-VL Series Model.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install qwen-vl-utils
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ### Qwen2VL
14
+
15
+ ```python
16
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
17
+ from qwen_vl_utils import process_vision_info
18
+
19
+
20
+ # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
21
+ messages = [
22
+ # Image
23
+ ## Local file path
24
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
25
+ ## Image URL
26
+ [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
27
+ ## Base64 encoded image
28
+ [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
29
+ ## PIL.Image.Image
30
+ [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
31
+ ## Model dynamically adjusts image size, specify dimensions if required.
32
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
33
+ # Video
34
+ ## Local video path
35
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
36
+ ## Local video frames
37
+ [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
38
+ ## Model dynamically adjusts video nframes, video height and width. specify args if required.
39
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
40
+ ]
41
+
42
+ processor = AutoProcessor.from_pretrained(model_path)
43
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
44
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
45
+ images, videos = process_vision_info(messages)
46
+ inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt")
47
+ print(inputs)
48
+ generated_ids = model.generate(**inputs)
49
+ print(generated_ids)
50
+ ```
51
+
52
+ ### Qwen2.5VL
53
+
54
+ ```python
55
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
56
+ from qwen_vl_utils import process_vision_info
57
+
58
+
59
+ # You can set the maximum tokens for a video through the environment variable VIDEO_MAX_PIXELS
60
+ # based on the maximum tokens that the model can accept.
61
+ # export VIDEO_MAX_PIXELS = 32000 * 28 * 28 * 0.9
62
+
63
+
64
+ # You can directly insert a local file path, a URL, or a base64-encoded image into the position where you want in the text.
65
+ messages = [
66
+ # Image
67
+ ## Local file path
68
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
69
+ ## Image URL
70
+ [{"role": "user", "content": [{"type": "image", "image": "http://path/to/your/image.jpg"}, {"type": "text", "text": "Describe this image."}]}],
71
+ ## Base64 encoded image
72
+ [{"role": "user", "content": [{"type": "image", "image": "data:image;base64,/9j/..."}, {"type": "text", "text": "Describe this image."}]}],
73
+ ## PIL.Image.Image
74
+ [{"role": "user", "content": [{"type": "image", "image": pil_image}, {"type": "text", "text": "Describe this image."}]}],
75
+ ## Model dynamically adjusts image size, specify dimensions if required.
76
+ [{"role": "user", "content": [{"type": "image", "image": "file:///path/to/your/image.jpg", "resized_height": 280, "resized_width": 420}, {"type": "text", "text": "Describe this image."}]}],
77
+ # Video
78
+ ## Local video path
79
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4"}, {"type": "text", "text": "Describe this video."}]}],
80
+ ## Local video frames
81
+ [{"role": "user", "content": [{"type": "video", "video": ["file:///path/to/extracted_frame1.jpg", "file:///path/to/extracted_frame2.jpg", "file:///path/to/extracted_frame3.jpg"],}, {"type": "text", "text": "Describe this video."},],}],
82
+ ## Model dynamically adjusts video nframes, video height and width. specify args if required.
83
+ [{"role": "user", "content": [{"type": "video", "video": "file:///path/to/video1.mp4", "fps": 2.0, "resized_height": 280, "resized_width": 280}, {"type": "text", "text": "Describe this video."}]}],
84
+ ]
85
+
86
+ processor = AutoProcessor.from_pretrained(model_path)
87
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
88
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
89
+ images, videos, video_kwargs = process_vision_info(messages, return_video_kwargs=True)
90
+ inputs = processor(text=text, images=images, videos=videos, padding=True, return_tensors="pt", **video_kwargs)
91
+ print(inputs)
92
+ generated_ids = model.generate(**inputs)
93
+ print(generated_ids)
94
+ ```
previous_version/Video-R1-main-previous/src/qwen-vl-utils/pyproject.toml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "qwen-vl-utils"
3
+ version = "0.0.10"
4
+ description = "Qwen Vision Language Model Utils - PyTorch"
5
+ authors = [
6
+ { name = "Qwen Team", email = "[email protected]" },
7
+ ]
8
+ dependencies = [
9
+ "requests",
10
+ "pillow",
11
+ "av",
12
+ "packaging",
13
+ ]
14
+ readme = "README.md"
15
+ requires-python = ">= 3.8"
16
+ license = {text = "Apache-2.0"}
17
+ keywords = [
18
+ 'large language model',
19
+ 'vision language model',
20
+ 'qwen-vl',
21
+ 'pytorch',
22
+ ]
23
+ classifiers = [
24
+ 'Development Status :: 4 - Beta',
25
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
26
+ 'Programming Language :: Python :: 3',
27
+ 'License :: OSI Approved :: Apache Software License',
28
+ ]
29
+
30
+ [project.urls]
31
+ Homepage = "https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils"
32
+ Repository = "https://github.com/QwenLM/Qwen2-VL.git"
33
+ Issues = "https://github.com/QwenLM/Qwen2-VL/issues"
34
+
35
+ [project.optional-dependencies]
36
+ decord = [
37
+ "decord",
38
+ ]
39
+
40
+ [build-system]
41
+ requires = ["hatchling"]
42
+ build-backend = "hatchling.build"
43
+
44
+ [tool.rye]
45
+ managed = true
46
+ dev-dependencies = [
47
+ "torch",
48
+ "torchvision",
49
+ ]
50
+
51
+ [tool.hatch.metadata]
52
+ allow-direct-references = true
53
+
54
+ [tool.hatch.build.targets.wheel]
55
+ packages = ["src/qwen_vl_utils"]
56
+
57
+ [tool.ruff]
58
+ line-length = 119
59
+
60
+ [tool.ruff.lint]
61
+ ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
62
+ select = ["C", "E", "F", "I", "W"]
63
+
64
+ [tool.ruff.lint.per-file-ignores]
65
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
66
+
67
+ [tool.ruff.lint.isort]
68
+ lines-after-imports = 2
69
+ known-first-party = ["qwen_vl_utils"]
70
+
71
+ [tool.ruff.format]
72
+ quote-style = "double"
73
+ indent-style = "space"
74
+ skip-magic-trailing-comma = false
75
+ line-ending = "auto"
previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements-dev.lock ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: ["decord"]
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+ # universal: false
11
+
12
+ -e file:.
13
+ av==12.3.0
14
+ # via qwen-vl-utils
15
+ certifi==2022.12.7
16
+ # via requests
17
+ charset-normalizer==2.1.1
18
+ # via requests
19
+ decord==0.6.0
20
+ # via qwen-vl-utils
21
+ filelock==3.13.1
22
+ # via torch
23
+ # via triton
24
+ fsspec==2024.2.0
25
+ # via torch
26
+ idna==3.4
27
+ # via requests
28
+ jinja2==3.1.3
29
+ # via torch
30
+ markupsafe==2.1.5
31
+ # via jinja2
32
+ mpmath==1.3.0
33
+ # via sympy
34
+ networkx==3.1
35
+ # via torch
36
+ numpy==1.24.1
37
+ # via decord
38
+ # via torchvision
39
+ nvidia-cublas-cu12==12.1.3.1
40
+ # via nvidia-cudnn-cu12
41
+ # via nvidia-cusolver-cu12
42
+ # via torch
43
+ nvidia-cuda-cupti-cu12==12.1.105
44
+ # via torch
45
+ nvidia-cuda-nvrtc-cu12==12.1.105
46
+ # via torch
47
+ nvidia-cuda-runtime-cu12==12.1.105
48
+ # via torch
49
+ nvidia-cudnn-cu12==9.1.0.70
50
+ # via torch
51
+ nvidia-cufft-cu12==11.0.2.54
52
+ # via torch
53
+ nvidia-curand-cu12==10.3.2.106
54
+ # via torch
55
+ nvidia-cusolver-cu12==11.4.5.107
56
+ # via torch
57
+ nvidia-cusparse-cu12==12.1.0.106
58
+ # via nvidia-cusolver-cu12
59
+ # via torch
60
+ nvidia-nccl-cu12==2.20.5
61
+ # via torch
62
+ nvidia-nvjitlink-cu12==12.6.68
63
+ # via nvidia-cusolver-cu12
64
+ # via nvidia-cusparse-cu12
65
+ nvidia-nvtx-cu12==12.1.105
66
+ # via torch
67
+ packaging==24.1
68
+ # via qwen-vl-utils
69
+ pillow==10.2.0
70
+ # via qwen-vl-utils
71
+ # via torchvision
72
+ requests==2.28.1
73
+ # via qwen-vl-utils
74
+ sympy==1.12
75
+ # via torch
76
+ torch==2.4.0
77
+ # via torchvision
78
+ torchvision==0.19.0
79
+ triton==3.0.0
80
+ # via torch
81
+ typing-extensions==4.9.0
82
+ # via torch
83
+ urllib3==1.26.13
84
+ # via requests
previous_version/Video-R1-main-previous/src/qwen-vl-utils/requirements.lock ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generated by rye
2
+ # use `rye lock` or `rye sync` to update this lockfile
3
+ #
4
+ # last locked with the following flags:
5
+ # pre: false
6
+ # features: ["decord"]
7
+ # all-features: false
8
+ # with-sources: false
9
+ # generate-hashes: false
10
+ # universal: false
11
+
12
+ -e file:.
13
+ av==12.3.0
14
+ # via qwen-vl-utils
15
+ certifi==2022.12.7
16
+ # via requests
17
+ charset-normalizer==2.1.1
18
+ # via requests
19
+ decord==0.6.0
20
+ # via qwen-vl-utils
21
+ idna==3.4
22
+ # via requests
23
+ numpy==1.24.4
24
+ # via decord
25
+ packaging==24.1
26
+ # via qwen-vl-utils
27
+ pillow==10.2.0
28
+ # via qwen-vl-utils
29
+ requests==2.28.1
30
+ # via qwen-vl-utils
31
+ urllib3==1.26.13
32
+ # via requests
previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .vision_process import (
2
+ extract_vision_info,
3
+ fetch_image,
4
+ fetch_video,
5
+ process_vision_info,
6
+ smart_resize,
7
+ )
previous_version/Video-R1-main-previous/src/qwen-vl-utils/src/qwen_vl_utils/vision_process.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import logging
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+ import warnings
10
+ from functools import lru_cache
11
+ from io import BytesIO
12
+
13
+ import requests
14
+ import torch
15
+ import torchvision
16
+ from packaging import version
17
+ from PIL import Image
18
+ from torchvision import io, transforms
19
+ from torchvision.transforms import InterpolationMode
20
+ from typing import Optional
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ IMAGE_FACTOR = 28
26
+ MIN_PIXELS = 4 * 28 * 28
27
+ MAX_PIXELS = 16384 * 28 * 28
28
+ MAX_RATIO = 200
29
+
30
+ # VIDEO_MIN_PIXELS = 128 * 28 * 28
31
+ # VIDEO_MAX_PIXELS = 768 * 28 * 28
32
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
33
+ VIDEO_MAX_PIXELS = 128 * 28 * 28
34
+ FRAME_FACTOR = 2
35
+ FPS = 2.0
36
+ FPS_MIN_FRAMES = 4
37
+ FPS_MAX_FRAMES = 16
38
+
39
+ # Set the maximum number of video token inputs.
40
+ # Here, 128K represents the maximum number of input tokens for the VLLM model.
41
+ # Remember to adjust it according to your own configuration.
42
+ VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
43
+ logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
44
+
45
+
46
+ def round_by_factor(number: int, factor: int) -> int:
47
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
48
+ return round(number / factor) * factor
49
+
50
+
51
+ def ceil_by_factor(number: int, factor: int) -> int:
52
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
53
+ return math.ceil(number / factor) * factor
54
+
55
+
56
+ def floor_by_factor(number: int, factor: int) -> int:
57
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
58
+ return math.floor(number / factor) * factor
59
+
60
+
61
+ def smart_resize(
62
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
63
+ ) -> tuple[int, int]:
64
+ """
65
+ Rescales the image so that the following conditions are met:
66
+
67
+ 1. Both dimensions (height and width) are divisible by 'factor'.
68
+
69
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
70
+
71
+ 3. The aspect ratio of the image is maintained as closely as possible.
72
+ """
73
+ if max(height, width) / min(height, width) > MAX_RATIO:
74
+ raise ValueError(
75
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
76
+ )
77
+ h_bar = max(factor, round_by_factor(height, factor))
78
+ w_bar = max(factor, round_by_factor(width, factor))
79
+ if h_bar * w_bar > max_pixels:
80
+ beta = math.sqrt((height * width) / max_pixels)
81
+ h_bar = floor_by_factor(height / beta, factor)
82
+ w_bar = floor_by_factor(width / beta, factor)
83
+ elif h_bar * w_bar < min_pixels:
84
+ beta = math.sqrt(min_pixels / (height * width))
85
+ h_bar = ceil_by_factor(height * beta, factor)
86
+ w_bar = ceil_by_factor(width * beta, factor)
87
+ return h_bar, w_bar
88
+
89
+
90
+ def to_rgb(pil_image: Image.Image) -> Image.Image:
91
+ if pil_image.mode == 'RGBA':
92
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
93
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
94
+ return white_background
95
+ else:
96
+ return pil_image.convert("RGB")
97
+
98
+
99
+ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
100
+ if "image" in ele:
101
+ image = ele["image"]
102
+ else:
103
+ image = ele["image_url"]
104
+ image_obj = None
105
+ if isinstance(image, Image.Image):
106
+ image_obj = image
107
+ elif image.startswith("http://") or image.startswith("https://"):
108
+ response = requests.get(image, stream=True)
109
+ image_obj = Image.open(BytesIO(response.content))
110
+ elif image.startswith("file://"):
111
+ image_obj = Image.open(image[7:])
112
+ elif image.startswith("data:image"):
113
+ if "base64," in image:
114
+ _, base64_data = image.split("base64,", 1)
115
+ data = base64.b64decode(base64_data)
116
+ image_obj = Image.open(BytesIO(data))
117
+ else:
118
+ image_obj = Image.open(image)
119
+ if image_obj is None:
120
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
121
+ image = to_rgb(image_obj)
122
+ ## resize
123
+ if "resized_height" in ele and "resized_width" in ele:
124
+ resized_height, resized_width = smart_resize(
125
+ ele["resized_height"],
126
+ ele["resized_width"],
127
+ factor=size_factor,
128
+ )
129
+ else:
130
+ width, height = image.size
131
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
132
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
133
+ resized_height, resized_width = smart_resize(
134
+ height,
135
+ width,
136
+ factor=size_factor,
137
+ min_pixels=min_pixels,
138
+ max_pixels=max_pixels,
139
+ )
140
+ image = image.resize((resized_width, resized_height))
141
+
142
+ return image
143
+
144
+
145
+ def smart_nframes(
146
+ ele: dict,
147
+ total_frames: int,
148
+ video_fps: int | float,
149
+ ) -> int:
150
+ """calculate the number of frames for video used for model inputs.
151
+
152
+ Args:
153
+ ele (dict): a dict contains the configuration of video.
154
+ support either `fps` or `nframes`:
155
+ - nframes: the number of frames to extract for model inputs.
156
+ - fps: the fps to extract frames for model inputs.
157
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
158
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
159
+ total_frames (int): the original total number of frames of the video.
160
+ video_fps (int | float): the original fps of the video.
161
+
162
+ Raises:
163
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
164
+
165
+ Returns:
166
+ int: the number of frames for video used for model inputs.
167
+ """
168
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
169
+ if "nframes" in ele:
170
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
171
+ else:
172
+ fps = ele.get("fps", FPS)
173
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
174
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
175
+ nframes = total_frames / video_fps * fps
176
+ if nframes > total_frames:
177
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
178
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
179
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
180
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
181
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
182
+ return nframes
183
+
184
+
185
+ def _read_video_torchvision(
186
+ ele: dict,
187
+ ) -> (torch.Tensor, float):
188
+ """read video using torchvision.io.read_video
189
+
190
+ Args:
191
+ ele (dict): a dict contains the configuration of video.
192
+ support keys:
193
+ - video: the path of video. support "file://", "http://", "https://" and local path.
194
+ - video_start: the start time of video.
195
+ - video_end: the end time of video.
196
+ Returns:
197
+ torch.Tensor: the video tensor with shape (T, C, H, W).
198
+ """
199
+ video_path = ele["video"]
200
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
201
+ if "http://" in video_path or "https://" in video_path:
202
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
203
+ if "file://" in video_path:
204
+ video_path = video_path[7:]
205
+ st = time.time()
206
+ video, audio, info = io.read_video(
207
+ video_path,
208
+ start_pts=ele.get("video_start", 0.0),
209
+ end_pts=ele.get("video_end", None),
210
+ pts_unit="sec",
211
+ output_format="TCHW",
212
+ )
213
+ total_frames, video_fps = video.size(0), info["video_fps"]
214
+ logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
215
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
216
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
217
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
218
+ video = video[idx]
219
+ return video, sample_fps
220
+
221
+
222
+ def is_decord_available() -> bool:
223
+ import importlib.util
224
+
225
+ return importlib.util.find_spec("decord") is not None
226
+
227
+
228
+ def _read_video_decord(
229
+ ele: dict,
230
+ ) -> (torch.Tensor, float):
231
+ """read video using decord.VideoReader
232
+
233
+ Args:
234
+ ele (dict): a dict contains the configuration of video.
235
+ support keys:
236
+ - video: the path of video. support "file://", "http://", "https://" and local path.
237
+ - video_start: the start time of video.
238
+ - video_end: the end time of video.
239
+ Returns:
240
+ torch.Tensor: the video tensor with shape (T, C, H, W).
241
+ """
242
+ import decord
243
+ video_path = ele["video"]
244
+ st = time.time()
245
+ vr = decord.VideoReader(video_path)
246
+ # TODO: support start_pts and end_pts
247
+ if 'video_start' in ele or 'video_end' in ele:
248
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
249
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
250
+ logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
251
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
252
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
253
+ video = vr.get_batch(idx).asnumpy()
254
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
255
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
256
+ return video, sample_fps
257
+
258
+
259
+ VIDEO_READER_BACKENDS = {
260
+ "decord": _read_video_decord,
261
+ "torchvision": _read_video_torchvision,
262
+ }
263
+
264
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
265
+
266
+
267
+ @lru_cache(maxsize=1)
268
+ def get_video_reader_backend() -> str:
269
+ if FORCE_QWENVL_VIDEO_READER is not None:
270
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
271
+ elif is_decord_available():
272
+ video_reader_backend = "decord"
273
+ else:
274
+ video_reader_backend = "torchvision"
275
+ print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
276
+ return video_reader_backend
277
+
278
+
279
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
280
+ if isinstance(ele["video"], str):
281
+ video_reader_backend = get_video_reader_backend()
282
+ try:
283
+ video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
284
+ except Exception as e:
285
+ logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
286
+ video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
287
+
288
+ nframes, _, height, width = video.shape
289
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
290
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
291
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
292
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
293
+ if max_pixels_supposed > max_pixels:
294
+ logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
295
+ max_pixels = min(max_pixels_supposed, max_pixels)
296
+ if "resized_height" in ele and "resized_width" in ele:
297
+ resized_height, resized_width = smart_resize(
298
+ ele["resized_height"],
299
+ ele["resized_width"],
300
+ factor=image_factor,
301
+ )
302
+ else:
303
+ resized_height, resized_width = smart_resize(
304
+ height,
305
+ width,
306
+ factor=image_factor,
307
+ min_pixels=min_pixels,
308
+ max_pixels=max_pixels,
309
+ )
310
+ video = transforms.functional.resize(
311
+ video,
312
+ [resized_height, resized_width],
313
+ interpolation=InterpolationMode.BICUBIC,
314
+ antialias=True,
315
+ ).float()
316
+ if return_video_sample_fps:
317
+ return video, sample_fps
318
+ return video
319
+ else:
320
+ assert isinstance(ele["video"], (list, tuple))
321
+ process_info = ele.copy()
322
+ process_info.pop("type", None)
323
+ process_info.pop("video", None)
324
+ images = [
325
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
326
+ for video_element in ele["video"]
327
+ ]
328
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
329
+ if len(images) < nframes:
330
+ images.extend([images[-1]] * (nframes - len(images)))
331
+ if return_video_sample_fps:
332
+ return images, process_info.pop("fps", 2.0)
333
+ return images
334
+
335
+
336
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
337
+ vision_infos = []
338
+ if isinstance(conversations[0], dict):
339
+ conversations = [conversations]
340
+ for conversation in conversations:
341
+ for message in conversation:
342
+ if isinstance(message["content"], list):
343
+ for ele in message["content"]:
344
+ if (
345
+ "image" in ele
346
+ or "image_url" in ele
347
+ or "video" in ele
348
+ or ele["type"] in ("image", "image_url", "video")
349
+ ):
350
+ vision_infos.append(ele)
351
+ return vision_infos
352
+
353
+
354
+ def process_vision_info(
355
+ conversations: list[dict] | list[list[dict]],
356
+ return_video_kwargs: bool = False,
357
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
358
+
359
+ vision_infos = extract_vision_info(conversations)
360
+ ## Read images or videos
361
+ image_inputs = []
362
+ video_inputs = []
363
+ video_sample_fps_list = []
364
+ for vision_info in vision_infos:
365
+ if "image" in vision_info or "image_url" in vision_info:
366
+ image_inputs.append(fetch_image(vision_info))
367
+ elif "video" in vision_info:
368
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
369
+ video_sample_fps_list.append(video_sample_fps)
370
+ video_inputs.append(video_input)
371
+ else:
372
+ raise ValueError("image, image_url or video should in content.")
373
+ if len(image_inputs) == 0:
374
+ image_inputs = None
375
+ if len(video_inputs) == 0:
376
+ video_inputs = None
377
+ if return_video_kwargs:
378
+ return image_inputs, video_inputs, {'fps': video_sample_fps_list}
379
+ return image_inputs, video_inputs
previous_version/Video-R1-main-previous/src/r1-v/temp_image.png ADDED

Git LFS Details

  • SHA256: 6d32d2be631fcae3fcf15b31fb57096fdba3c4c6e5417f8cab84f5c16e7ce18f
  • Pointer size: 131 Bytes
  • Size of remote file: 147 kB
src/r1-v/.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
172
+
173
+ # Temp folders
174
+ data/
175
+ wandb/
176
+ scripts/
177
+ checkpoints/
178
+ .vscode/
src/r1-v/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
src/r1-v/Makefile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style quality
2
+
3
+ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4
+ export PYTHONPATH = src
5
+
6
+ check_dirs := src
7
+
8
+ style:
9
+ black --line-length 119 --target-version py310 $(check_dirs) setup.py
10
+ isort $(check_dirs) setup.py
11
+
12
+ quality:
13
+ black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14
+ isort --check-only $(check_dirs) setup.py
15
+ flake8 --max-line-length 119 $(check_dirs) setup.py
16
+
17
+
18
+ # Evaluation
19
+
20
+ evaluate:
src/r1-v/setup.cfg ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ default_section = FIRSTPARTY
3
+ ensure_newline_before_comments = True
4
+ force_grid_wrap = 0
5
+ include_trailing_comma = True
6
+ known_first_party = open_r1
7
+ known_third_party =
8
+ transformers
9
+ datasets
10
+ fugashi
11
+ git
12
+ h5py
13
+ matplotlib
14
+ nltk
15
+ numpy
16
+ packaging
17
+ pandas
18
+ psutil
19
+ pytest
20
+ rouge_score
21
+ sacrebleu
22
+ seqeval
23
+ sklearn
24
+ streamlit
25
+ torch
26
+ tqdm
27
+
28
+ line_length = 119
29
+ lines_after_imports = 2
30
+ multi_line_output = 3
31
+ use_parentheses = True
32
+
33
+ [flake8]
34
+ ignore = E203, E501, E741, W503, W605
35
+ max-line-length = 119
36
+ per-file-ignores =
37
+ # imported but unused
38
+ __init__.py: F401
39
+
40
+ [tool:pytest]
41
+ doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
src/r1-v/setup.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16
+
17
+
18
+ import re
19
+ import shutil
20
+ from pathlib import Path
21
+
22
+ from setuptools import find_packages, setup
23
+
24
+
25
+ # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26
+ stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27
+ if stale_egg_info.exists():
28
+ print(
29
+ (
30
+ "Warning: {} exists.\n\n"
31
+ "If you recently updated open_r1, this is expected,\n"
32
+ "but it may prevent open_r1 from installing in editable mode.\n\n"
33
+ "This directory is automatically generated by Python's packaging tools.\n"
34
+ "I will remove it now.\n\n"
35
+ "See https://github.com/pypa/pip/issues/5466 for details.\n"
36
+ ).format(stale_egg_info)
37
+ )
38
+ shutil.rmtree(stale_egg_info)
39
+
40
+
41
+ # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42
+ # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43
+ _deps = [
44
+ "accelerate>=1.2.1",
45
+ "bitsandbytes>=0.43.0",
46
+ "black>=24.4.2",
47
+ "datasets>=3.2.0",
48
+ "deepspeed==0.15.4",
49
+ "distilabel[vllm,ray,openai]>=1.5.2",
50
+ "einops>=0.8.0",
51
+ "flake8>=6.0.0",
52
+ "hf_transfer>=0.1.4",
53
+ "huggingface-hub[cli]>=0.19.2,<1.0",
54
+ "isort>=5.12.0",
55
+ "liger_kernel==0.5.2",
56
+ "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57
+ "math-verify", # Used for math verification in grpo
58
+ "packaging>=23.0",
59
+ "parameterized>=0.9.0",
60
+ "pytest",
61
+ "safetensors>=0.3.3",
62
+ "sentencepiece>=0.1.99",
63
+ "torch>=2.5.1",
64
+ # "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef",
65
+ "trl==0.16.0",
66
+ "vllm==0.7.2",
67
+ "wandb>=0.19.1",
68
+ "pillow",
69
+ ]
70
+
71
+ # this is a lookup table with items like:
72
+ #
73
+ # tokenizers: "tokenizers==0.9.4"
74
+ # packaging: "packaging"
75
+ #
76
+ # some of the values are versioned whereas others aren't.
77
+ deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78
+
79
+
80
+ def deps_list(*pkgs):
81
+ return [deps[pkg] for pkg in pkgs]
82
+
83
+
84
+ extras = {}
85
+ extras["tests"] = deps_list("pytest", "parameterized")
86
+ extras["torch"] = deps_list("torch")
87
+ extras["quality"] = deps_list("black", "isort", "flake8")
88
+ extras["eval"] = deps_list("lighteval", "math-verify")
89
+ extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
90
+
91
+ # core dependencies shared across the whole project - keep this to a bare minimum :)
92
+ install_requires = [
93
+ deps["accelerate"],
94
+ deps["bitsandbytes"],
95
+ deps["einops"],
96
+ deps["datasets"],
97
+ deps["deepspeed"],
98
+ deps["hf_transfer"],
99
+ deps["huggingface-hub"],
100
+ deps["liger_kernel"],
101
+ deps["packaging"], # utilities from PyPA to e.g., compare versions
102
+ deps["safetensors"],
103
+ deps["sentencepiece"],
104
+ # deps["transformers"],
105
+ deps["trl"],
106
+ ]
107
+
108
+ setup(
109
+ name="r1-v",
110
+ version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
111
+ author="The r1-v team and the Hugging Face team (past and future)",
112
+ description="R1-V",
113
+ license="Apache",
114
+ url="https://github.com/Deep-Agent/R1-V",
115
+ package_dir={"": "src"},
116
+ packages=find_packages("src"),
117
+ zip_safe=False,
118
+ extras_require=extras,
119
+ python_requires=">=3.10.9",
120
+ install_requires=install_requires,
121
+ classifiers=[
122
+ "Development Status :: 3 - Alpha",
123
+ "Intended Audience :: Developers",
124
+ "Intended Audience :: Education",
125
+ "Intended Audience :: Science/Research",
126
+ "License :: OSI Approved :: Apache Software License",
127
+ "Operating System :: OS Independent",
128
+ "Programming Language :: Python :: 3",
129
+ "Programming Language :: Python :: 3.10",
130
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
131
+ ],
132
+ )
src/r1-v/src/open_r1/__init__.py ADDED
File without changes
src/r1-v/src/open_r1/evaluate.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Custom evaluation tasks for LightEval."""
16
+
17
+ from lighteval.metrics.dynamic_metrics import (
18
+ ExprExtractionConfig,
19
+ LatexExtractionConfig,
20
+ multilingual_extractive_match_metric,
21
+ )
22
+ from lighteval.tasks.lighteval_task import LightevalTaskConfig
23
+ from lighteval.tasks.requests import Doc
24
+ from lighteval.utils.language import Language
25
+
26
+
27
+ metric = multilingual_extractive_match_metric(
28
+ language=Language.ENGLISH,
29
+ fallback_mode="first_match",
30
+ precision=5,
31
+ gold_extraction_target=(LatexExtractionConfig(),),
32
+ pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33
+ aggregation_function=max,
34
+ )
35
+
36
+
37
+ def prompt_fn(line, task_name: str = None):
38
+ """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39
+ return Doc(
40
+ task_name=task_name,
41
+ query=line["problem"],
42
+ choices=[line["solution"]],
43
+ gold_index=0,
44
+ )
45
+
46
+
47
+ # Define tasks
48
+ aime24 = LightevalTaskConfig(
49
+ name="aime24",
50
+ suite=["custom"],
51
+ prompt_function=prompt_fn,
52
+ hf_repo="HuggingFaceH4/aime_2024",
53
+ hf_subset="default",
54
+ hf_avail_splits=["train"],
55
+ evaluation_splits=["train"],
56
+ few_shots_split=None,
57
+ few_shots_select=None,
58
+ generation_size=32768,
59
+ metric=[metric],
60
+ version=1,
61
+ )
62
+ math_500 = LightevalTaskConfig(
63
+ name="math_500",
64
+ suite=["custom"],
65
+ prompt_function=prompt_fn,
66
+ hf_repo="HuggingFaceH4/MATH-500",
67
+ hf_subset="default",
68
+ hf_avail_splits=["test"],
69
+ evaluation_splits=["test"],
70
+ few_shots_split=None,
71
+ few_shots_select=None,
72
+ generation_size=32768,
73
+ metric=[metric],
74
+ version=1,
75
+ )
76
+
77
+ # Add tasks to the table
78
+ TASKS_TABLE = []
79
+ TASKS_TABLE.append(aime24)
80
+ TASKS_TABLE.append(math_500)
81
+
82
+ # MODULE LOGIC
83
+ if __name__ == "__main__":
84
+ print([t["name"] for t in TASKS_TABLE])
85
+ print(len(TASKS_TABLE))
src/r1-v/src/open_r1/generate.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from distilabel.llms import OpenAILLM
18
+ from distilabel.pipeline import Pipeline
19
+ from distilabel.steps.tasks import TextGeneration
20
+
21
+
22
+ def build_distilabel_pipeline(
23
+ model: str,
24
+ base_url: str = "http://localhost:8000/v1",
25
+ prompt_column: Optional[str] = None,
26
+ temperature: Optional[float] = None,
27
+ top_p: Optional[float] = None,
28
+ max_new_tokens: int = 8192,
29
+ num_generations: int = 1,
30
+ ) -> Pipeline:
31
+ generation_kwargs = {"max_new_tokens": max_new_tokens}
32
+
33
+ if temperature is not None:
34
+ generation_kwargs["temperature"] = temperature
35
+
36
+ if top_p is not None:
37
+ generation_kwargs["top_p"] = top_p
38
+
39
+ with Pipeline().ray() as pipeline:
40
+ TextGeneration(
41
+ llm=OpenAILLM(
42
+ base_url=base_url,
43
+ api_key="something",
44
+ model=model,
45
+ # thinking can take some time...
46
+ timeout=10 * 60,
47
+ generation_kwargs=generation_kwargs,
48
+ ),
49
+ input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50
+ input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51
+ num_generations=num_generations,
52
+ )
53
+
54
+ return pipeline
55
+
56
+
57
+ if __name__ == "__main__":
58
+ import argparse
59
+
60
+ from datasets import load_dataset
61
+
62
+ parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63
+ parser.add_argument(
64
+ "--hf-dataset",
65
+ type=str,
66
+ required=True,
67
+ help="HuggingFace dataset to load",
68
+ )
69
+ parser.add_argument(
70
+ "--hf-dataset-config",
71
+ type=str,
72
+ required=False,
73
+ help="Dataset config to use",
74
+ )
75
+ parser.add_argument(
76
+ "--hf-dataset-split",
77
+ type=str,
78
+ default="train",
79
+ help="Dataset split to use",
80
+ )
81
+ parser.add_argument("--prompt-column", type=str, default="prompt")
82
+ parser.add_argument(
83
+ "--model",
84
+ type=str,
85
+ required=True,
86
+ help="Model name to use for generation",
87
+ )
88
+ parser.add_argument(
89
+ "--vllm-server-url",
90
+ type=str,
91
+ default="http://localhost:8000/v1",
92
+ help="URL of the vLLM server",
93
+ )
94
+ parser.add_argument(
95
+ "--temperature",
96
+ type=float,
97
+ help="Temperature for generation",
98
+ )
99
+ parser.add_argument(
100
+ "--top-p",
101
+ type=float,
102
+ help="Top-p value for generation",
103
+ )
104
+ parser.add_argument(
105
+ "--max-new-tokens",
106
+ type=int,
107
+ default=8192,
108
+ help="Maximum number of new tokens to generate",
109
+ )
110
+ parser.add_argument(
111
+ "--num-generations",
112
+ type=int,
113
+ default=1,
114
+ help="Number of generations per problem",
115
+ )
116
+ parser.add_argument(
117
+ "--hf-output-dataset",
118
+ type=str,
119
+ required=False,
120
+ help="HuggingFace repo to push results to",
121
+ )
122
+ parser.add_argument(
123
+ "--private",
124
+ action="store_true",
125
+ help="Whether to make the output dataset private when pushing to HF Hub",
126
+ )
127
+
128
+ args = parser.parse_args()
129
+
130
+ print("\nRunning with arguments:")
131
+ for arg, value in vars(args).items():
132
+ print(f" {arg}: {value}")
133
+ print()
134
+
135
+ print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136
+ dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137
+ print("Dataset loaded!")
138
+
139
+ pipeline = build_distilabel_pipeline(
140
+ model=args.model,
141
+ base_url=args.vllm_server_url,
142
+ prompt_column=args.prompt_column,
143
+ temperature=args.temperature,
144
+ top_p=args.top_p,
145
+ max_new_tokens=args.max_new_tokens,
146
+ num_generations=args.num_generations,
147
+ )
148
+
149
+ print("Running generation pipeline...")
150
+ distiset = pipeline.run(dataset=dataset, use_cache=False)
151
+ print("Generation pipeline finished!")
152
+
153
+ if args.hf_output_dataset:
154
+ print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155
+ distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156
+ print("Dataset pushed!")
src/r1-v/src/open_r1/grpo-cot-72BEval.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ from openai import OpenAI
34
+ from concurrent.futures import ThreadPoolExecutor, as_completed
35
+ import time
36
+ # from utils.math_cot import *
37
+ # from qa_metrics.pedant import PEDANT
38
+ # from qa_metrics.answerBERT import AnswerBertActor
39
+
40
+ # pedant = PEDANT()
41
+ # answerBERT = AnswerBertActor(device='cuda:7')
42
+ client = OpenAI(
43
+ base_url="http://29.81.228.243:8081 /v1", # your vLLM server
44
+ api_key="ANYKEY", # if you set --api-key when launching
45
+ )
46
+
47
+ def validate_description(description, question):
48
+ input_message = "You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide a step-by-step reasoning within <think> </think> tags, then provide your answer as a single final answer, single letter choice, or a short phrase ENCLOSED with <answer> </answer> tags. \nText description: {Description}\nQuestion: {Question}\nPlease only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions.".format(Description=description, Question=question)
49
+ response = client.chat.completions.create(
50
+ model="Qwen2.5-72B-Instruct", # **must match** the returned id
51
+ messages=[
52
+ {"role": "system", "content": "You are a helpful assistant."},
53
+ {"role": "user", "content": input_message}
54
+ ]
55
+ )
56
+
57
+ # print('*'*10)
58
+ # print('Input Prompt: ', input_message)
59
+ # print('-'*10)
60
+ # print('Output Message: ', response.choices[0].message.content)
61
+ # print('-'*10)
62
+ # time.sleep(40)
63
+
64
+ return response.choices[0].message.content
65
+
66
+
67
+ @dataclass
68
+ class GRPOScriptArguments(ScriptArguments):
69
+ """
70
+ Script arguments for the GRPO training script.
71
+
72
+ Args:
73
+ reward_funcs (`list[str]`):
74
+ List of reward functions. Possible values: 'accuracy', 'format'.
75
+ """
76
+
77
+ reward_funcs: list[str] = field(
78
+ default_factory=lambda: ["accuracy", "format"],
79
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
80
+ )
81
+
82
+ # reward_funcs: list[str] = field(
83
+ # default_factory=lambda: ["accuracy"],
84
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
85
+ # )
86
+ max_pixels: Optional[int] = field(
87
+ default=12845056,
88
+ metadata={"help": "Maximum number of pixels for the image"},
89
+ )
90
+ min_pixels: Optional[int] = field(
91
+ default=3136,
92
+ metadata={"help": "Minimum number of pixels for the image"},
93
+ )
94
+ temporal: Optional[bool] = field(
95
+ default=True,
96
+ metadata={"help": "whether using temporal GRPO"},
97
+ )
98
+ len_control: Optional[bool] = field(
99
+ default=True,
100
+ metadata={"help": "whether using length reward"},
101
+ )
102
+
103
+
104
+ def accuracy_reward(completions, solution, **kwargs):
105
+ def extract_answer(text: str) -> str:
106
+ """
107
+ 1) Try the full <answer> … </answer> block.
108
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
109
+ 3) Otherwise return the original text.
110
+ """
111
+ # ① normal case <answer> … </answer>
112
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
113
+ if m:
114
+ return m.group(1).strip()
115
+
116
+ # ② fallback <answer> … <end-of-string>
117
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
118
+ if m:
119
+ return m.group(1).strip()
120
+
121
+ # ③ nothing found
122
+ return text.strip()
123
+
124
+ def extract_description(predict: str) -> Optional[str]:
125
+ """
126
+ Extracts the content of the <answer>…</answer> block from `predict`.
127
+ Returns the inner text (with leading/trailing whitespace stripped),
128
+ or None if no <answer> tag is found.
129
+ """
130
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
131
+ if not match:
132
+ return predict
133
+ return match.group(1).strip()
134
+
135
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
136
+ answer = predict
137
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
138
+
139
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
140
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
141
+ # format_score = format_reward(predict)
142
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
143
+
144
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
145
+ return accuracy_score
146
+
147
+ def normalize_number(num_str):
148
+ try:
149
+ num_str = num_str.replace(',', '')
150
+ return float(num_str)
151
+ except Exception as e:
152
+ print(f"Error converting '{num_str}' to float: {e}")
153
+ return None
154
+
155
+ def wer(reference, hypothesis):
156
+ ref_words = reference.split()
157
+ hyp_words = hypothesis.split()
158
+ m = len(ref_words)
159
+ n = len(hyp_words)
160
+ d = [[0]*(n+1) for _ in range(m+1)]
161
+ for i in range(m+1):
162
+ d[i][0] = i
163
+ for j in range(n+1):
164
+ d[0][j] = j
165
+ for i in range(1, m+1):
166
+ for j in range(1, n+1):
167
+ if ref_words[i-1] == hyp_words[j-1]:
168
+ d[i][j] = d[i-1][j-1]
169
+ else:
170
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
171
+ return d[m][n] / max(1, m)
172
+
173
+
174
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
175
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
176
+ scores = scorer.score(reference, hypothesis)
177
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
178
+ return average_fmeasure
179
+
180
+
181
+ question_type = kwargs['problem_type'][0]
182
+ questions = kwargs['problem']
183
+ # questions = kwargs['prompt']
184
+
185
+ contents = [completion[0]["content"] for completion in completions]
186
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
187
+ rewards = []
188
+
189
+
190
+ extracted_content_descriptions = [extract_description(str(ele)) for ele in contents]
191
+ description_answer_outputs = []
192
+
193
+
194
+ with ThreadPoolExecutor(max_workers=8) as executor:
195
+ futures = [
196
+ executor.submit(validate_description, desc, q)
197
+ for desc, q in zip(extracted_content_descriptions, questions)
198
+ ]
199
+ for future in as_completed(futures):
200
+ try:
201
+ description_answer_outputs.append(future.result())
202
+ except Exception as e:
203
+ # handle/log e
204
+ # description_answer_outputs.append(None)
205
+ print('Description output error: ', e)
206
+ description_answer_outputs.append(0)
207
+
208
+
209
+ contents = [str(ele) for ele in contents]
210
+ description_answer_outputs = [str(ele) for ele in description_answer_outputs]
211
+
212
+ gt_answers = [extract_answer(str(sol)) for sol in solution]
213
+ extracted_description_outputs = [extract_answer(str(description_answer_outputs[index_description])) for index_description in range(len(description_answer_outputs))]
214
+
215
+
216
+ # print('GT answers: ', gt_answers)
217
+ # print('Description answers: ', description_answer_outputs[0])
218
+ # print('-'*10)
219
+ # import time
220
+ # time.sleep(10)
221
+
222
+ description_rewards = [compute_math_score_single(extracted_description_outputs[count_idx], gt_answers[count_idx]) for count_idx in range(len(description_answer_outputs))]
223
+
224
+ # print('()'*10)
225
+ # print("Question: ", questions[0])
226
+ # print(gt_answers)
227
+ # print('Description outputs', description_answer_outputs[0])
228
+ # print(description_rewards)
229
+ # print('-'*10)
230
+ # time.sleep(30)
231
+
232
+
233
+ # for content, sol, description_reward in zip(contents, solution, description_rewards):
234
+ for content, gt_ans, description_reward in zip(contents, gt_answers, description_rewards):
235
+ try:
236
+ output_ans = extract_answer(str(content))
237
+ # gt_ans = extract_answer(sol)
238
+
239
+ if question_type == "OCR":
240
+ # description_extraction = extract_answer(str(second_content))
241
+ # description_error_rate = wer(gt_ans, description_extraction)
242
+ description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
243
+ # error_rate = wer(gt_ans, output_ans)
244
+ answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
245
+ # reward = (1 - error_rate) + (1- description_error_rate)
246
+ # reward = max(0.0, min(2.0, reward))
247
+ # print('Extracted description: ', description_extraction)
248
+ # print('Generated answer: ', output_ans)
249
+ # print('Sol: ', gt_ans)
250
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
251
+ # print('-' * 10)
252
+ reward = description_pendat_reward + answer_pedant_reward
253
+ # elif question_type == "free-form":
254
+ # score = compute_rouge_score(gt_ans, output_ans)
255
+ # reward = max(0.0, min(1.0, score))
256
+ elif question_type == "regression":
257
+ gt_number = normalize_number(gt_ans)
258
+ out_number = normalize_number(output_ans)
259
+ if gt_number is None or out_number is None:
260
+ reward = 0.0
261
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
262
+ rel_diff = min(1.0, max(0.0, rel_diff))
263
+ reward = 1 - rel_diff
264
+ elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
265
+ # description_reward = compute_math_score_single(description_extraction, gt_ans)
266
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
267
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
268
+ # print('-' * 10)
269
+ reward = description_reward + answer_reward
270
+ # reward = answer_reward
271
+ else:
272
+ print('Falling back to none rewards')
273
+ reward = 0.0
274
+ except Exception as e:
275
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
276
+ reward = 0.0
277
+
278
+ rewards.append(reward)
279
+ if os.getenv("DEBUG_MODE") == "true":
280
+ log_path = os.getenv("LOG_PATH")
281
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
282
+ with open(log_path, "a", encoding="utf-8") as f:
283
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
284
+ f.write(f"Content: {content}\n")
285
+ f.write(f"Solution: {gt_ans}\n")
286
+
287
+ # print("rewards: ", rewards)
288
+ return rewards
289
+
290
+
291
+ def simple_format_reward(completions, **kwargs):
292
+ """Reward function that checks if the completion has a specific format."""
293
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
294
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
295
+ completion_contents = [completion[0]["content"] for completion in completions]
296
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
297
+ return [0.1 if match else 0.0 for match in matches]
298
+
299
+
300
+ reward_funcs_registry = {
301
+ "accuracy": accuracy_reward,
302
+ "format": simple_format_reward,
303
+ }
304
+
305
+ # SYSTEM_PROMPT = (
306
+ # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
307
+ # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
308
+ # "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
309
+ # "<think> reasoning process here </think><answer> answer here </answer>"
310
+ # )
311
+
312
+ SYSTEM_PROMPT = (
313
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
314
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
315
+ "and provide this step-by-step reasoning within <think> </think> tags. "
316
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
317
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
318
+ )
319
+
320
+
321
+ def main(script_args, training_args, model_args):
322
+ # Get reward functions
323
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
324
+
325
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
326
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
327
+ else:
328
+ # Load the dataset
329
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
330
+
331
+
332
+ # Format into conversation
333
+ def make_conversation(example):
334
+ return {
335
+ "prompt": [
336
+ {"role": "system", "content": SYSTEM_PROMPT},
337
+ {"role": "user", "content": example["problem"]},
338
+ ],
339
+ }
340
+
341
+
342
+ # QUESTION_TEMPLATE = (
343
+ # "{Question}\n"
344
+ # "Please think about this question as if you were a human pondering deeply. "
345
+ # "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
346
+ # "It's encouraged to include self-reflection or verification in the reasoning process. "
347
+ # "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
348
+ # )
349
+
350
+ QUESTION_TEMPLATE = (
351
+ "{Question}\n"
352
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
353
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
354
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
355
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
356
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
357
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
358
+ )
359
+
360
+
361
+ TYPE_TEMPLATE = {
362
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
363
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
364
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
365
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
366
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
367
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
368
+ }
369
+
370
+ def make_conversation_image(example):
371
+
372
+ return {
373
+ "prompt": [
374
+ {
375
+ "role": "user",
376
+ "content": [
377
+ {"type": "image"},
378
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
379
+ ],
380
+ },
381
+ ],
382
+ }
383
+
384
+
385
+ def make_conversation_video(example):
386
+ return {
387
+ "prompt": [
388
+ {
389
+ "role": "user",
390
+ "content": [
391
+ {"type": "video"},
392
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
393
+ ],
394
+ },
395
+ ],
396
+ }
397
+
398
+ def make_conversation_image_and_video(example):
399
+ if example["problem_type"] == 'multiple choice':
400
+ question = example['problem'] + "Options:\n"
401
+ for op in example["options"]:
402
+ question += op + "\n"
403
+ else:
404
+ question = example['problem']
405
+
406
+
407
+ # msg ={
408
+ # "prompt":
409
+ # [{
410
+ # "role": "user",
411
+ # "content": [
412
+ # {
413
+ # "type": example['data_type'],
414
+ # # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
415
+ # },
416
+ # {
417
+ # "type": "text",
418
+ # "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
419
+ # }
420
+ # ]
421
+ # }]
422
+ # }
423
+
424
+ msg ={
425
+ "prompt":
426
+ [{
427
+ "role": "user",
428
+ "content": [
429
+ {
430
+ "type": example['data_type'],
431
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
432
+ },
433
+ {
434
+ "type": "text",
435
+ "text": QUESTION_TEMPLATE.format(Question=question)
436
+ }
437
+ ]
438
+ }]
439
+ }
440
+
441
+ # return msg
442
+ return {
443
+ "prompt": msg["prompt"],
444
+ "problem": question,
445
+ }
446
+
447
+
448
+ dataset = dataset.map(make_conversation_image_and_video)
449
+
450
+
451
+ # print('Example problem')
452
+ # print(dataset['train']['problem'][10])
453
+ # time.sleep(30)
454
+
455
+
456
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
457
+ print("using: ", trainer_cls)
458
+
459
+
460
+ # Initialize the GRPO trainer
461
+ trainer = trainer_cls(
462
+ model=model_args.model_name_or_path,
463
+ reward_funcs=reward_funcs,
464
+ args=training_args,
465
+ script_args=script_args,
466
+ train_dataset=dataset[script_args.dataset_train_split],
467
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
468
+ peft_config=get_peft_config(model_args),
469
+ attn_implementation=model_args.attn_implementation,
470
+ max_pixels=script_args.max_pixels,
471
+ min_pixels=script_args.min_pixels,
472
+ )
473
+
474
+ if training_args.resume_from_checkpoint is not None:
475
+ checkpoint = training_args.resume_from_checkpoint
476
+ trainer.train(resume_from_checkpoint=checkpoint)
477
+ else:
478
+ trainer.train()
479
+
480
+ # Save and push to hub
481
+ trainer.save_model(training_args.output_dir)
482
+ if training_args.push_to_hub:
483
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
484
+
485
+
486
+ if __name__ == "__main__":
487
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
488
+ script_args, training_args, model_args = parser.parse_args_and_config()
489
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-LLMEval.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ # from utils.gpt_eval import infer
34
+ # from utils.math_cot import *
35
+ # from qa_metrics.pedant import PEDANT
36
+ # from qa_metrics.answerBERT import AnswerBertActor
37
+
38
+ # pedant = PEDANT()
39
+ # answerBERT = AnswerBertActor(device='cuda:7')
40
+
41
+ alpha = 1.0
42
+
43
+ TYPE_TEMPLATE = {
44
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
45
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
46
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
47
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
48
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
49
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
50
+ }
51
+
52
+ '''
53
+ gpt infer
54
+ '''
55
+ import os
56
+ from openai import AzureOpenAI
57
+ import time
58
+
59
+ import base64
60
+ from mimetypes import guess_type
61
+
62
+
63
+ def azure_gpt4(messages, model):
64
+ outputs = []
65
+ for message in messages:
66
+ input_prompt = [
67
+ { "role": "system", "content": "You are a helpful assistant." },
68
+ { "role": "user", "content": [
69
+ {
70
+ "type": "text",
71
+ "text": message["instruction"]
72
+ },
73
+ # {
74
+ # "type": "image_url",
75
+ # "image_url": {
76
+ # "url": message["image"]
77
+ # }
78
+ # }
79
+ ]}
80
+ ]
81
+ ## try N times if API exceed limit ...
82
+ for i in range(10):
83
+ try:
84
+ output = client.chat.completions.create(
85
+ model=model, messages=input_prompt, max_tokens=2000
86
+ )
87
+
88
+ output_text = output.choices[0].message.content
89
+ break ## exit if successful
90
+
91
+ except Exception as e:
92
+ print(f'Index {i} got error message: {e}')
93
+ output_text = ''
94
+ time.sleep(3)
95
+
96
+ outputs.append(output_text)
97
+
98
+ return outputs
99
+
100
+
101
+ client = AzureOpenAI(
102
+ api_key = "83f30a2a22324395b854bd343db38d85",
103
+ api_version = "2024-08-01-preview",
104
+ azure_endpoint = "https://francecentral.api.cognitive.microsoft.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
105
+ )
106
+
107
+ model = "gpt-4o"
108
+ prompt_template = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. Provide your answer as a single final answer or a short phrase enclosed with <answer></answer>. \nText description: {text}\nQuestion: {question}'''
109
+
110
+
111
+ def infer(prompt):
112
+ # prompt_question = prompt_question.replace('<image>', '')
113
+ # prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
114
+
115
+ messages = [
116
+ {"instruction": prompt},
117
+ ]
118
+ prompt_success = False
119
+ prompt_time = 0
120
+ outputs = ['<answer> None </answer>']
121
+ while prompt_success == False and prompt_time <= 2:
122
+ try:
123
+ outputs = azure_gpt4(messages, model)
124
+ prompt_success = True
125
+ except:
126
+ prompt_time += 1
127
+ time.sleep(5)
128
+
129
+ return outputs[0]
130
+
131
+ '''
132
+ end of gpt infer
133
+ '''
134
+
135
+
136
+ from concurrent.futures import ThreadPoolExecutor, as_completed
137
+
138
+ def _call_infer(desc):
139
+ return infer(desc)
140
+
141
+ @dataclass
142
+ class GRPOScriptArguments(ScriptArguments):
143
+ """
144
+ Script arguments for the GRPO training script.
145
+
146
+ Args:
147
+ reward_funcs (`list[str]`):
148
+ List of reward functions. Possible values: 'accuracy', 'format'.
149
+ """
150
+
151
+ reward_funcs: list[str] = field(
152
+ default_factory=lambda: ["accuracy", "format"],
153
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
154
+ )
155
+
156
+ # reward_funcs: list[str] = field(
157
+ # default_factory=lambda: ["accuracy"],
158
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
159
+ # )
160
+ max_pixels: Optional[int] = field(
161
+ default=12845056,
162
+ metadata={"help": "Maximum number of pixels for the image"},
163
+ )
164
+ min_pixels: Optional[int] = field(
165
+ default=3136,
166
+ metadata={"help": "Minimum number of pixels for the image"},
167
+ )
168
+ temporal: Optional[bool] = field(
169
+ default=True,
170
+ metadata={"help": "whether using temporal GRPO"},
171
+ )
172
+ len_control: Optional[bool] = field(
173
+ default=True,
174
+ metadata={"help": "whether using length reward"},
175
+ )
176
+
177
+
178
+
179
+ def accuracy_reward(completions, solution, **kwargs):
180
+ def extract_answer(text: str) -> str:
181
+ """
182
+ 1) Try the full <answer> … </answer> block.
183
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
184
+ 3) Otherwise return the original text.
185
+ """
186
+ # ① normal case <answer> … </answer>
187
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
188
+ if m:
189
+ return m.group(1).strip()
190
+
191
+ # ② fallback <answer> … <end-of-string>
192
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
193
+ if m:
194
+ return m.group(1).strip()
195
+
196
+ # ③ nothing found
197
+ return text.strip()
198
+
199
+ def extract_description(predict: str) -> Optional[str]:
200
+ """
201
+ Extracts the content of the <answer>…</answer> block from `predict`.
202
+ Returns the inner text (with leading/trailing whitespace stripped),
203
+ or None if no <answer> tag is found.
204
+ """
205
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
206
+ if not match:
207
+ return predict
208
+ return match.group(1).strip()
209
+
210
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
211
+ answer = predict
212
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
213
+
214
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
215
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
216
+ # format_score = format_reward(predict)
217
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
218
+
219
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
220
+ return accuracy_score
221
+
222
+ def normalize_number(num_str):
223
+ try:
224
+ num_str = num_str.replace(',', '')
225
+ return float(num_str)
226
+ except Exception as e:
227
+ print(f"Error converting '{num_str}' to float: {e}")
228
+ return None
229
+
230
+ def wer(reference, hypothesis):
231
+ ref_words = reference.split()
232
+ hyp_words = hypothesis.split()
233
+ m = len(ref_words)
234
+ n = len(hyp_words)
235
+ d = [[0]*(n+1) for _ in range(m+1)]
236
+ for i in range(m+1):
237
+ d[i][0] = i
238
+ for j in range(n+1):
239
+ d[0][j] = j
240
+ for i in range(1, m+1):
241
+ for j in range(1, n+1):
242
+ if ref_words[i-1] == hyp_words[j-1]:
243
+ d[i][j] = d[i-1][j-1]
244
+ else:
245
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
246
+ return d[m][n] / max(1, m)
247
+
248
+
249
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
250
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
251
+ scores = scorer.score(reference, hypothesis)
252
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
253
+ return average_fmeasure
254
+
255
+ # print('Computing rewards now...')
256
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
257
+ # second_completions = kwargs.get("second_completions")
258
+ # second_contents = [comp[0]["content"] for comp in second_completions]
259
+ # print('second prompts', second_prompts)
260
+ # print('-'*10)
261
+ # print('second completions', second_completions)
262
+ # print('-'*10)
263
+
264
+ # import time
265
+ # time.sleep(30)
266
+ question_type = kwargs['problem_type'][0]
267
+ questions = kwargs['problem']
268
+
269
+ contents = [completion[0]["content"] for completion in completions]
270
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
271
+ rewards = []
272
+
273
+ extracted_content_descriptions = [extract_description(ele) for ele in contents]
274
+
275
+ description_query_inputs = []
276
+
277
+ for index in range(len(extracted_content_descriptions)):
278
+ prompt_question = questions[index]
279
+ des_text = extracted_content_descriptions[index]
280
+ prompt_question = prompt_question.replace('<image>', '')
281
+ prompt_input = prompt_template.replace('{text}', des_text).replace('{question}', prompt_question) + TYPE_TEMPLATE[question_type]
282
+ description_query_inputs.append(prompt_input)
283
+
284
+
285
+ description_score_outputs = []
286
+ with ThreadPoolExecutor(max_workers=8) as executor:
287
+ # kick off all the futures
288
+ # futures = [
289
+ # executor.submit(_call_infer, desc, ques)
290
+ # for desc, ques in zip(extracted_content_descriptions, questions)
291
+ # ]
292
+ futures = [
293
+ executor.submit(_call_infer, desc)
294
+ for desc in description_query_inputs
295
+ ]
296
+ # collect as they finish (optional—keeps order of completion)
297
+ for fut in as_completed(futures):
298
+ description_score_outputs.append(extract_answer(fut.result()))
299
+
300
+ gt_answers = [extract_answer(sol) for sol in solution]
301
+ description_rewards = [compute_math_score_single(description_score_outputs[count_idx], gt_answers[count_idx]) for count_idx in range(len(description_score_outputs))]
302
+
303
+ # print(gt_answers)
304
+ # print(description_score_outputs)
305
+ # print(description_rewards)
306
+ # print('-'*10)
307
+
308
+
309
+ for content, gt_ans, description_reward in zip(contents, gt_answers, description_rewards):
310
+ # for content, sol, question in zip(contents, solution, questions):
311
+ # for content, sol, second_content in zip(contents, solution, second_completions):
312
+ try:
313
+ output_ans = extract_answer(content)
314
+ # gt_ans = extract_answer(sol)
315
+ # description_extraction = extract_answer(second_content)
316
+ # if question_type == "multiple choice":
317
+ # reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
318
+ # elif question_type == "numerical":
319
+ # gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
320
+ # out_has_decimal = ("." in output_ans) or ("," in output_ans)
321
+ # if gt_has_decimal != out_has_decimal:
322
+ # reward = 0.0
323
+ # else:
324
+ # gt_number = normalize_number(gt_ans)
325
+ # out_number = normalize_number(output_ans)
326
+ # if gt_number is None or out_number is None:
327
+ # reward = 0.0
328
+ # else:
329
+ # reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
330
+ if question_type == "OCR":
331
+ # description_extraction = extract_answer(second_content)
332
+ # description_error_rate = wer(gt_ans, description_extraction)
333
+ # description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
334
+ # error_rate = wer(gt_ans, output_ans)
335
+ answer_pedant_reward = pedant.get_score(gt_ans, output_ans, questions[0])
336
+ # reward = (1 - error_rate) + (1- description_error_rate)
337
+ # reward = max(0.0, min(2.0, reward))
338
+ # print('Extracted description: ', description_extraction)
339
+ # print('Generated answer: ', output_ans)
340
+ # print('Sol: ', gt_ans)
341
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
342
+ # print('-' * 10)
343
+ # reward = description_pendat_reward + answer_pedant_reward
344
+ reward = answer_pedant_reward
345
+ # elif question_type == "free-form":
346
+ # score = compute_rouge_score(gt_ans, output_ans)
347
+ # reward = max(0.0, min(1.0, score))
348
+ elif question_type == "regression":
349
+ gt_number = normalize_number(gt_ans)
350
+ out_number = normalize_number(output_ans)
351
+ if gt_number is None or out_number is None:
352
+ reward = 0.0
353
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
354
+ rel_diff = min(1.0, max(0.0, rel_diff))
355
+ reward = 1 - rel_diff
356
+ elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
357
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
358
+
359
+
360
+ # print(f"Extracted description: {description_extraction} | Generated answer: {output_ans} | Sol: {gt_ans}")
361
+ # print(f'Description reward: {description_reward} | answer reward: {answer_reward} | final reward: {reward}')
362
+ # print('-' * 10)
363
+
364
+ if description_reward == 0 and answer_reward == 1:
365
+ reward = alpha
366
+ else:
367
+ reward = description_reward + answer_reward
368
+ # reward = answer_reward
369
+ else:
370
+ print('Falling back to none rewards')
371
+ reward = 0.0
372
+ except Exception as e:
373
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
374
+ reward = 0.0
375
+
376
+ rewards.append(reward)
377
+
378
+ if os.getenv("DEBUG_MODE") == "true":
379
+ log_path = os.getenv("LOG_PATH")
380
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
381
+ with open(log_path, "a", encoding="utf-8") as f:
382
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
383
+ f.write(f"Content: {content}\n")
384
+ f.write(f"Solution: {gt_ans}\n")
385
+
386
+ return rewards
387
+
388
+
389
+ def simple_format_reward(completions, **kwargs):
390
+ """Reward function that checks if the completion has a specific format."""
391
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
392
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
393
+ completion_contents = [completion[0]["content"] for completion in completions]
394
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
395
+ return [0.1 if match else 0.0 for match in matches]
396
+
397
+
398
+ reward_funcs_registry = {
399
+ "accuracy": accuracy_reward,
400
+ "format": simple_format_reward,
401
+ }
402
+
403
+ # SYSTEM_PROMPT = (
404
+ # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
405
+ # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
406
+ # "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
407
+ # "<think> reasoning process here </think><answer> answer here </answer>"
408
+ # )
409
+
410
+ SYSTEM_PROMPT = (
411
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
412
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
413
+ "and provide this step-by-step reasoning within <think> </think> tags. "
414
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
415
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
416
+ )
417
+
418
+
419
+ def main(script_args, training_args, model_args):
420
+ # Get reward functions
421
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
422
+
423
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
424
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
425
+ else:
426
+ # Load the dataset
427
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
428
+
429
+
430
+ # Format into conversation
431
+ def make_conversation(example):
432
+ return {
433
+ "prompt": [
434
+ {"role": "system", "content": SYSTEM_PROMPT},
435
+ {"role": "user", "content": example["problem"]},
436
+ ],
437
+ }
438
+
439
+
440
+ # QUESTION_TEMPLATE = (
441
+ # "{Question}\n"
442
+ # "Please think about this question as if you were a human pondering deeply. "
443
+ # "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
444
+ # "It's encouraged to include self-reflection or verification in the reasoning process. "
445
+ # "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
446
+ # )
447
+
448
+ QUESTION_TEMPLATE = (
449
+ "{Question}\n"
450
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
451
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
452
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
453
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
454
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
455
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
456
+ )
457
+
458
+
459
+
460
+ def make_conversation_image(example):
461
+
462
+ return {
463
+ "prompt": [
464
+ {
465
+ "role": "user",
466
+ "content": [
467
+ {"type": "image"},
468
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
469
+ ],
470
+ },
471
+ ],
472
+ }
473
+
474
+
475
+ def make_conversation_video(example):
476
+ return {
477
+ "prompt": [
478
+ {
479
+ "role": "user",
480
+ "content": [
481
+ {"type": "video"},
482
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
483
+ ],
484
+ },
485
+ ],
486
+ }
487
+
488
+ def make_conversation_image_and_video(example):
489
+ if example["problem_type"] == 'multiple choice':
490
+ question = example['problem'] + "Options:\n"
491
+ for op in example["options"]:
492
+ question += op + "\n"
493
+ else:
494
+ question = example['problem']
495
+
496
+
497
+ msg ={
498
+ "prompt":
499
+ [{
500
+ "role": "user",
501
+ "content": [
502
+ {
503
+ "type": example['data_type'],
504
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
505
+ },
506
+ {
507
+ "type": "text",
508
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
509
+ }
510
+ ]
511
+ }]
512
+ }
513
+
514
+ return msg
515
+
516
+
517
+ dataset = dataset.map(make_conversation_image_and_video)
518
+
519
+
520
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
521
+ print("using: ", trainer_cls)
522
+
523
+ # Initialize the GRPO trainer
524
+ trainer = trainer_cls(
525
+ model=model_args.model_name_or_path,
526
+ reward_funcs=reward_funcs,
527
+ args=training_args,
528
+ script_args=script_args,
529
+ train_dataset=dataset[script_args.dataset_train_split],
530
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
531
+ peft_config=get_peft_config(model_args),
532
+ attn_implementation=model_args.attn_implementation,
533
+ max_pixels=script_args.max_pixels,
534
+ min_pixels=script_args.min_pixels,
535
+ )
536
+
537
+ if training_args.resume_from_checkpoint is not None:
538
+ checkpoint = training_args.resume_from_checkpoint
539
+ trainer.train(resume_from_checkpoint=checkpoint)
540
+ else:
541
+ trainer.train()
542
+
543
+ # Save and push to hub
544
+ trainer.save_model(training_args.output_dir)
545
+ if training_args.push_to_hub:
546
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
547
+
548
+
549
+ if __name__ == "__main__":
550
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
551
+ script_args, training_args, model_args = parser.parse_args_and_config()
552
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-answerBERT-eval.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ # from utils.math_cot import *
34
+ # from qa_metrics.pedant import PEDANT
35
+ from qa_metrics.answerBERT import AnswerBertActor
36
+
37
+ # pedant = PEDANT()
38
+ answerBERT = AnswerBertActor(device='cuda:0')
39
+
40
+ @dataclass
41
+ class GRPOScriptArguments(ScriptArguments):
42
+ """
43
+ Script arguments for the GRPO training script.
44
+
45
+ Args:
46
+ reward_funcs (`list[str]`):
47
+ List of reward functions. Possible values: 'accuracy', 'format'.
48
+ """
49
+
50
+ reward_funcs: list[str] = field(
51
+ default_factory=lambda: ["accuracy", "format"],
52
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
53
+ )
54
+
55
+ # reward_funcs: list[str] = field(
56
+ # default_factory=lambda: ["accuracy"],
57
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
58
+ # )
59
+ max_pixels: Optional[int] = field(
60
+ default=12845056,
61
+ metadata={"help": "Maximum number of pixels for the image"},
62
+ )
63
+ min_pixels: Optional[int] = field(
64
+ default=3136,
65
+ metadata={"help": "Minimum number of pixels for the image"},
66
+ )
67
+ temporal: Optional[bool] = field(
68
+ default=True,
69
+ metadata={"help": "whether using temporal GRPO"},
70
+ )
71
+ len_control: Optional[bool] = field(
72
+ default=True,
73
+ metadata={"help": "whether using length reward"},
74
+ )
75
+
76
+
77
+
78
+ def accuracy_reward(completions, solution, **kwargs):
79
+ def extract_answer(text: str) -> str:
80
+ """
81
+ 1) Try the full <answer> … </answer> block.
82
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
83
+ 3) Otherwise return the original text.
84
+ """
85
+ # ① normal case <answer> … </answer>
86
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
87
+ if m:
88
+ return m.group(1).strip()
89
+
90
+ # ② fallback <answer> … <end-of-string>
91
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
92
+ if m:
93
+ return m.group(1).strip()
94
+
95
+ # ③ nothing found
96
+ return text.strip()
97
+
98
+ def extract_description(predict: str) -> Optional[str]:
99
+ """
100
+ Extracts the content of the <answer>…</answer> block from `predict`.
101
+ Returns the inner text (with leading/trailing whitespace stripped),
102
+ or None if no <answer> tag is found.
103
+ """
104
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
105
+ if not match:
106
+ return predict
107
+ return match.group(1).strip()
108
+
109
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
110
+ answer = predict
111
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
112
+
113
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
114
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
115
+ # format_score = format_reward(predict)
116
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
117
+
118
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
119
+ return accuracy_score
120
+
121
+ def normalize_number(num_str):
122
+ try:
123
+ num_str = num_str.replace(',', '')
124
+ return float(num_str)
125
+ except Exception as e:
126
+ print(f"Error converting '{num_str}' to float: {e}")
127
+ return None
128
+
129
+ def wer(reference, hypothesis):
130
+ ref_words = reference.split()
131
+ hyp_words = hypothesis.split()
132
+ m = len(ref_words)
133
+ n = len(hyp_words)
134
+ d = [[0]*(n+1) for _ in range(m+1)]
135
+ for i in range(m+1):
136
+ d[i][0] = i
137
+ for j in range(n+1):
138
+ d[0][j] = j
139
+ for i in range(1, m+1):
140
+ for j in range(1, n+1):
141
+ if ref_words[i-1] == hyp_words[j-1]:
142
+ d[i][j] = d[i-1][j-1]
143
+ else:
144
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
145
+ return d[m][n] / max(1, m)
146
+
147
+
148
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
149
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
150
+ scores = scorer.score(reference, hypothesis)
151
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
152
+ return average_fmeasure
153
+
154
+ # print('Computing rewards now...')
155
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
156
+ # second_completions = kwargs.get("second_completions")
157
+ # second_contents = [comp[0]["content"] for comp in second_completions]
158
+ # print('second prompts', second_prompts)
159
+ # print('-'*10)
160
+ # print('second completions', second_completions)
161
+ # print('-'*10)
162
+
163
+ # import time
164
+ # time.sleep(30)
165
+ question_type = kwargs['problem_type'][0]
166
+ questions = kwargs['problem']
167
+
168
+ contents = [completion[0]["content"] for completion in completions]
169
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
170
+ rewards = []
171
+
172
+ extracted_content_descriptions = [extract_description(ele) for ele in contents]
173
+ # extracted_content_answers = [extract_answer(ele) for ele in contents]
174
+ # model = kwargs.get("model") # may be None if called elsewhere
175
+ # tokenizer = kwargs.get("tokenizer")
176
+ # # (optional) example use: let the model score the generated answer
177
+ # if model is not None and tokenizer is not None:
178
+ # model.eval()
179
+ description_inputs = [questions[index_count] + ' [SEP] ' + extracted_content_descriptions[index_count] for index_count in range(len(extracted_content_descriptions))]
180
+ description_rewards = answerBERT.batch_predict(description_inputs, batch_size = 32)
181
+
182
+ for content, sol, description_reward in zip(contents, solution, description_rewards):
183
+ # for content, sol, question in zip(contents, solution, questions):
184
+ # for content, sol, second_content in zip(contents, solution, second_completions):
185
+ try:
186
+ output_ans = extract_answer(content)
187
+ gt_ans = extract_answer(sol)
188
+ # description_extraction = extract_answer(second_content)
189
+ # if question_type == "multiple choice":
190
+ # reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
191
+ # elif question_type == "numerical":
192
+ # gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
193
+ # out_has_decimal = ("." in output_ans) or ("," in output_ans)
194
+ # if gt_has_decimal != out_has_decimal:
195
+ # reward = 0.0
196
+ # else:
197
+ # gt_number = normalize_number(gt_ans)
198
+ # out_number = normalize_number(output_ans)
199
+ # if gt_number is None or out_number is None:
200
+ # reward = 0.0
201
+ # else:
202
+ # reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
203
+ if question_type == "OCR":
204
+ # description_extraction = extract_answer(second_content)
205
+ # description_error_rate = wer(gt_ans, description_extraction)
206
+ description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
207
+ # error_rate = wer(gt_ans, output_ans)
208
+ answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
209
+ # reward = (1 - error_rate) + (1- description_error_rate)
210
+ # reward = max(0.0, min(2.0, reward))
211
+ # print('Extracted description: ', description_extraction)
212
+ # print('Generated answer: ', output_ans)
213
+ # print('Sol: ', gt_ans)
214
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
215
+ # print('-' * 10)
216
+ reward = description_pendat_reward + answer_pedant_reward
217
+ # elif question_type == "free-form":
218
+ # score = compute_rouge_score(gt_ans, output_ans)
219
+ # reward = max(0.0, min(1.0, score))
220
+ # elif question_type == "regression":
221
+ # gt_number = normalize_number(gt_ans)
222
+ # out_number = normalize_number(output_ans)
223
+ # if gt_number is None or out_number is None:
224
+ # reward = 0.0
225
+ # rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
226
+ # rel_diff = min(1.0, max(0.0, rel_diff))
227
+ # reward = 1 - rel_diff
228
+ elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical" or question_type == "regression":
229
+ # print('Extracted description: ', description_extraction)
230
+ # print('Generated answer: ', output_ans)
231
+ # print('Sol: ', gt_ans)
232
+
233
+ # description_reward = compute_math_score_single(description_extraction, gt_ans)
234
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
235
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
236
+ # print('-' * 10)
237
+ reward = description_reward + answer_reward
238
+ else:
239
+ print('Falling back to none rewards')
240
+ reward = 0.0
241
+ except Exception as e:
242
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
243
+ reward = 0.0
244
+
245
+ rewards.append(reward)
246
+
247
+ if os.getenv("DEBUG_MODE") == "true":
248
+ log_path = os.getenv("LOG_PATH")
249
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
250
+ with open(log_path, "a", encoding="utf-8") as f:
251
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
252
+ f.write(f"Content: {content}\n")
253
+ f.write(f"Solution: {sol}\n")
254
+
255
+ return rewards
256
+
257
+
258
+ def simple_format_reward(completions, **kwargs):
259
+ """Reward function that checks if the completion has a specific format."""
260
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
261
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
262
+ completion_contents = [completion[0]["content"] for completion in completions]
263
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
264
+ return [0.1 if match else 0.0 for match in matches]
265
+
266
+
267
+ reward_funcs_registry = {
268
+ "accuracy": accuracy_reward,
269
+ "format": simple_format_reward,
270
+ }
271
+
272
+ # SYSTEM_PROMPT = (
273
+ # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
274
+ # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
275
+ # "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
276
+ # "<think> reasoning process here </think><answer> answer here </answer>"
277
+ # )
278
+
279
+ SYSTEM_PROMPT = (
280
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
281
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
282
+ "and provide this step-by-step reasoning within <think> </think> tags. "
283
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
284
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
285
+ )
286
+
287
+
288
+ def main(script_args, training_args, model_args):
289
+ # Get reward functions
290
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
291
+
292
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
293
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
294
+ else:
295
+ # Load the dataset
296
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
297
+
298
+
299
+ # Format into conversation
300
+ def make_conversation(example):
301
+ return {
302
+ "prompt": [
303
+ {"role": "system", "content": SYSTEM_PROMPT},
304
+ {"role": "user", "content": example["problem"]},
305
+ ],
306
+ }
307
+
308
+
309
+ # QUESTION_TEMPLATE = (
310
+ # "{Question}\n"
311
+ # "Please think about this question as if you were a human pondering deeply. "
312
+ # "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
313
+ # "It's encouraged to include self-reflection or verification in the reasoning process. "
314
+ # "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
315
+ # )
316
+
317
+ QUESTION_TEMPLATE = (
318
+ "{Question}\n"
319
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
320
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
321
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
322
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
323
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
324
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
325
+ )
326
+
327
+
328
+ TYPE_TEMPLATE = {
329
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
330
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
331
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
332
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
333
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
334
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
335
+ }
336
+
337
+ def make_conversation_image(example):
338
+
339
+ return {
340
+ "prompt": [
341
+ {
342
+ "role": "user",
343
+ "content": [
344
+ {"type": "image"},
345
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
346
+ ],
347
+ },
348
+ ],
349
+ }
350
+
351
+
352
+ def make_conversation_video(example):
353
+ return {
354
+ "prompt": [
355
+ {
356
+ "role": "user",
357
+ "content": [
358
+ {"type": "video"},
359
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
360
+ ],
361
+ },
362
+ ],
363
+ }
364
+
365
+ def make_conversation_image_and_video(example):
366
+ if example["problem_type"] == 'multiple choice':
367
+ question = example['problem'] + "Options:\n"
368
+ for op in example["options"]:
369
+ question += op + "\n"
370
+ else:
371
+ question = example['problem']
372
+
373
+
374
+ msg ={
375
+ "prompt":
376
+ [{
377
+ "role": "user",
378
+ "content": [
379
+ {
380
+ "type": example['data_type'],
381
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
382
+ },
383
+ {
384
+ "type": "text",
385
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
386
+ }
387
+ ]
388
+ }]
389
+ }
390
+
391
+ return msg
392
+
393
+
394
+ dataset = dataset.map(make_conversation_image_and_video)
395
+
396
+
397
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
398
+ print("using: ", trainer_cls)
399
+
400
+ # Initialize the GRPO trainer
401
+ trainer = trainer_cls(
402
+ model=model_args.model_name_or_path,
403
+ reward_funcs=reward_funcs,
404
+ args=training_args,
405
+ script_args=script_args,
406
+ train_dataset=dataset[script_args.dataset_train_split],
407
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
408
+ peft_config=get_peft_config(model_args),
409
+ attn_implementation=model_args.attn_implementation,
410
+ max_pixels=script_args.max_pixels,
411
+ min_pixels=script_args.min_pixels,
412
+ )
413
+
414
+ if training_args.resume_from_checkpoint is not None:
415
+ checkpoint = training_args.resume_from_checkpoint
416
+ trainer.train(resume_from_checkpoint=checkpoint)
417
+ else:
418
+ trainer.train()
419
+
420
+ # Save and push to hub
421
+ trainer.save_model(training_args.output_dir)
422
+ if training_args.push_to_hub:
423
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
424
+
425
+
426
+ if __name__ == "__main__":
427
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
428
+ script_args, training_args, model_args = parser.parse_args_and_config()
429
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-noDesEval.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ # from utils.math_cot import *
34
+ # from qa_metrics.pedant import PEDANT
35
+ # from qa_metrics.answerBERT import AnswerBertActor
36
+
37
+ # pedant = PEDANT()
38
+ # answerBERT = AnswerBertActor(device='cuda:7')
39
+
40
+ @dataclass
41
+ class GRPOScriptArguments(ScriptArguments):
42
+ """
43
+ Script arguments for the GRPO training script.
44
+
45
+ Args:
46
+ reward_funcs (`list[str]`):
47
+ List of reward functions. Possible values: 'accuracy', 'format'.
48
+ """
49
+
50
+ reward_funcs: list[str] = field(
51
+ default_factory=lambda: ["accuracy", "format"],
52
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
53
+ )
54
+
55
+ # reward_funcs: list[str] = field(
56
+ # default_factory=lambda: ["accuracy"],
57
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
58
+ # )
59
+ max_pixels: Optional[int] = field(
60
+ default=12845056,
61
+ metadata={"help": "Maximum number of pixels for the image"},
62
+ )
63
+ min_pixels: Optional[int] = field(
64
+ default=3136,
65
+ metadata={"help": "Minimum number of pixels for the image"},
66
+ )
67
+ temporal: Optional[bool] = field(
68
+ default=True,
69
+ metadata={"help": "whether using temporal GRPO"},
70
+ )
71
+ len_control: Optional[bool] = field(
72
+ default=True,
73
+ metadata={"help": "whether using length reward"},
74
+ )
75
+
76
+
77
+ def accuracy_reward(completions, solution, **kwargs):
78
+ def extract_answer(text: str) -> str:
79
+ """
80
+ 1) Try the full <answer> … </answer> block.
81
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
82
+ 3) Otherwise return the original text.
83
+ """
84
+ # ① normal case <answer> … </answer>
85
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
86
+ if m:
87
+ return m.group(1).strip()
88
+
89
+ # ② fallback <answer> … <end-of-string>
90
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
91
+ if m:
92
+ return m.group(1).strip()
93
+
94
+ # ③ nothing found
95
+ return text.strip()
96
+
97
+ def extract_description(predict: str) -> Optional[str]:
98
+ """
99
+ Extracts the content of the <answer>…</answer> block from `predict`.
100
+ Returns the inner text (with leading/trailing whitespace stripped),
101
+ or None if no <answer> tag is found.
102
+ """
103
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
104
+ if not match:
105
+ return predict
106
+ return match.group(1).strip()
107
+
108
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
109
+ answer = predict
110
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
111
+
112
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
113
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
114
+ # format_score = format_reward(predict)
115
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
116
+
117
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
118
+ return accuracy_score
119
+
120
+ def normalize_number(num_str):
121
+ try:
122
+ num_str = num_str.replace(',', '')
123
+ return float(num_str)
124
+ except Exception as e:
125
+ print(f"Error converting '{num_str}' to float: {e}")
126
+ return None
127
+
128
+ def wer(reference, hypothesis):
129
+ ref_words = reference.split()
130
+ hyp_words = hypothesis.split()
131
+ m = len(ref_words)
132
+ n = len(hyp_words)
133
+ d = [[0]*(n+1) for _ in range(m+1)]
134
+ for i in range(m+1):
135
+ d[i][0] = i
136
+ for j in range(n+1):
137
+ d[0][j] = j
138
+ for i in range(1, m+1):
139
+ for j in range(1, n+1):
140
+ if ref_words[i-1] == hyp_words[j-1]:
141
+ d[i][j] = d[i-1][j-1]
142
+ else:
143
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
144
+ return d[m][n] / max(1, m)
145
+
146
+
147
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
148
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
149
+ scores = scorer.score(reference, hypothesis)
150
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
151
+ return average_fmeasure
152
+
153
+ # print('Computing rewards now...')
154
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
155
+ # second_completions = kwargs.get("second_completions")
156
+ # second_contents = [comp[0]["content"] for comp in second_completions]
157
+ # print('second prompts', second_prompts)
158
+ # print('-'*10)
159
+ # print('second completions', second_completions)
160
+ # print('-'*10)
161
+
162
+ # import time
163
+ # time.sleep(30)
164
+ question_type = kwargs['problem_type'][0]
165
+ questions = kwargs['problem']
166
+
167
+ contents = [completion[0]["content"] for completion in completions]
168
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
169
+ rewards = []
170
+
171
+ # extracted_content_descriptions = [extract_description(ele) for ele in contents]
172
+ # extracted_content_answers = [extract_answer(ele) for ele in contents]
173
+ # model = kwargs.get("model") # may be None if called elsewhere
174
+ # tokenizer = kwargs.get("tokenizer")
175
+ # # (optional) example use: let the model score the generated answer
176
+ # if model is not None and tokenizer is not None:
177
+ # model.eval()
178
+ # description_inputs = [questions[index_count] + ' [SEP] ' + extracted_content_descriptions[index_count] for index_count in range(len(extracted_content_descriptions))]
179
+ # description_rewards = answerBERT.batch_predict(description_inputs, batch_size = 64)
180
+
181
+ for content, sol in zip(contents, solution):
182
+ # for content, sol, question in zip(contents, solution, questions):
183
+ # for content, sol, second_content in zip(contents, solution, second_completions):
184
+ try:
185
+ output_ans = extract_answer(content)
186
+ gt_ans = extract_answer(sol)
187
+ # description_extraction = extract_answer(second_content)
188
+ # if question_type == "multiple choice":
189
+ # reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
190
+ # elif question_type == "numerical":
191
+ # gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
192
+ # out_has_decimal = ("." in output_ans) or ("," in output_ans)
193
+ # if gt_has_decimal != out_has_decimal:
194
+ # reward = 0.0
195
+ # else:
196
+ # gt_number = normalize_number(gt_ans)
197
+ # out_number = normalize_number(output_ans)
198
+ # if gt_number is None or out_number is None:
199
+ # reward = 0.0
200
+ # else:
201
+ # reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
202
+ if question_type == "OCR":
203
+ # description_extraction = extract_answer(second_content)
204
+ # description_error_rate = wer(gt_ans, description_extraction)
205
+ description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
206
+ # error_rate = wer(gt_ans, output_ans)
207
+ answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
208
+ # reward = (1 - error_rate) + (1- description_error_rate)
209
+ # reward = max(0.0, min(2.0, reward))
210
+ # print('Extracted description: ', description_extraction)
211
+ # print('Generated answer: ', output_ans)
212
+ # print('Sol: ', gt_ans)
213
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
214
+ # print('-' * 10)
215
+ reward = description_pendat_reward + answer_pedant_reward
216
+ # elif question_type == "free-form":
217
+ # score = compute_rouge_score(gt_ans, output_ans)
218
+ # reward = max(0.0, min(1.0, score))
219
+ elif question_type == "regression":
220
+ gt_number = normalize_number(gt_ans)
221
+ out_number = normalize_number(output_ans)
222
+ if gt_number is None or out_number is None:
223
+ reward = 0.0
224
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
225
+ rel_diff = min(1.0, max(0.0, rel_diff))
226
+ reward = 1 - rel_diff
227
+ elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
228
+ # print('Extracted description: ', description_extraction)
229
+ # print('Generated answer: ', output_ans)
230
+ # print('Sol: ', gt_ans)
231
+
232
+ # description_reward = compute_math_score_single(description_extraction, gt_ans)
233
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
234
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
235
+ # print('-' * 10)
236
+ # reward = description_reward + answer_reward
237
+ reward = answer_reward
238
+ else:
239
+ print('Falling back to none rewards')
240
+ reward = 0.0
241
+ except Exception as e:
242
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
243
+ reward = 0.0
244
+
245
+ rewards.append(reward)
246
+
247
+ if os.getenv("DEBUG_MODE") == "true":
248
+ log_path = os.getenv("LOG_PATH")
249
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
250
+ with open(log_path, "a", encoding="utf-8") as f:
251
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
252
+ f.write(f"Content: {content}\n")
253
+ f.write(f"Solution: {sol}\n")
254
+
255
+ return rewards
256
+
257
+
258
+ def simple_format_reward(completions, **kwargs):
259
+ """Reward function that checks if the completion has a specific format."""
260
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
261
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
262
+ completion_contents = [completion[0]["content"] for completion in completions]
263
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
264
+ return [0.1 if match else 0.0 for match in matches]
265
+
266
+
267
+ reward_funcs_registry = {
268
+ "accuracy": accuracy_reward,
269
+ "format": simple_format_reward,
270
+ }
271
+
272
+ # SYSTEM_PROMPT = (
273
+ # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
274
+ # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
275
+ # "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
276
+ # "<think> reasoning process here </think><answer> answer here </answer>"
277
+ # )
278
+
279
+ SYSTEM_PROMPT = (
280
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
281
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
282
+ "and provide this step-by-step reasoning within <think> </think> tags. "
283
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
284
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
285
+ )
286
+
287
+
288
+ def main(script_args, training_args, model_args):
289
+ # Get reward functions
290
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
291
+
292
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
293
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
294
+ else:
295
+ # Load the dataset
296
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
297
+
298
+
299
+ # Format into conversation
300
+ def make_conversation(example):
301
+ return {
302
+ "prompt": [
303
+ {"role": "system", "content": SYSTEM_PROMPT},
304
+ {"role": "user", "content": example["problem"]},
305
+ ],
306
+ }
307
+
308
+
309
+ # QUESTION_TEMPLATE = (
310
+ # "{Question}\n"
311
+ # "Please think about this question as if you were a human pondering deeply. "
312
+ # "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
313
+ # "It's encouraged to include self-reflection or verification in the reasoning process. "
314
+ # "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
315
+ # )
316
+
317
+ QUESTION_TEMPLATE = (
318
+ "{Question}\n"
319
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
320
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
321
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
322
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
323
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
324
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
325
+ )
326
+
327
+
328
+ TYPE_TEMPLATE = {
329
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
330
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
331
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
332
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
333
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
334
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
335
+ }
336
+
337
+ def make_conversation_image(example):
338
+
339
+ return {
340
+ "prompt": [
341
+ {
342
+ "role": "user",
343
+ "content": [
344
+ {"type": "image"},
345
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
346
+ ],
347
+ },
348
+ ],
349
+ }
350
+
351
+
352
+ def make_conversation_video(example):
353
+ return {
354
+ "prompt": [
355
+ {
356
+ "role": "user",
357
+ "content": [
358
+ {"type": "video"},
359
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
360
+ ],
361
+ },
362
+ ],
363
+ }
364
+
365
+ def make_conversation_image_and_video(example):
366
+ if example["problem_type"] == 'multiple choice':
367
+ question = example['problem'] + "Options:\n"
368
+ for op in example["options"]:
369
+ question += op + "\n"
370
+ else:
371
+ question = example['problem']
372
+
373
+
374
+ # msg ={
375
+ # "prompt":
376
+ # [{
377
+ # "role": "user",
378
+ # "content": [
379
+ # {
380
+ # "type": example['data_type'],
381
+ # # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
382
+ # },
383
+ # {
384
+ # "type": "text",
385
+ # "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
386
+ # }
387
+ # ]
388
+ # }]
389
+ # }
390
+
391
+ msg ={
392
+ "prompt":
393
+ [{
394
+ "role": "user",
395
+ "content": [
396
+ {
397
+ "type": example['data_type'],
398
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
399
+ },
400
+ {
401
+ "type": "text",
402
+ "text": QUESTION_TEMPLATE.format(Question=question)
403
+ }
404
+ ]
405
+ }]
406
+ }
407
+
408
+ return msg
409
+
410
+
411
+ dataset = dataset.map(make_conversation_image_and_video)
412
+
413
+
414
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
415
+ print("using: ", trainer_cls)
416
+
417
+ # Initialize the GRPO trainer
418
+ trainer = trainer_cls(
419
+ model=model_args.model_name_or_path,
420
+ reward_funcs=reward_funcs,
421
+ args=training_args,
422
+ script_args=script_args,
423
+ train_dataset=dataset[script_args.dataset_train_split],
424
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
425
+ peft_config=get_peft_config(model_args),
426
+ attn_implementation=model_args.attn_implementation,
427
+ max_pixels=script_args.max_pixels,
428
+ min_pixels=script_args.min_pixels,
429
+ )
430
+
431
+ if training_args.resume_from_checkpoint is not None:
432
+ checkpoint = training_args.resume_from_checkpoint
433
+ trainer.train(resume_from_checkpoint=checkpoint)
434
+ else:
435
+ trainer.train()
436
+
437
+ # Save and push to hub
438
+ trainer.save_model(training_args.output_dir)
439
+ if training_args.push_to_hub:
440
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
441
+
442
+
443
+ if __name__ == "__main__":
444
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
445
+ script_args, training_args, model_args = parser.parse_args_and_config()
446
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-noInfo.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ from datasets import load_dataset, load_from_disk
22
+ from transformers import Qwen2VLForConditionalGeneration
23
+
24
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
25
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
26
+
27
+ from datasets import Dataset, DatasetDict
28
+
29
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
30
+ from rouge_score import rouge_scorer
31
+ from utils.math_cot_noInfo import *
32
+
33
+
34
+ @dataclass
35
+ class GRPOScriptArguments(ScriptArguments):
36
+ """
37
+ Script arguments for the GRPO training script.
38
+
39
+ Args:
40
+ reward_funcs (`list[str]`):
41
+ List of reward functions. Possible values: 'accuracy', 'format'.
42
+ """
43
+
44
+ reward_funcs: list[str] = field(
45
+ default_factory=lambda: ["accuracy"],
46
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
47
+ )
48
+ max_pixels: Optional[int] = field(
49
+ default=12845056,
50
+ metadata={"help": "Maximum number of pixels for the image"},
51
+ )
52
+ min_pixels: Optional[int] = field(
53
+ default=3136,
54
+ metadata={"help": "Minimum number of pixels for the image"},
55
+ )
56
+ temporal: Optional[bool] = field(
57
+ default=True,
58
+ metadata={"help": "whether using temporal GRPO"},
59
+ )
60
+ len_control: Optional[bool] = field(
61
+ default=True,
62
+ metadata={"help": "whether using length reward"},
63
+ )
64
+
65
+
66
+
67
+ def accuracy_reward(completions, solution, **kwargs):
68
+
69
+ def extract_answer(text):
70
+ pattern = r'<answer>\s*(.*?)\s*</answer>'
71
+ match = re.search(pattern, text, re.DOTALL)
72
+ if match:
73
+ return match.group(1).strip()
74
+ return ""
75
+
76
+ def normalize_number(num_str):
77
+ try:
78
+ num_str = num_str.replace(',', '')
79
+ return float(num_str)
80
+ except Exception as e:
81
+ print(f"Error converting '{num_str}' to float: {e}")
82
+ return None
83
+
84
+ def wer(reference, hypothesis):
85
+ ref_words = reference.split()
86
+ hyp_words = hypothesis.split()
87
+ m = len(ref_words)
88
+ n = len(hyp_words)
89
+ d = [[0]*(n+1) for _ in range(m+1)]
90
+ for i in range(m+1):
91
+ d[i][0] = i
92
+ for j in range(n+1):
93
+ d[0][j] = j
94
+ for i in range(1, m+1):
95
+ for j in range(1, n+1):
96
+ if ref_words[i-1] == hyp_words[j-1]:
97
+ d[i][j] = d[i-1][j-1]
98
+ else:
99
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
100
+ return d[m][n] / max(1, m)
101
+
102
+
103
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
104
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
105
+ scores = scorer.score(reference, hypothesis)
106
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
107
+ return average_fmeasure
108
+
109
+
110
+ question_type = kwargs['problem_type'][0]
111
+
112
+ contents = [completion[0]["content"] for completion in completions]
113
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
114
+ rewards = []
115
+
116
+ for content, sol in zip(contents, solution):
117
+
118
+ try:
119
+ output_ans = extract_answer(content)
120
+ gt_ans = extract_answer(sol)
121
+ if question_type == "multiple choice":
122
+ reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
123
+ elif question_type == "numerical":
124
+ gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
125
+ out_has_decimal = ("." in output_ans) or ("," in output_ans)
126
+ if gt_has_decimal != out_has_decimal:
127
+ reward = 0.0
128
+ else:
129
+ gt_number = normalize_number(gt_ans)
130
+ out_number = normalize_number(output_ans)
131
+ if gt_number is None or out_number is None:
132
+ reward = 0.0
133
+ else:
134
+ reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
135
+ elif question_type == "OCR":
136
+ error_rate = wer(gt_ans, output_ans)
137
+ reward = 1 - error_rate
138
+ reward = max(0.0, min(1.0, reward))
139
+ elif question_type == "free-form":
140
+ score = compute_rouge_score(gt_ans, output_ans)
141
+ reward = max(0.0, min(1.0, score))
142
+ elif question_type == "regression":
143
+ gt_number = normalize_number(gt_ans)
144
+ out_number = normalize_number(output_ans)
145
+ if gt_number is None or out_number is None:
146
+ reward = 0.0
147
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
148
+ rel_diff = min(1.0, max(0.0, rel_diff))
149
+ reward = 1 - rel_diff
150
+ elif question_type == 'math':
151
+ reward = compute_math_score_single(content, gt_ans)
152
+ else:
153
+ print('Falling back to none rewards')
154
+ reward = 0.0
155
+ except Exception as e:
156
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
157
+ reward = 0.0
158
+
159
+ rewards.append(reward)
160
+
161
+ if os.getenv("DEBUG_MODE") == "true":
162
+ log_path = os.getenv("LOG_PATH")
163
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
164
+ with open(log_path, "a", encoding="utf-8") as f:
165
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
166
+ f.write(f"Content: {content}\n")
167
+ f.write(f"Solution: {sol}\n")
168
+
169
+ return rewards
170
+
171
+
172
+ def format_reward(completions, **kwargs):
173
+ """Reward function that checks if the completion has a specific format."""
174
+ pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
175
+ completion_contents = [completion[0]["content"] for completion in completions]
176
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
177
+ return [1.0 if match else 0.0 for match in matches]
178
+
179
+
180
+ reward_funcs_registry = {
181
+ "accuracy": accuracy_reward,
182
+ # "format": format_reward,
183
+ }
184
+
185
+ SYSTEM_PROMPT = (
186
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
187
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
188
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
189
+ "<think> reasoning process here </think><answer> answer here </answer>"
190
+ )
191
+
192
+ # SYSTEM_PROMPT = (
193
+ # "A conversation between User and Assistant. The user provides a question about an image, "
194
+ # "and the Assistant is tasked with generating an exhaustive and detailed description of the image. "
195
+ # "The assistant should extract and describe all possible information from the image—including objects, numbers, text, and their relationships—"
196
+ # "and enclose this description within <info> </info> tags. "
197
+ # "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
198
+ # "and provide this step-by-step reasoning within <think> </think> tags. "
199
+ # "Finally, the assistant provides a single word or phrase answer within <answer> </answer> tags. "
200
+ # "The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
201
+ # )
202
+
203
+
204
+ def main(script_args, training_args, model_args):
205
+ # Get reward functions
206
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
207
+
208
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
209
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
210
+ else:
211
+ # Load the dataset
212
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
213
+
214
+
215
+ # Format into conversation
216
+ def make_conversation(example):
217
+ return {
218
+ "prompt": [
219
+ {"role": "system", "content": SYSTEM_PROMPT},
220
+ {"role": "user", "content": example["problem"]},
221
+ ],
222
+ }
223
+
224
+
225
+ QUESTION_TEMPLATE = (
226
+ "{Question}\n"
227
+ "Please think about this question as if you were a human pondering deeply. "
228
+ "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
229
+ "It's encouraged to include self-reflection or verification in the reasoning process. "
230
+ "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
231
+ )
232
+
233
+ # QUESTION_TEMPLATE = (
234
+ # "{Question}\n"
235
+ # "You are tasked with analyzing an image to generate an exhaustive and detailed description. "
236
+ # "Your goal is to extract and describe all possible information from the image, including but not limited to objects, numbers, text, and the relationships between these elements. "
237
+ # "The description should be as fine and detailed as possible, capturing every nuance, and should be enclosed within <info> </info> tags.\n"
238
+ # "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
239
+ # "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
240
+ # "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
241
+ # "The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
242
+ # )
243
+
244
+
245
+ TYPE_TEMPLATE = {
246
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
247
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
248
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
249
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
250
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
251
+ "math": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
252
+ }
253
+
254
+ def make_conversation_image(example):
255
+
256
+ return {
257
+ "prompt": [
258
+ {
259
+ "role": "user",
260
+ "content": [
261
+ {"type": "image"},
262
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
263
+ ],
264
+ },
265
+ ],
266
+ }
267
+
268
+
269
+ def make_conversation_video(example):
270
+ return {
271
+ "prompt": [
272
+ {
273
+ "role": "user",
274
+ "content": [
275
+ {"type": "video"},
276
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
277
+ ],
278
+ },
279
+ ],
280
+ }
281
+
282
+ def make_conversation_image_and_video(example):
283
+ if example["problem_type"] == 'multiple choice':
284
+ question = example['problem'] + "Options:\n"
285
+ for op in example["options"]:
286
+ question += op + "\n"
287
+ else:
288
+ question = example['problem']
289
+
290
+
291
+ msg ={
292
+ "prompt":
293
+ [{
294
+ "role": "user",
295
+ "content": [
296
+ {
297
+ "type": example['data_type'],
298
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
299
+ },
300
+ {
301
+ "type": "text",
302
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
303
+ }
304
+ ]
305
+ }]
306
+ }
307
+
308
+ return msg
309
+
310
+
311
+ dataset = dataset.map(make_conversation_image_and_video)
312
+
313
+
314
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
315
+ print("using: ", trainer_cls)
316
+
317
+ # Initialize the GRPO trainer
318
+ trainer = trainer_cls(
319
+ model=model_args.model_name_or_path,
320
+ reward_funcs=reward_funcs,
321
+ args=training_args,
322
+ script_args=script_args,
323
+ train_dataset=dataset[script_args.dataset_train_split],
324
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
325
+ peft_config=get_peft_config(model_args),
326
+ attn_implementation=model_args.attn_implementation,
327
+ max_pixels=script_args.max_pixels,
328
+ min_pixels=script_args.min_pixels,
329
+ )
330
+
331
+ if training_args.resume_from_checkpoint is not None:
332
+ checkpoint = training_args.resume_from_checkpoint
333
+ trainer.train(resume_from_checkpoint=checkpoint)
334
+ else:
335
+ trainer.train()
336
+
337
+ # Save and push to hub
338
+ trainer.save_model(training_args.output_dir)
339
+ if training_args.push_to_hub:
340
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
341
+
342
+
343
+ if __name__ == "__main__":
344
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
345
+ script_args, training_args, model_args = parser.parse_args_and_config()
346
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-qwenEval.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ import ray
18
+ from datetime import datetime
19
+ from dataclasses import dataclass, field
20
+
21
+ from datasets import load_dataset, load_from_disk
22
+ from transformers import Qwen2VLForConditionalGeneration
23
+
24
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
25
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
26
+
27
+ from datasets import Dataset, DatasetDict
28
+
29
+ from typing import Dict, List, Optional
30
+ from mathruler.grader import extract_boxed_content, grade_answer
31
+
32
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
33
+ from rouge_score import rouge_scorer
34
+ import torch
35
+ # from utils.gpt_eval import infer
36
+ # from utils.math_cot import *
37
+ from qa_metrics.pedant import PEDANT
38
+ from concurrent.futures import ProcessPoolExecutor
39
+ import os, subprocess, sys
40
+ # from qa_metrics.answerBERT import AnswerBertActor
41
+ # from utils.self_eval import *
42
+ from vllm import LLM, SamplingParams
43
+
44
+ pedant = None
45
+ # answerBERT = AnswerBertActor(device='cuda:7')
46
+
47
+ # curr_actor = VllmActor.options(num_gpus=1).remote("Qwen/Qwen2.5-3B-Instruct")
48
+
49
+ from typing import List
50
+ import os
51
+ import ray, os, subprocess, torch.distributed as dist
52
+
53
+ MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
54
+ MAX_LEN = 32_768
55
+ RAY_NS = "grpo_qwen_vllm"
56
+ RAY_TMP = "/tmp/ray"
57
+
58
+ # ------------------------------------------------------------
59
+ # 1. Define the Ray actor class *before* we ever create it
60
+ # (Ray just needs to see the decorator; it doesn’t need an
61
+ # active cluster at definition time)
62
+ # ------------------------------------------------------------
63
+ @ray.remote(num_gpus=1,resources={"gpu_7": 1})
64
+ class VllmActor:
65
+ def __init__(self, model_id):
66
+ self.engine = LLM(
67
+ model_id,
68
+ tensor_parallel_size=1,
69
+ gpu_memory_utilization=0.80,
70
+ max_model_len=MAX_LEN,
71
+ trust_remote_code=True,
72
+ dtype="bfloat16",
73
+ )
74
+ self.default = SamplingParams(top_p=0.9, temperature=0.7, max_tokens=128)
75
+
76
+ def generate_batch(self, prompts, sampling=None):
77
+ outs = self.engine.generate(prompts, sampling_params=sampling or self.default)
78
+ return [o.outputs[0].text for o in outs]
79
+
80
+ # ------------------------------------------------------------
81
+ # 2. Torch-DDP initialisation
82
+ # ------------------------------------------------------------
83
+ dist.init_process_group("nccl")
84
+ rank = dist.get_rank()
85
+
86
+ # ------------------------------------------------------------
87
+ # 3. Rank-0 starts the Ray head, others wait
88
+ # ------------------------------------------------------------
89
+ if rank == 1:
90
+ ray.init(
91
+ _temp_dir=RAY_TMP,
92
+ object_store_memory=1 * 1024**3,
93
+ namespace=RAY_NS,
94
+ include_dashboard=False,
95
+ resources={"gpu_7": 1}
96
+ )
97
+ # optional: confirm the head is up
98
+ # from ray._private.internal_api import wait_for_gcs
99
+ # wait_for_gcs()
100
+ dist.barrier() # ---- head definitely running here ----
101
+
102
+ # ------------------------------------------------------------
103
+ # 4. Non-zero ranks attach to the head
104
+ # ------------------------------------------------------------
105
+ if rank != 0:
106
+ ray.init(address="auto", _temp_dir=RAY_TMP, namespace=RAY_NS)
107
+
108
+ dist.barrier() # ---- every rank now in the cluster ----
109
+
110
+ # ------------------------------------------------------------
111
+ # 5. Create / look-up the VllmActor
112
+ # ------------------------------------------------------------
113
+ if rank == 1:
114
+ vllm_actor = (
115
+ VllmActor.options(name="vllm", namespace=RAY_NS, lifetime="detached")
116
+ .remote(MODEL_ID)
117
+ )
118
+ # block until the model finishes loading so other ranks don’t race
119
+ ray.get(vllm_actor.generate_batch.remote(["ping"]))
120
+ dist.barrier() # ---- actor fully alive everywhere ----
121
+
122
+ if rank != 0:
123
+ vllm_actor = ray.get_actor("vllm", namespace=RAY_NS)
124
+
125
+
126
+ eval_prompt_template = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. Provide your answer as a single final answer or a short phrase enclosed with <answer></answer>. If the question is a multiple choice, the final answer should be a single letter choice. \nText description: {}\nQuestion: {}'''
127
+
128
+ @dataclass
129
+ class GRPOScriptArguments(ScriptArguments):
130
+ """
131
+ Script arguments for the GRPO training script.
132
+
133
+ Args:
134
+ reward_funcs (`list[str]`):
135
+ List of reward functions. Possible values: 'accuracy', 'format'.
136
+ """
137
+
138
+ reward_funcs: list[str] = field(
139
+ default_factory=lambda: ["accuracy", "format"],
140
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
141
+ )
142
+
143
+ # reward_funcs: list[str] = field(
144
+ # default_factory=lambda: ["accuracy"],
145
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
146
+ # )
147
+ max_pixels: Optional[int] = field(
148
+ default=12845056,
149
+ metadata={"help": "Maximum number of pixels for the image"},
150
+ )
151
+ min_pixels: Optional[int] = field(
152
+ default=3136,
153
+ metadata={"help": "Minimum number of pixels for the image"},
154
+ )
155
+ temporal: Optional[bool] = field(
156
+ default=True,
157
+ metadata={"help": "whether using temporal GRPO"},
158
+ )
159
+ len_control: Optional[bool] = field(
160
+ default=True,
161
+ metadata={"help": "whether using length reward"},
162
+ )
163
+
164
+
165
+ def accuracy_reward(completions, solution, **kwargs):
166
+ def extract_answer(text: str) -> str:
167
+ """
168
+ 1) Try the full <answer> … </answer> block.
169
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
170
+ 3) Otherwise return the original text.
171
+ """
172
+ # ① normal case <answer> … </answer>
173
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
174
+ if m:
175
+ return m.group(1).strip()
176
+
177
+ # ② fallback <answer> … <end-of-string>
178
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
179
+ if m:
180
+ return m.group(1).strip()
181
+
182
+ # ③ nothing found
183
+ return text.strip()
184
+
185
+ def extract_description(predict: str) -> Optional[str]:
186
+ """
187
+ Extracts the content of the <answer>…</answer> block from `predict`.
188
+ Returns the inner text (with leading/trailing whitespace stripped),
189
+ or None if no <answer> tag is found.
190
+ """
191
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
192
+ if not match:
193
+ return predict
194
+ return match.group(1).strip()
195
+
196
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
197
+ answer = predict
198
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
199
+
200
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
201
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
202
+ # format_score = format_reward(predict)
203
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
204
+
205
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
206
+ return accuracy_score
207
+
208
+ def normalize_number(num_str):
209
+ try:
210
+ num_str = num_str.replace(',', '')
211
+ return float(num_str)
212
+ except Exception as e:
213
+ print(f"Error converting '{num_str}' to float: {e}")
214
+ return None
215
+
216
+ def wer(reference, hypothesis):
217
+ ref_words = reference.split()
218
+ hyp_words = hypothesis.split()
219
+ m = len(ref_words)
220
+ n = len(hyp_words)
221
+ d = [[0]*(n+1) for _ in range(m+1)]
222
+ for i in range(m+1):
223
+ d[i][0] = i
224
+ for j in range(n+1):
225
+ d[0][j] = j
226
+ for i in range(1, m+1):
227
+ for j in range(1, n+1):
228
+ if ref_words[i-1] == hyp_words[j-1]:
229
+ d[i][j] = d[i-1][j-1]
230
+ else:
231
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
232
+ return d[m][n] / max(1, m)
233
+
234
+
235
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
236
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
237
+ scores = scorer.score(reference, hypothesis)
238
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
239
+ return average_fmeasure
240
+
241
+ # print('Computing rewards now...')
242
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
243
+ # second_completions = kwargs.get("second_completions")
244
+ # second_contents = [comp[0]["content"] for comp in second_completions]
245
+ # print('second prompts', second_prompts)
246
+ # print('-'*10)
247
+ # print('second completions', second_completions)
248
+ # print('-'*10)
249
+
250
+ # import time
251
+ # time.sleep(30)
252
+ question_type = kwargs['problem_type'][0]
253
+ questions = kwargs['problem']
254
+
255
+ contents = [completion[0]["content"] for completion in completions]
256
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
257
+ rewards = []
258
+
259
+ extracted_content_descriptions = [extract_description(ele) for ele in contents]
260
+ description_eval_inputs = [eval_prompt_template.format(extracted_content_descriptions[count_index], questions[count_index]) for count_index in range(len(extracted_content_descriptions))]
261
+ # extracted_content_answers = [extract_answer(ele) for ele in contents]
262
+ # model = kwargs.get("model") # may be None if called elsewhere
263
+ # tokenizer = kwargs.get("tokenizer")
264
+ # # (optional) example use: let the model score the generated answer
265
+ # if model is not None and tokenizer is not None:
266
+ # model.eval()
267
+ # description_inputs = [questions[index_count] + ' [SEP] ' + extracted_content_descriptions[index_count] for index_count in range(len(extracted_content_descriptions))]
268
+ # description_rewards = answerBERT.batch_predict(description_inputs, batch_size = 64)
269
+ # description_rewards = [infer(extracted_content_descriptions[index_count], questions[index_count]) for index_count in range(len(extracted_content_descriptions))]
270
+ # description_outputs = generate_batch(description_eval_inputs)
271
+ print(len(description_eval_inputs))
272
+ print('Computing rewards...')
273
+ print('-'*10)
274
+ # description_outputs = ray.get(vllm_actor.generate.remote(description_eval_inputs))
275
+ description_outputs = ray.get(
276
+ vllm_actor.generate_batch_sequential.remote(description_eval_inputs,
277
+ batch_size=32) # tune to taste
278
+ )
279
+ print('Finish computing generating batch')
280
+ output_answers = [extract_answer(content) for content in contents]
281
+ gt_answers = [extract_answer(sol) for sol in solution]
282
+ description_rewards = [compute_math_score_single(description_outputs[curr_idx], gt_answers[curr_idx]) for curr_idx in range(len(description_outputs))]
283
+
284
+
285
+
286
+
287
+ # for content, sol, description_reward in zip(contents, solution, description_rewards):
288
+ # for content, sol, question in zip(contents, solution, questions):
289
+ # for content, sol, second_content in zip(contents, solution, second_completions):
290
+ for output_ans, gt_ans, description_reward in zip(output_answers, gt_answers, description_rewards):
291
+ try:
292
+ # output_ans = extract_answer(content)
293
+ # gt_ans = extract_answer(sol)
294
+ # description_extraction = extract_answer(second_content)
295
+ # if question_type == "multiple choice":
296
+ # reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
297
+ # elif question_type == "numerical":
298
+ # gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
299
+ # out_has_decimal = ("." in output_ans) or ("," in output_ans)
300
+ # if gt_has_decimal != out_has_decimal:
301
+ # reward = 0.0
302
+ # else:
303
+ # gt_number = normalize_number(gt_ans)
304
+ # out_number = normalize_number(output_ans)
305
+ # if gt_number is None or out_number is None:
306
+ # reward = 0.0
307
+ # else:
308
+ # reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
309
+ if question_type == "OCR":
310
+ # description_extraction = extract_answer(second_content)
311
+ # description_error_rate = wer(gt_ans, description_extraction)
312
+ # description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
313
+ # error_rate = wer(gt_ans, output_ans)
314
+ answer_pedant_reward = pedant.get_score(gt_ans, output_ans, questions[0])
315
+ # reward = (1 - error_rate) + (1- description_error_rate)
316
+ # reward = max(0.0, min(2.0, reward))
317
+ # print('Extracted description: ', description_extraction)
318
+ print('Generated answer: ', output_ans)
319
+ print('Sol: ', gt_ans)
320
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
321
+ print('-' * 10)
322
+ # reward = description_pendat_reward + answer_pedant_reward
323
+ reward = answer_pedant_reward
324
+ # elif question_type == "free-form":
325
+ # score = compute_rouge_score(gt_ans, output_ans)
326
+ # reward = max(0.0, min(1.0, score))
327
+ elif question_type == "regression":
328
+ gt_number = normalize_number(gt_ans)
329
+ out_number = normalize_number(output_ans)
330
+ if gt_number is None or out_number is None:
331
+ reward = 0.0
332
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
333
+ rel_diff = min(1.0, max(0.0, rel_diff))
334
+ reward = 1 - rel_diff
335
+ elif question_type == 'math' or question_type == 'unify' or question_type == "multiple choice" or question_type == "numerical":
336
+ # print('Extracted description: ', description_extraction)
337
+ print('Generated answer: ', output_ans)
338
+ print('Sol: ', gt_ans)
339
+
340
+ # description_reward = compute_math_score_single(description_extraction, gt_ans)
341
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
342
+ print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
343
+ print('-' * 10)
344
+ reward = description_reward + answer_reward
345
+ else:
346
+ print('Falling back to none rewards')
347
+ reward = 0.0
348
+ except Exception as e:
349
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
350
+ reward = 0.0
351
+
352
+ rewards.append(reward)
353
+
354
+ if os.getenv("DEBUG_MODE") == "true":
355
+ log_path = os.getenv("LOG_PATH")
356
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
357
+ with open(log_path, "a", encoding="utf-8") as f:
358
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
359
+ f.write(f"Content: {output_ans}\n")
360
+ f.write(f"Solution: {gt_ans}\n")
361
+
362
+ return rewards
363
+
364
+
365
+ def simple_format_reward(completions, **kwargs):
366
+ """Reward function that checks if the completion has a specific format."""
367
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
368
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
369
+ completion_contents = [completion[0]["content"] for completion in completions]
370
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
371
+ return [0.1 if match else 0.0 for match in matches]
372
+
373
+
374
+ reward_funcs_registry = {
375
+ "accuracy": accuracy_reward,
376
+ "format": simple_format_reward,
377
+ }
378
+
379
+
380
+ SYSTEM_PROMPT = (
381
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
382
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
383
+ "and provide this step-by-step reasoning within <think> </think> tags. "
384
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
385
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
386
+ )
387
+
388
+
389
+ def main(script_args, training_args, model_args):
390
+ # Get reward functions
391
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
392
+
393
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
394
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
395
+ else:
396
+ # Load the dataset
397
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
398
+
399
+
400
+ # Format into conversation
401
+ def make_conversation(example):
402
+ return {
403
+ "prompt": [
404
+ {"role": "system", "content": SYSTEM_PROMPT},
405
+ {"role": "user", "content": example["problem"]},
406
+ ],
407
+ }
408
+
409
+
410
+
411
+ QUESTION_TEMPLATE = (
412
+ "{Question}\n"
413
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
414
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
415
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
416
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
417
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
418
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
419
+ )
420
+
421
+
422
+ TYPE_TEMPLATE = {
423
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
424
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
425
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
426
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
427
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
428
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
429
+ }
430
+
431
+ def make_conversation_image(example):
432
+
433
+ return {
434
+ "prompt": [
435
+ {
436
+ "role": "user",
437
+ "content": [
438
+ {"type": "image"},
439
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
440
+ ],
441
+ },
442
+ ],
443
+ }
444
+
445
+
446
+ def make_conversation_video(example):
447
+ return {
448
+ "prompt": [
449
+ {
450
+ "role": "user",
451
+ "content": [
452
+ {"type": "video"},
453
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
454
+ ],
455
+ },
456
+ ],
457
+ }
458
+
459
+ def make_conversation_image_and_video(example):
460
+ if example["problem_type"] == 'multiple choice':
461
+ question = example['problem'] + "Options:\n"
462
+ for op in example["options"]:
463
+ question += op + "\n"
464
+ else:
465
+ question = example['problem']
466
+
467
+
468
+ msg ={
469
+ "prompt":
470
+ [{
471
+ "role": "user",
472
+ "content": [
473
+ {
474
+ "type": example['data_type'],
475
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
476
+ },
477
+ {
478
+ "type": "text",
479
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
480
+ }
481
+ ]
482
+ }]
483
+ }
484
+
485
+ return msg
486
+
487
+
488
+ dataset = dataset.map(make_conversation_image_and_video)
489
+
490
+
491
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
492
+ print("using: ", trainer_cls)
493
+
494
+ # Initialize the GRPO trainer
495
+ trainer = trainer_cls(
496
+ model=model_args.model_name_or_path,
497
+ reward_funcs=reward_funcs,
498
+ args=training_args,
499
+ script_args=script_args,
500
+ train_dataset=dataset[script_args.dataset_train_split],
501
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
502
+ peft_config=get_peft_config(model_args),
503
+ attn_implementation=model_args.attn_implementation,
504
+ max_pixels=script_args.max_pixels,
505
+ min_pixels=script_args.min_pixels,
506
+ )
507
+
508
+ if training_args.resume_from_checkpoint is not None:
509
+ checkpoint = training_args.resume_from_checkpoint
510
+ trainer.train(resume_from_checkpoint=checkpoint)
511
+ else:
512
+ trainer.train()
513
+
514
+ # Save and push to hub
515
+ trainer.save_model(training_args.output_dir)
516
+ if training_args.push_to_hub:
517
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
518
+
519
+
520
+ if __name__ == "__main__":
521
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
522
+ script_args, training_args, model_args = parser.parse_args_and_config()
523
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-selfEval.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ # from utils.math_cot import *
34
+ # from qa_metrics.pedant import PEDANT
35
+
36
+ # pedant = PEDANT()
37
+
38
+ '''
39
+ Alpha constant: When the description is wrong, but the final answer is right, the model is doing reward hacking,
40
+ so we give it a partial reward
41
+ '''
42
+ alpha = 1.0
43
+
44
+ @dataclass
45
+ class GRPOScriptArguments(ScriptArguments):
46
+ """
47
+ Script arguments for the GRPO training script.
48
+
49
+ Args:
50
+ reward_funcs (`list[str]`):
51
+ List of reward functions. Possible values: 'accuracy', 'format'.
52
+ """
53
+
54
+ reward_funcs: list[str] = field(
55
+ default_factory=lambda: ["accuracy", "format"],
56
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
57
+ )
58
+
59
+ # reward_funcs: list[str] = field(
60
+ # default_factory=lambda: ["accuracy"],
61
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
62
+ # )
63
+ max_pixels: Optional[int] = field(
64
+ default=12845056,
65
+ metadata={"help": "Maximum number of pixels for the image"},
66
+ )
67
+ min_pixels: Optional[int] = field(
68
+ default=3136,
69
+ metadata={"help": "Minimum number of pixels for the image"},
70
+ )
71
+ temporal: Optional[bool] = field(
72
+ default=True,
73
+ metadata={"help": "whether using temporal GRPO"},
74
+ )
75
+ len_control: Optional[bool] = field(
76
+ default=True,
77
+ metadata={"help": "whether using length reward"},
78
+ )
79
+
80
+
81
+
82
+ def accuracy_reward(completions, solution, **kwargs):
83
+ def extract_answer(text: str) -> str:
84
+ """
85
+ 1) Try the full <answer> … </answer> block.
86
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
87
+ 3) Otherwise return the original text.
88
+ """
89
+ # ① normal case <answer> … </answer>
90
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
91
+ if m:
92
+ return m.group(1).strip()
93
+
94
+ # ② fallback <answer> … <end-of-string>
95
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
96
+ if m:
97
+ return m.group(1).strip()
98
+
99
+ # ③ nothing found
100
+ return text.strip()
101
+
102
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
103
+ answer = predict
104
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
105
+
106
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
107
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
108
+ # format_score = format_reward(predict)
109
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
110
+
111
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
112
+ return accuracy_score
113
+
114
+ def normalize_number(num_str):
115
+ try:
116
+ num_str = num_str.replace(',', '')
117
+ return float(num_str)
118
+ except Exception as e:
119
+ print(f"Error converting '{num_str}' to float: {e}")
120
+ return None
121
+
122
+ def wer(reference, hypothesis):
123
+ ref_words = reference.split()
124
+ hyp_words = hypothesis.split()
125
+ m = len(ref_words)
126
+ n = len(hyp_words)
127
+ d = [[0]*(n+1) for _ in range(m+1)]
128
+ for i in range(m+1):
129
+ d[i][0] = i
130
+ for j in range(n+1):
131
+ d[0][j] = j
132
+ for i in range(1, m+1):
133
+ for j in range(1, n+1):
134
+ if ref_words[i-1] == hyp_words[j-1]:
135
+ d[i][j] = d[i-1][j-1]
136
+ else:
137
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
138
+ return d[m][n] / max(1, m)
139
+
140
+
141
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
142
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
143
+ scores = scorer.score(reference, hypothesis)
144
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
145
+ return average_fmeasure
146
+
147
+ # print('Computing rewards now...')
148
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
149
+ second_completions = kwargs.get("second_completions")
150
+ # second_contents = [comp[0]["content"] for comp in second_completions]
151
+ # print('second prompts', second_prompts)
152
+ # print('-'*10)
153
+ # print('second completions', second_completions)
154
+ # print('-'*10)
155
+
156
+ # import time
157
+ # time.sleep(30)
158
+ question_type = kwargs['problem_type'][0]
159
+ question = kwargs['problem'][0]
160
+
161
+ contents = [completion[0]["content"] for completion in completions]
162
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
163
+ rewards = []
164
+
165
+
166
+ # model = kwargs.get("model") # may be None if called elsewhere
167
+ # tokenizer = kwargs.get("tokenizer")
168
+
169
+ # # (optional) example use: let the model score the generated answer
170
+ # if model is not None and tokenizer is not None:
171
+ # model.eval()
172
+
173
+ # for content, sol in zip(contents, solution):
174
+ for content, sol, second_content in zip(contents, solution, second_completions):
175
+ try:
176
+ output_ans = extract_answer(content)
177
+ gt_ans = extract_answer(sol)
178
+ description_extraction = extract_answer(second_content)
179
+ # if question_type == "multiple choice":
180
+ # reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
181
+ # elif question_type == "numerical":
182
+ # gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
183
+ # out_has_decimal = ("." in output_ans) or ("," in output_ans)
184
+ # if gt_has_decimal != out_has_decimal:
185
+ # reward = 0.0
186
+ # else:
187
+ # gt_number = normalize_number(gt_ans)
188
+ # out_number = normalize_number(output_ans)
189
+ # if gt_number is None or out_number is None:
190
+ # reward = 0.0
191
+ # else:
192
+ # reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
193
+ # if question_type == "OCR":
194
+ # # description_extraction = extract_answer(second_content)
195
+ # # description_error_rate = wer(gt_ans, description_extraction)
196
+ # description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
197
+ # # error_rate = wer(gt_ans, output_ans)
198
+ # answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
199
+ # # reward = (1 - error_rate) + (1- description_error_rate)
200
+ # # reward = max(0.0, min(2.0, reward))
201
+ # print('Extracted description: ', description_extraction)
202
+ # print('Generated answer: ', output_ans)
203
+ # print('Sol: ', gt_ans)
204
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
205
+ # print('-' * 10)
206
+ # reward = description_pendat_reward + answer_pedant_reward
207
+ if question_type == "free-form":
208
+ score = compute_rouge_score(gt_ans, output_ans)
209
+ description_score = compute_rouge_score(gt_ans, description_extraction)
210
+ reward = max(0.0, min(1.0, score)) + max(0.0, min(1.0, description_score))
211
+ elif question_type == "regression":
212
+ gt_number = normalize_number(gt_ans)
213
+ out_number = normalize_number(output_ans)
214
+ description_number = normalize_number(description_extraction)
215
+ if gt_number is None or out_number is None:
216
+ reward = 0.0
217
+
218
+ if description_number is None:
219
+ description_reward = 0.0
220
+
221
+
222
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
223
+ rel_diff = min(1.0, max(0.0, rel_diff))
224
+
225
+ description_diff = (abs(description_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
226
+ description_diff = min(1.0, max(0.0, description_diff))
227
+
228
+ reward = (1 - rel_diff) + (1 - description_diff)
229
+ elif question_type == 'math' or question_type == 'unify' or question_type == 'multiple choice' or question_type == 'numerical':
230
+ description_reward = compute_math_score_single(description_extraction, gt_ans)
231
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
232
+
233
+ if description_reward == 0 and answer_reward == 1:
234
+ # Avoid multiplication to save computation
235
+ reward = alpha
236
+ else:
237
+ reward = description_reward + answer_reward
238
+
239
+ # print(f"Extracted description: {description_extraction} | Generated answer: {output_ans} | Sol: {gt_ans}")
240
+ # print(f'Description reward: {description_reward} | answer reward: {answer_reward} | final reward: {reward}')
241
+ # print('-' * 10)
242
+ else:
243
+ print('Falling back to none rewards')
244
+ reward = 0.0
245
+ except Exception as e:
246
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
247
+ reward = 0.0
248
+
249
+ rewards.append(reward)
250
+
251
+ if os.getenv("DEBUG_MODE") == "true":
252
+ log_path = os.getenv("LOG_PATH")
253
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
254
+ with open(log_path, "a", encoding="utf-8") as f:
255
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
256
+ f.write(f"Content: {content}\n")
257
+ f.write(f"Solution: {sol}\n")
258
+
259
+ return rewards
260
+
261
+
262
+ def simple_format_reward(completions, **kwargs):
263
+ """Reward function that checks if the completion has a specific format."""
264
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
265
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
266
+ completion_contents = [completion[0]["content"] for completion in completions]
267
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
268
+ return [0.1 if match else 0.0 for match in matches]
269
+
270
+
271
+ reward_funcs_registry = {
272
+ "accuracy": accuracy_reward,
273
+ "format": simple_format_reward,
274
+ }
275
+
276
+
277
+ SYSTEM_PROMPT = (
278
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
279
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
280
+ "and provide this step-by-step reasoning within <think> </think> tags. "
281
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
282
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions."
283
+ )
284
+
285
+
286
+ def main(script_args, training_args, model_args):
287
+ print('Start program..')
288
+ # Get reward functions
289
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
290
+
291
+
292
+ print('Loading dataset')
293
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
294
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
295
+ else:
296
+ # Load the dataset
297
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
298
+
299
+
300
+ # Format into conversation
301
+ def make_conversation(example):
302
+ return {
303
+ "prompt": [
304
+ {"role": "system", "content": SYSTEM_PROMPT},
305
+ {"role": "user", "content": example["problem"]},
306
+ ],
307
+ }
308
+
309
+
310
+ QUESTION_TEMPLATE = (
311
+ "{Question}\n"
312
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
313
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
314
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
315
+ "Provide your detailed, step-by-step reasoning based on the image and image description, and enclose this part within <think> </think> tags.\n"
316
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
317
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>. Please keep your final answer short and precise."
318
+ )
319
+
320
+
321
+ TYPE_TEMPLATE = {
322
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
323
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
324
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
325
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
326
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
327
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
328
+ }
329
+
330
+ ABS_Verify_Prompt = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide a step-by-step reasoning within <think> </think> tags, then provide your answer as a single final answer, single letter choice, or a short phrase ENCLOSED with <answer> </answer> tags. \nText description: {{Description}}\nQuestion: {Question}\nPlease only return the final single letter choice within the <answer> </answer> tags for multiple choice questions; Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags for numerical questions.'''
331
+
332
+ def make_conversation_image(example):
333
+
334
+ return {
335
+ "prompt": [
336
+ {
337
+ "role": "user",
338
+ "content": [
339
+ {"type": "image"},
340
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
341
+ ],
342
+ },
343
+ ],
344
+ }
345
+
346
+
347
+ def make_conversation_video(example):
348
+ return {
349
+ "prompt": [
350
+ {
351
+ "role": "user",
352
+ "content": [
353
+ {"type": "video"},
354
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
355
+ ],
356
+ },
357
+ ],
358
+ }
359
+
360
+ def make_conversation_image_and_video(example):
361
+ if example["problem_type"] == 'multiple choice':
362
+ question = example['problem'] + "Options:\n"
363
+ for op in example["options"]:
364
+ question += op + "\n"
365
+ else:
366
+ question = example['problem']
367
+
368
+
369
+ msg ={
370
+ "prompt":
371
+ [{
372
+ "role": "user",
373
+ "content": [
374
+ {
375
+ "type": example['data_type'],
376
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
377
+ },
378
+ {
379
+ "type": "text",
380
+ # "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
381
+ "text": QUESTION_TEMPLATE.format(Question=question)
382
+ }
383
+ ]
384
+ }]
385
+ }
386
+
387
+ return msg
388
+
389
+ def make_verify_conversation(example):
390
+ # ➊ build the question text
391
+ question = example["problem"]
392
+ if example["problem_type"] == "multiple choice":
393
+ question += "Options:\n" + "\n".join(example["options"])
394
+
395
+ # ➋ verification template + suffix (no if/else)
396
+ verify_text = (
397
+ ABS_Verify_Prompt.format(Question=question.replace("<image>", ""))
398
+ # + TYPE_TEMPLATE[example["problem_type"]] # ← one-liner, no branching
399
+ )
400
+
401
+ # ➌ conversation dict
402
+ conv_dict = {
403
+ "prompt": [
404
+ {
405
+ "role": "user",
406
+ "content": [{"type": "text", "text": verify_text}],
407
+ }
408
+ ]
409
+ }
410
+
411
+ # templated = maybe_apply_chat_template(conv_dict, processing_class)["prompt"]
412
+ # return {"verify_prompt": templated}
413
+ return {"verify_prompt": conv_dict}
414
+
415
+
416
+
417
+
418
+ print('Start mapping dataset')
419
+ dataset = dataset.map(make_conversation_image_and_video)
420
+ dataset = dataset.map(
421
+ make_verify_conversation,
422
+ desc="add description verify prompt",
423
+ )
424
+
425
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
426
+ print("using: ", trainer_cls)
427
+
428
+ # Initialize the GRPO trainer
429
+ trainer = trainer_cls(
430
+ model=model_args.model_name_or_path,
431
+ reward_funcs=reward_funcs,
432
+ args=training_args,
433
+ script_args=script_args,
434
+ train_dataset=dataset[script_args.dataset_train_split],
435
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
436
+ peft_config=get_peft_config(model_args),
437
+ attn_implementation=model_args.attn_implementation,
438
+ max_pixels=script_args.max_pixels,
439
+ min_pixels=script_args.min_pixels,
440
+ )
441
+
442
+ if training_args.resume_from_checkpoint is not None:
443
+ checkpoint = training_args.resume_from_checkpoint
444
+ trainer.train(resume_from_checkpoint=checkpoint)
445
+ else:
446
+ trainer.train()
447
+
448
+ # Save and push to hub
449
+ trainer.save_model(training_args.output_dir)
450
+ if training_args.push_to_hub:
451
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
452
+
453
+
454
+ if __name__ == "__main__":
455
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
456
+ script_args, training_args, model_args = parser.parse_args_and_config()
457
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot-selfEvalConst.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerSelfConst
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ # from utils.math_cot import *
34
+ # from qa_metrics.pedant import PEDANT
35
+
36
+ # pedant = PEDANT()
37
+
38
+ '''
39
+ Alpha constant: When the description is wrong, but the final answer is right, the model is doing reward hacking,
40
+ so we give it a partial reward
41
+ '''
42
+ alpha = 0.85
43
+
44
+ @dataclass
45
+ class GRPOScriptArguments(ScriptArguments):
46
+ """
47
+ Script arguments for the GRPO training script.
48
+
49
+ Args:
50
+ reward_funcs (`list[str]`):
51
+ List of reward functions. Possible values: 'accuracy', 'format'.
52
+ """
53
+
54
+ reward_funcs: list[str] = field(
55
+ default_factory=lambda: ["accuracy", "format"],
56
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
57
+ )
58
+
59
+ # reward_funcs: list[str] = field(
60
+ # default_factory=lambda: ["accuracy"],
61
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
62
+ # )
63
+ max_pixels: Optional[int] = field(
64
+ default=12845056,
65
+ metadata={"help": "Maximum number of pixels for the image"},
66
+ )
67
+ min_pixels: Optional[int] = field(
68
+ default=3136,
69
+ metadata={"help": "Minimum number of pixels for the image"},
70
+ )
71
+ temporal: Optional[bool] = field(
72
+ default=True,
73
+ metadata={"help": "whether using temporal GRPO"},
74
+ )
75
+ len_control: Optional[bool] = field(
76
+ default=True,
77
+ metadata={"help": "whether using length reward"},
78
+ )
79
+
80
+
81
+
82
+ def accuracy_reward(completions, solution, **kwargs):
83
+ def extract_answer(text: str) -> str:
84
+ """
85
+ 1) Try the full <answer> … </answer> block.
86
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
87
+ 3) Otherwise return the original text.
88
+ """
89
+ # ① normal case <answer> … </answer>
90
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
91
+ if m:
92
+ return m.group(1).strip()
93
+
94
+ # ② fallback <answer> … <end-of-string>
95
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
96
+ if m:
97
+ return m.group(1).strip()
98
+
99
+ # ③ nothing found
100
+ return text.strip()
101
+
102
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
103
+ answer = predict
104
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
105
+
106
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
107
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
108
+ # format_score = format_reward(predict)
109
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
110
+
111
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
112
+ return accuracy_score
113
+
114
+ def normalize_number(num_str):
115
+ try:
116
+ num_str = num_str.replace(',', '')
117
+ return float(num_str)
118
+ except Exception as e:
119
+ print(f"Error converting '{num_str}' to float: {e}")
120
+ return None
121
+
122
+ def wer(reference, hypothesis):
123
+ ref_words = reference.split()
124
+ hyp_words = hypothesis.split()
125
+ m = len(ref_words)
126
+ n = len(hyp_words)
127
+ d = [[0]*(n+1) for _ in range(m+1)]
128
+ for i in range(m+1):
129
+ d[i][0] = i
130
+ for j in range(n+1):
131
+ d[0][j] = j
132
+ for i in range(1, m+1):
133
+ for j in range(1, n+1):
134
+ if ref_words[i-1] == hyp_words[j-1]:
135
+ d[i][j] = d[i-1][j-1]
136
+ else:
137
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
138
+ return d[m][n] / max(1, m)
139
+
140
+
141
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
142
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
143
+ scores = scorer.score(reference, hypothesis)
144
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
145
+ return average_fmeasure
146
+
147
+ # print('Computing rewards now...')
148
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
149
+ second_completions = kwargs.get("second_completions")
150
+ # second_contents = [comp[0]["content"] for comp in second_completions]
151
+ # print('second prompts', second_prompts)
152
+ # print('-'*10)
153
+ # print('second completions', second_completions)
154
+ # print('-'*10)
155
+
156
+ # import time
157
+ # time.sleep(30)
158
+ question_type = kwargs['problem_type'][0]
159
+ question = kwargs['problem'][0]
160
+
161
+ contents = [completion[0]["content"] for completion in completions]
162
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
163
+ rewards = []
164
+
165
+
166
+ # model = kwargs.get("model") # may be None if called elsewhere
167
+ # tokenizer = kwargs.get("tokenizer")
168
+
169
+ # # (optional) example use: let the model score the generated answer
170
+ # if model is not None and tokenizer is not None:
171
+ # model.eval()
172
+
173
+ # for content, sol in zip(contents, solution):
174
+ for content, sol, second_content in zip(contents, solution, second_completions):
175
+ try:
176
+ output_ans = extract_answer(content)
177
+ gt_ans = extract_answer(sol)
178
+ description_extraction = extract_answer(second_content)
179
+ # if question_type == "multiple choice":
180
+ # reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
181
+ # elif question_type == "numerical":
182
+ # gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
183
+ # out_has_decimal = ("." in output_ans) or ("," in output_ans)
184
+ # if gt_has_decimal != out_has_decimal:
185
+ # reward = 0.0
186
+ # else:
187
+ # gt_number = normalize_number(gt_ans)
188
+ # out_number = normalize_number(output_ans)
189
+ # if gt_number is None or out_number is None:
190
+ # reward = 0.0
191
+ # else:
192
+ # reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
193
+ # if question_type == "OCR":
194
+ # # description_extraction = extract_answer(second_content)
195
+ # # description_error_rate = wer(gt_ans, description_extraction)
196
+ # description_pendat_reward = pedant.get_score(gt_ans, description_extraction, question)
197
+ # # error_rate = wer(gt_ans, output_ans)
198
+ # answer_pedant_reward = pedant.get_score(gt_ans, output_ans, question)
199
+ # # reward = (1 - error_rate) + (1- description_error_rate)
200
+ # # reward = max(0.0, min(2.0, reward))
201
+ # print('Extracted description: ', description_extraction)
202
+ # print('Generated answer: ', output_ans)
203
+ # print('Sol: ', gt_ans)
204
+ # print(f'Description reward: {description_reward}; answer reward: {answer_reward}')
205
+ # print('-' * 10)
206
+ # reward = description_pendat_reward + answer_pedant_reward
207
+ if question_type == "free-form":
208
+ score = compute_rouge_score(gt_ans, output_ans)
209
+ description_score = compute_rouge_score(gt_ans, description_extraction)
210
+ reward = max(0.0, min(1.0, score)) + max(0.0, min(1.0, description_score))
211
+ elif question_type == "regression":
212
+ gt_number = normalize_number(gt_ans)
213
+ out_number = normalize_number(output_ans)
214
+ description_number = normalize_number(description_extraction)
215
+ if gt_number is None or out_number is None:
216
+ reward = 0.0
217
+
218
+ if description_number is None:
219
+ description_reward = 0.0
220
+
221
+
222
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
223
+ rel_diff = min(1.0, max(0.0, rel_diff))
224
+
225
+ description_diff = (abs(description_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
226
+ description_diff = min(1.0, max(0.0, description_diff))
227
+
228
+ reward = (1 - rel_diff) + (1 - description_diff)
229
+ elif question_type == 'math' or question_type == 'unify' or question_type == 'multiple choice' or question_type == 'numerical':
230
+ description_reward = compute_math_score_single(description_extraction, gt_ans)
231
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
232
+
233
+ if description_reward == 0 and answer_reward == 1:
234
+ # Avoid multiplication to save computation
235
+ reward = alpha
236
+ else:
237
+ reward = description_reward + answer_reward
238
+
239
+ # print(f"Extracted description: {description_extraction} | Generated answer: {output_ans} | Sol: {gt_ans}")
240
+ # print(f'Description reward: {description_reward} | answer reward: {answer_reward} | final reward: {reward}')
241
+ # print('-' * 10)
242
+ else:
243
+ print('Falling back to none rewards')
244
+ reward = 0.0
245
+ except Exception as e:
246
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
247
+ reward = 0.0
248
+
249
+ rewards.append(reward)
250
+
251
+ if os.getenv("DEBUG_MODE") == "true":
252
+ log_path = os.getenv("LOG_PATH")
253
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
254
+ with open(log_path, "a", encoding="utf-8") as f:
255
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
256
+ f.write(f"Content: {content}\n")
257
+ f.write(f"Solution: {sol}\n")
258
+
259
+ return rewards
260
+
261
+
262
+ def simple_format_reward(completions, **kwargs):
263
+ """Reward function that checks if the completion has a specific format."""
264
+ # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
265
+ pattern = r"<des>.*?</des>\s*<think>.*?</think>\s*<answer>.*?</answer>"
266
+ completion_contents = [completion[0]["content"] for completion in completions]
267
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
268
+ return [0.1 if match else 0.0 for match in matches]
269
+
270
+
271
+ reward_funcs_registry = {
272
+ "accuracy": accuracy_reward,
273
+ "format": simple_format_reward,
274
+ }
275
+
276
+
277
+ SYSTEM_PROMPT = (
278
+ "A conversation between User and Assistant. After the user asks a question about an image, write a rich, self-contained description of that image—detailed enough that someone could answer the question from the description alone, without ever seeing the image. Enclose the entire description in <des> </des> tags."
279
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
280
+ "and provide this step-by-step reasoning within <think> </think> tags. "
281
+ "Finally, the assistant provides a single word, single letter choice, or phrase answer within <answer> </answer> tags."
282
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
283
+ )
284
+
285
+
286
+ def main(script_args, training_args, model_args):
287
+ print('Start program..')
288
+ # Get reward functions
289
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
290
+
291
+
292
+ print('Loading dataset')
293
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
294
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
295
+ else:
296
+ # Load the dataset
297
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
298
+
299
+
300
+ # Format into conversation
301
+ def make_conversation(example):
302
+ return {
303
+ "prompt": [
304
+ {"role": "system", "content": SYSTEM_PROMPT},
305
+ {"role": "user", "content": example["problem"]},
306
+ ],
307
+ }
308
+
309
+
310
+ QUESTION_TEMPLATE = (
311
+ "{Question}\n"
312
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description to answer a question. "
313
+ "Analyze the image and produce a thorough, self-contained description—detailed enough for someone to answer the question using the description alone. Wrap the entire description in <des> </des> tags.\n"
314
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
315
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
316
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
317
+ "The output format should be: <des> image description here </des> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
318
+ )
319
+
320
+
321
+ TYPE_TEMPLATE = {
322
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
323
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
324
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
325
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
326
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
327
+ "math": " Please provide the final exact answer (single option letter for multiple choice) within the <answer> </answer> tags.",
328
+ }
329
+
330
+ ABS_Verify_Prompt = '''You are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide a step-by-step reasoning within <think> </think> tags, then provide your answer as a single final answer, single letter choice, or a short phrase ENCLOSED with <answer> </answer> tags. \nText description: {{Description}}\nQuestion: {Question}'''
331
+
332
+ def make_conversation_image(example):
333
+
334
+ return {
335
+ "prompt": [
336
+ {
337
+ "role": "user",
338
+ "content": [
339
+ {"type": "image"},
340
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
341
+ ],
342
+ },
343
+ ],
344
+ }
345
+
346
+
347
+ def make_conversation_video(example):
348
+ return {
349
+ "prompt": [
350
+ {
351
+ "role": "user",
352
+ "content": [
353
+ {"type": "video"},
354
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
355
+ ],
356
+ },
357
+ ],
358
+ }
359
+
360
+ def make_conversation_image_and_video(example):
361
+ if example["problem_type"] == 'multiple choice':
362
+ question = example['problem'] + "Options:\n"
363
+ for op in example["options"]:
364
+ question += op + "\n"
365
+ else:
366
+ question = example['problem']
367
+
368
+
369
+ msg ={
370
+ "prompt":
371
+ [{
372
+ "role": "user",
373
+ "content": [
374
+ {
375
+ "type": example['data_type'],
376
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
377
+ },
378
+ {
379
+ "type": "text",
380
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
381
+ }
382
+ ]
383
+ }]
384
+ }
385
+
386
+ return msg
387
+
388
+ def make_verify_conversation(example):
389
+ # ➊ build the question text
390
+ question = example["problem"]
391
+ if example["problem_type"] == "multiple choice":
392
+ question += "Options:\n" + "\n".join(example["options"])
393
+
394
+ # ➋ verification template + suffix (no if/else)
395
+ verify_text = (
396
+ ABS_Verify_Prompt.format(Question=question.replace("<image>", ""))
397
+ + TYPE_TEMPLATE[example["problem_type"]] # ← one-liner, no branching
398
+ )
399
+
400
+ # ➌ conversation dict
401
+ conv_dict = {
402
+ "prompt": [
403
+ {
404
+ "role": "user",
405
+ "content": [{"type": "text", "text": verify_text}],
406
+ }
407
+ ]
408
+ }
409
+
410
+ # templated = maybe_apply_chat_template(conv_dict, processing_class)["prompt"]
411
+ # return {"verify_prompt": templated}
412
+ return {"verify_prompt": conv_dict}
413
+
414
+
415
+
416
+
417
+ print('Start mapping dataset')
418
+ dataset = dataset.map(make_conversation_image_and_video)
419
+ dataset = dataset.map(
420
+ make_verify_conversation,
421
+ desc="add description verify prompt",
422
+ )
423
+
424
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerSelfConst
425
+ print("using: ", trainer_cls)
426
+
427
+ # Initialize the GRPO trainer
428
+ trainer = trainer_cls(
429
+ model=model_args.model_name_or_path,
430
+ reward_funcs=reward_funcs,
431
+ args=training_args,
432
+ script_args=script_args,
433
+ train_dataset=dataset[script_args.dataset_train_split],
434
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
435
+ peft_config=get_peft_config(model_args),
436
+ attn_implementation=model_args.attn_implementation,
437
+ max_pixels=script_args.max_pixels,
438
+ min_pixels=script_args.min_pixels,
439
+ )
440
+
441
+ if training_args.resume_from_checkpoint is not None:
442
+ checkpoint = training_args.resume_from_checkpoint
443
+ trainer.train(resume_from_checkpoint=checkpoint)
444
+ else:
445
+ trainer.train()
446
+
447
+ # Save and push to hub
448
+ trainer.save_model(training_args.output_dir)
449
+ if training_args.push_to_hub:
450
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
451
+
452
+
453
+ if __name__ == "__main__":
454
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
455
+ script_args, training_args, model_args = parser.parse_args_and_config()
456
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-cot.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ from datasets import load_dataset, load_from_disk
22
+ from transformers import Qwen2VLForConditionalGeneration
23
+
24
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
25
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
26
+
27
+ from datasets import Dataset, DatasetDict
28
+
29
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
30
+ from rouge_score import rouge_scorer
31
+ from utils.math_cot import *
32
+
33
+
34
+ @dataclass
35
+ class GRPOScriptArguments(ScriptArguments):
36
+ """
37
+ Script arguments for the GRPO training script.
38
+
39
+ Args:
40
+ reward_funcs (`list[str]`):
41
+ List of reward functions. Possible values: 'accuracy', 'format'.
42
+ """
43
+
44
+ # reward_funcs: list[str] = field(
45
+ # default_factory=lambda: ["accuracy", "format"],
46
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
47
+ # )
48
+
49
+ reward_funcs: list[str] = field(
50
+ default_factory=lambda: ["accuracy"],
51
+ metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
52
+ )
53
+ max_pixels: Optional[int] = field(
54
+ default=12845056,
55
+ metadata={"help": "Maximum number of pixels for the image"},
56
+ )
57
+ min_pixels: Optional[int] = field(
58
+ default=3136,
59
+ metadata={"help": "Minimum number of pixels for the image"},
60
+ )
61
+ temporal: Optional[bool] = field(
62
+ default=True,
63
+ metadata={"help": "whether using temporal GRPO"},
64
+ )
65
+ len_control: Optional[bool] = field(
66
+ default=True,
67
+ metadata={"help": "whether using length reward"},
68
+ )
69
+
70
+
71
+
72
+ def accuracy_reward(completions, solution, **kwargs):
73
+
74
+ def extract_answer(text):
75
+ pattern = r'<answer>\s*(.*?)\s*</answer>'
76
+ match = re.search(pattern, text, re.DOTALL)
77
+ if match:
78
+ return match.group(1).strip()
79
+ return ""
80
+
81
+ def normalize_number(num_str):
82
+ try:
83
+ num_str = num_str.replace(',', '')
84
+ return float(num_str)
85
+ except Exception as e:
86
+ print(f"Error converting '{num_str}' to float: {e}")
87
+ return None
88
+
89
+ def wer(reference, hypothesis):
90
+ ref_words = reference.split()
91
+ hyp_words = hypothesis.split()
92
+ m = len(ref_words)
93
+ n = len(hyp_words)
94
+ d = [[0]*(n+1) for _ in range(m+1)]
95
+ for i in range(m+1):
96
+ d[i][0] = i
97
+ for j in range(n+1):
98
+ d[0][j] = j
99
+ for i in range(1, m+1):
100
+ for j in range(1, n+1):
101
+ if ref_words[i-1] == hyp_words[j-1]:
102
+ d[i][j] = d[i-1][j-1]
103
+ else:
104
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
105
+ return d[m][n] / max(1, m)
106
+
107
+
108
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
109
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
110
+ scores = scorer.score(reference, hypothesis)
111
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
112
+ return average_fmeasure
113
+
114
+
115
+ question_type = kwargs['problem_type'][0]
116
+
117
+ contents = [completion[0]["content"] for completion in completions]
118
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
119
+ rewards = []
120
+
121
+ for content, sol in zip(contents, solution):
122
+
123
+ try:
124
+ output_ans = extract_answer(content)
125
+ gt_ans = extract_answer(sol)
126
+ if question_type == "multiple choice":
127
+ reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
128
+ elif question_type == "numerical":
129
+ gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
130
+ out_has_decimal = ("." in output_ans) or ("," in output_ans)
131
+ if gt_has_decimal != out_has_decimal:
132
+ reward = 0.0
133
+ else:
134
+ gt_number = normalize_number(gt_ans)
135
+ out_number = normalize_number(output_ans)
136
+ if gt_number is None or out_number is None:
137
+ reward = 0.0
138
+ else:
139
+ reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
140
+ elif question_type == "OCR":
141
+ error_rate = wer(gt_ans, output_ans)
142
+ reward = 1 - error_rate
143
+ reward = max(0.0, min(1.0, reward))
144
+ elif question_type == "free-form":
145
+ score = compute_rouge_score(gt_ans, output_ans)
146
+ reward = max(0.0, min(1.0, score))
147
+ elif question_type == "regression":
148
+ gt_number = normalize_number(gt_ans)
149
+ out_number = normalize_number(output_ans)
150
+ if gt_number is None or out_number is None:
151
+ reward = 0.0
152
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
153
+ rel_diff = min(1.0, max(0.0, rel_diff))
154
+ reward = 1 - rel_diff
155
+ elif question_type == 'math':
156
+ reward = compute_math_score_single(content, gt_ans)
157
+ else:
158
+ print('Falling back to none rewards')
159
+ reward = 0.0
160
+ except Exception as e:
161
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
162
+ reward = 0.0
163
+
164
+ rewards.append(reward)
165
+
166
+ if os.getenv("DEBUG_MODE") == "true":
167
+ log_path = os.getenv("LOG_PATH")
168
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
169
+ with open(log_path, "a", encoding="utf-8") as f:
170
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
171
+ f.write(f"Content: {content}\n")
172
+ f.write(f"Solution: {sol}\n")
173
+
174
+ return rewards
175
+
176
+
177
+ def format_reward(completions, **kwargs):
178
+ """Reward function that checks if the completion has a specific format."""
179
+ pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
180
+ completion_contents = [completion[0]["content"] for completion in completions]
181
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
182
+ return [1.0 if match else 0.0 for match in matches]
183
+
184
+
185
+ reward_funcs_registry = {
186
+ "accuracy": accuracy_reward,
187
+ # "format": 0,
188
+ }
189
+
190
+ # SYSTEM_PROMPT = (
191
+ # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
192
+ # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
193
+ # "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
194
+ # "<think> reasoning process here </think><answer> answer here </answer>"
195
+ # )
196
+
197
+ SYSTEM_PROMPT = (
198
+ "A conversation between User and Assistant. The user provides a question about an image, "
199
+ "and the Assistant is tasked with generating an exhaustive and detailed description of the image. "
200
+ "The assistant should extract and describe all possible information from the image—including objects, numbers, text, and their relationships—"
201
+ "and enclose this description within <info> </info> tags. "
202
+ "Next, the assistant should think deeply about the reasoning process, engaging in an internal dialogue and self-reflection, "
203
+ "and provide this step-by-step reasoning within <think> </think> tags. "
204
+ "Finally, the assistant provides a single word or phrase answer within <answer> </answer> tags. "
205
+ "The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>."
206
+ )
207
+
208
+
209
+ def main(script_args, training_args, model_args):
210
+ # Get reward functions
211
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
212
+
213
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
214
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
215
+ else:
216
+ # Load the dataset
217
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
218
+
219
+
220
+ # Format into conversation
221
+ def make_conversation(example):
222
+ return {
223
+ "prompt": [
224
+ {"role": "system", "content": SYSTEM_PROMPT},
225
+ {"role": "user", "content": example["problem"]},
226
+ ],
227
+ }
228
+
229
+
230
+ # QUESTION_TEMPLATE = (
231
+ # "{Question}\n"
232
+ # "Please think about this question as if you were a human pondering deeply. "
233
+ # "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
234
+ # "It's encouraged to include self-reflection or verification in the reasoning process. "
235
+ # "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
236
+ # )
237
+
238
+ QUESTION_TEMPLATE = (
239
+ "{Question}\n"
240
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description. "
241
+ "Your goal is to extract and describe all possible information from the image, including but not limited to objects, numbers, text, and the relationships between these elements. "
242
+ "The description should be as fine and detailed as possible, capturing every nuance, and should be enclosed within <info> </info> tags.\n"
243
+ "Next, engage in an internal dialogue as if you were a human pondering deeply—use expressions such as 'let me think', 'wait', 'hmm', 'oh, I see', 'let's break it down', etc., and include self-reflection or verification in your reasoning process. "
244
+ "Provide your detailed, step-by-step reasoning based on the image description, and enclose this part within <think> </think> tags.\n"
245
+ "Finally, provide a single word or phrase answer to the question, enclosed within <answer> </answer> tags.\n"
246
+ "The output format should be: <info> image description here </info> <think> reasoning process here </think> <answer> FINAL ANSWER here </answer>"
247
+ )
248
+
249
+
250
+ TYPE_TEMPLATE = {
251
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
252
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
253
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
254
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
255
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
256
+ "math": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
257
+ }
258
+
259
+ def make_conversation_image(example):
260
+
261
+ return {
262
+ "prompt": [
263
+ {
264
+ "role": "user",
265
+ "content": [
266
+ {"type": "image"},
267
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
268
+ ],
269
+ },
270
+ ],
271
+ }
272
+
273
+
274
+ def make_conversation_video(example):
275
+ return {
276
+ "prompt": [
277
+ {
278
+ "role": "user",
279
+ "content": [
280
+ {"type": "video"},
281
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
282
+ ],
283
+ },
284
+ ],
285
+ }
286
+
287
+ def make_conversation_image_and_video(example):
288
+ if example["problem_type"] == 'multiple choice':
289
+ question = example['problem'] + "Options:\n"
290
+ for op in example["options"]:
291
+ question += op + "\n"
292
+ else:
293
+ question = example['problem']
294
+
295
+
296
+ msg ={
297
+ "prompt":
298
+ [{
299
+ "role": "user",
300
+ "content": [
301
+ {
302
+ "type": example['data_type'],
303
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
304
+ },
305
+ {
306
+ "type": "text",
307
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
308
+ }
309
+ ]
310
+ }]
311
+ }
312
+
313
+ return msg
314
+
315
+
316
+ dataset = dataset.map(make_conversation_image_and_video)
317
+
318
+
319
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
320
+ print("using: ", trainer_cls)
321
+
322
+ # Initialize the GRPO trainer
323
+ trainer = trainer_cls(
324
+ model=model_args.model_name_or_path,
325
+ reward_funcs=reward_funcs,
326
+ args=training_args,
327
+ script_args=script_args,
328
+ train_dataset=dataset[script_args.dataset_train_split],
329
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
330
+ peft_config=get_peft_config(model_args),
331
+ attn_implementation=model_args.attn_implementation,
332
+ max_pixels=script_args.max_pixels,
333
+ min_pixels=script_args.min_pixels,
334
+ )
335
+
336
+ if training_args.resume_from_checkpoint is not None:
337
+ checkpoint = training_args.resume_from_checkpoint
338
+ trainer.train(resume_from_checkpoint=checkpoint)
339
+ else:
340
+ trainer.train()
341
+
342
+ # Save and push to hub
343
+ trainer.save_model(training_args.output_dir)
344
+ if training_args.push_to_hub:
345
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
346
+
347
+
348
+ if __name__ == "__main__":
349
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
350
+ script_args, training_args, model_args = parser.parse_args_and_config()
351
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo-description-LLMEval.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+
20
+ from datasets import load_dataset, load_from_disk
21
+ from transformers import Qwen2VLForConditionalGeneration
22
+ from openai import OpenAI
23
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModifiedOrig
24
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
25
+
26
+ from datasets import Dataset, DatasetDict
27
+
28
+ from typing import Dict, List, Optional
29
+ from mathruler.grader import extract_boxed_content, grade_answer
30
+
31
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
32
+ from rouge_score import rouge_scorer
33
+ # from utils.gpt_eval import infer
34
+ # from utils.math_cot import *
35
+ # from qa_metrics.pedant import PEDANT
36
+ # from qa_metrics.answerBERT import AnswerBertActor
37
+
38
+ # pedant = PEDANT()
39
+ # answerBERT = AnswerBertActor(device='cuda:7')
40
+
41
+ alpha = 1.0
42
+
43
+ TYPE_TEMPLATE = {
44
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) in \\boxed{}.",
45
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) in \\boxed{}.",
46
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer in \\boxed{}.",
47
+ "free-form": " Please provide your text answer in \\boxed{}.",
48
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) in \\boxed{}.",
49
+ "math": " Please provide the final exact answer (single option letter for multiple choice) in \\boxed{}.",
50
+ }
51
+
52
+ '''
53
+ gpt infer
54
+ '''
55
+ import os
56
+ from openai import AzureOpenAI
57
+ import time
58
+
59
+ import base64
60
+ from mimetypes import guess_type
61
+
62
+
63
+ def azure_gpt4(messages, model):
64
+ outputs = []
65
+ for message in messages:
66
+ input_prompt = [
67
+ { "role": "system", "content": "You are a helpful assistant." },
68
+ { "role": "user", "content": [
69
+ {
70
+ "type": "text",
71
+ "text": message["instruction"]
72
+ },
73
+ # {
74
+ # "type": "image_url",
75
+ # "image_url": {
76
+ # "url": message["image"]
77
+ # }
78
+ # }
79
+ ]}
80
+ ]
81
+ ## try N times if API exceed limit ...
82
+ for i in range(10):
83
+ try:
84
+ output = client.chat.completions.create(
85
+ model=model, messages=input_prompt, max_tokens=2000
86
+ )
87
+
88
+ output_text = output.choices[0].message.content
89
+ break ## exit if successful
90
+
91
+ except Exception as e:
92
+ print(f'Index {i} got error message: {e}')
93
+ output_text = ''
94
+ time.sleep(3)
95
+
96
+ outputs.append(output_text)
97
+
98
+ return outputs
99
+
100
+
101
+ client = AzureOpenAI(
102
+ api_key = "83f30a2a22324395b854bd343db38d85",
103
+ api_version = "2024-08-01-preview",
104
+ azure_endpoint = "https://francecentral.api.cognitive.microsoft.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
105
+ )
106
+
107
+ model = "gpt-4o"
108
+ prompt_template = '''Text description: {text}\nQuestion: {question}\nYou are provided a text description of a problem and a question. Determine the answer to the question based on the text description. First provide an internal step-by-step reasoning within <think> </think> tags, then provide a single word or phrase answer in \\boxed{}.'''
109
+
110
+
111
+ # client = OpenAI(
112
+ # base_url="http://29.81.244.54:8080/v1", # your vLLM server
113
+ # api_key="ANYKEY", # if you set --api-key when launching
114
+ # )
115
+
116
+ client = OpenAI(
117
+ base_url="http://29.81.224.188:8080/v1", # your vLLM server
118
+ api_key="ANYKEY", # if you set --api-key when launching
119
+ )
120
+
121
+ def chat_batch(
122
+ client,
123
+ all_message_batches: List[List[Dict[str, str]]],
124
+ *,
125
+ # model: str = "Qwen2.5-32B-Instruct",
126
+ model: str = "Qwen2.5-32B-finetune",
127
+ max_workers: int = 8,
128
+ retries: int = 2,
129
+ backoff: float = 0.5,
130
+ timeout: Optional[float] = None,
131
+ ) -> List[str]:
132
+ """
133
+ Send many chat requests in parallel and return replies as a list of strings,
134
+ preserving the order of `all_message_batches`.
135
+ """
136
+
137
+ def _chat_once_with_retry(messages: List[Dict[str, str]]) -> str:
138
+ last_err: Optional[BaseException] = None
139
+ for attempt in range(retries + 1):
140
+ try:
141
+ resp = client.chat.completions.create(
142
+ model=model,
143
+ messages=messages,
144
+ timeout=timeout,
145
+ )
146
+ # Different SDKs expose content slightly differently; handle common cases.
147
+ choice = resp.choices[0]
148
+ if hasattr(choice, "message") and getattr(choice.message, "content", None) is not None:
149
+ return choice.message.content
150
+ if hasattr(choice, "text") and choice.text is not None:
151
+ return choice.text
152
+ # Fallback to stringifying the choice if structure is unexpected.
153
+ return str(choice)
154
+ except Exception as e:
155
+ last_err = e
156
+ if attempt < retries:
157
+ sleep(backoff * (2 ** attempt))
158
+ return f"Error: {last_err!r}"
159
+
160
+ results: List[Optional[str]] = [None] * len(all_message_batches)
161
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
162
+ future_to_idx = {
163
+ executor.submit(_chat_once_with_retry, batch): i
164
+ for i, batch in enumerate(all_message_batches)
165
+ }
166
+ for fut in as_completed(future_to_idx):
167
+ i = future_to_idx[fut]
168
+ results[i] = fut.result()
169
+
170
+ # mypy-friendly cast: no Nones remain at this point
171
+ return [r if r is not None else "Error: Unknown failure" for r in results]
172
+
173
+
174
+ def infer(prompt):
175
+ # prompt_question = prompt_question.replace('<image>', '')
176
+ # prompt = prompt_template.replace('{text}', text).replace('{question}', prompt_question)
177
+
178
+ messages = [
179
+ {"instruction": prompt},
180
+ ]
181
+ prompt_success = False
182
+ prompt_time = 0
183
+ outputs = ['\\boxed{None}']
184
+ while prompt_success == False and prompt_time <= 2:
185
+ try:
186
+ outputs = azure_gpt4(messages, model)
187
+ prompt_success = True
188
+ except:
189
+ prompt_time += 1
190
+ time.sleep(5)
191
+
192
+ return outputs[0]
193
+
194
+ '''
195
+ end of gpt infer
196
+ '''
197
+
198
+
199
+ from concurrent.futures import ThreadPoolExecutor, as_completed
200
+
201
+ def _call_infer(desc):
202
+ return infer(desc)
203
+
204
+ @dataclass
205
+ class GRPOScriptArguments(ScriptArguments):
206
+ """
207
+ Script arguments for the GRPO training script.
208
+
209
+ Args:
210
+ reward_funcs (`list[str]`):
211
+ List of reward functions. Possible values: 'accuracy', 'format'.
212
+ """
213
+
214
+ reward_funcs: list[str] = field(
215
+ default_factory=lambda: ["accuracy", "format"],
216
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
217
+ )
218
+
219
+ # reward_funcs: list[str] = field(
220
+ # default_factory=lambda: ["accuracy"],
221
+ # metadata={"help": "List of reward functions. Possible values: 'accuracy'"},
222
+ # )
223
+ max_pixels: Optional[int] = field(
224
+ default=12845056,
225
+ metadata={"help": "Maximum number of pixels for the image"},
226
+ )
227
+ min_pixels: Optional[int] = field(
228
+ default=3136,
229
+ metadata={"help": "Minimum number of pixels for the image"},
230
+ )
231
+ temporal: Optional[bool] = field(
232
+ default=True,
233
+ metadata={"help": "whether using temporal GRPO"},
234
+ )
235
+ len_control: Optional[bool] = field(
236
+ default=True,
237
+ metadata={"help": "whether using length reward"},
238
+ )
239
+
240
+
241
+
242
+ def accuracy_reward(completions, solution, **kwargs):
243
+ def extract_answer(text: str) -> str:
244
+ """
245
+ 1) Try the full <answer> … </answer> block.
246
+ 2) If that is missing, grab whatever follows the opening <answer> tag.
247
+ 3) Otherwise return the original text.
248
+ """
249
+ # ① normal case <answer> … </answer>
250
+ m = re.search(r'<answer>\s*(.*?)\s*</answer>', text, flags=re.DOTALL | re.IGNORECASE)
251
+ if m:
252
+ return m.group(1).strip()
253
+
254
+ # ② fallback <answer> … <end-of-string>
255
+ m = re.search(r'<answer>\s*(.*)$', text, flags=re.DOTALL | re.IGNORECASE)
256
+ if m:
257
+ return m.group(1).strip()
258
+
259
+ # ③ nothing found
260
+ return text.strip()
261
+
262
+ def extract_description(predict: str) -> Optional[str]:
263
+ """
264
+ Extracts the content of the <answer>…</answer> block from `predict`.
265
+ Returns the inner text (with leading/trailing whitespace stripped),
266
+ or None if no <answer> tag is found.
267
+ """
268
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
269
+ if not match:
270
+ return predict
271
+ return match.group(1).strip()
272
+
273
+ def single_accuracy_reward(predict: str, ground_truth: str) -> float:
274
+ answer = predict
275
+ return 1.0 if grade_answer(answer, ground_truth) else 0.0
276
+
277
+ def compute_math_score_single(predict: str, ground_truth: str, format_weight: float = 0.0) -> Dict[str, float]:
278
+ predict = re.sub(r"\s*(<|>|/)\s*", r"\1", predict)
279
+ accuracy_score = single_accuracy_reward(predict, ground_truth)
280
+ # return (1 - format_weight) * accuracy_score + format_weight * format_score
281
+ return accuracy_score
282
+
283
+ def normalize_number(num_str):
284
+ try:
285
+ num_str = num_str.replace(',', '')
286
+ return float(num_str)
287
+ except Exception as e:
288
+ print(f"Error converting '{num_str}' to float: {e}")
289
+ return None
290
+
291
+ def wer(reference, hypothesis):
292
+ ref_words = reference.split()
293
+ hyp_words = hypothesis.split()
294
+ m = len(ref_words)
295
+ n = len(hyp_words)
296
+ d = [[0]*(n+1) for _ in range(m+1)]
297
+ for i in range(m+1):
298
+ d[i][0] = i
299
+ for j in range(n+1):
300
+ d[0][j] = j
301
+ for i in range(1, m+1):
302
+ for j in range(1, n+1):
303
+ if ref_words[i-1] == hyp_words[j-1]:
304
+ d[i][j] = d[i-1][j-1]
305
+ else:
306
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
307
+ return d[m][n] / max(1, m)
308
+
309
+
310
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
311
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
312
+ scores = scorer.score(reference, hypothesis)
313
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
314
+ return average_fmeasure
315
+
316
+ # print('Computing rewards now...')
317
+ # second_prompts = kwargs.get("second_prompts") # ← list[str] or None
318
+ # second_completions = kwargs.get("second_completions")
319
+ # second_contents = [comp[0]["content"] for comp in second_completions]
320
+ # print('second prompts', second_prompts)
321
+ # print('-'*10)
322
+ # print('second completions', second_completions)
323
+ # print('-'*10)
324
+
325
+ # import time
326
+ # time.sleep(30)
327
+ question_type = kwargs['problem_type'][0]
328
+ questions = kwargs['problem']
329
+
330
+ contents = [completion[0]["content"] for completion in completions]
331
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
332
+ rewards = []
333
+
334
+ extracted_content_descriptions = [extract_description(ele) for ele in contents]
335
+
336
+ description_query_inputs = []
337
+ batch_messages = []
338
+ vllm_batch_messages = []
339
+
340
+ for index in range(len(extracted_content_descriptions)):
341
+ prompt_question = questions[index]
342
+ des_text = extracted_content_descriptions[index]
343
+ prompt_question = prompt_question.replace('<image>', '')
344
+ prompt_input = prompt_template.replace('{text}', des_text).replace('{question}', prompt_question) + TYPE_TEMPLATE[question_type]
345
+ description_query_inputs.append(prompt_input)
346
+ curr_msg = [
347
+ {"role": "system", "content": "You are a helpful assistant."},
348
+ {"role": "user", "content": prompt_input}
349
+ ]
350
+ vllm_batch_messages.append(curr_msg)
351
+
352
+
353
+ batched_vllm_outputs = chat_batch(client, vllm_batch_messages)
354
+
355
+ description_score_outputs = [extract_boxed_content(idx_input) for idx_input in batched_vllm_outputs]
356
+ # with ThreadPoolExecutor(max_workers=8) as executor:
357
+ # futures = [
358
+ # executor.submit(_call_infer, desc)
359
+ # for desc in description_query_inputs
360
+ # ]
361
+ # # collect as they finish (optional—keeps order of completion)
362
+ # for fut in as_completed(futures):
363
+ # # description_score_outputs.append(extract_answer(fut.result()))
364
+ # # extract_boxed_content
365
+ # description_score_outputs.append(extract_boxed_content(fut.result()))
366
+
367
+
368
+ gt_answers = [extract_answer(sol) for sol in solution]
369
+ description_rewards = [compute_math_score_single(description_score_outputs[count_idx], gt_answers[count_idx]) for count_idx in range(len(description_score_outputs))]
370
+
371
+ print(gt_answers)
372
+ print(description_score_outputs)
373
+ print(description_rewards)
374
+ print('-'*10)
375
+
376
+
377
+ for content, gt_ans, description_reward in zip(contents, gt_answers, description_rewards):
378
+ # for content, sol, question in zip(contents, solution, questions):
379
+ # for content, sol, second_content in zip(contents, solution, second_completions):
380
+ try:
381
+ # output_ans = extract_answer(content)
382
+ output_ans = extract_boxed_content(content)
383
+
384
+ if question_type != 'None':
385
+ answer_reward = compute_math_score_single(output_ans, gt_ans)
386
+ if description_reward == 0 and answer_reward == 1:
387
+ reward = alpha
388
+ else:
389
+ reward = description_reward + answer_reward
390
+ # reward = answer_reward
391
+ else:
392
+ print('Falling back to none rewards')
393
+ reward = 0.0
394
+ except Exception as e:
395
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
396
+ reward = 0.0
397
+
398
+ rewards.append(reward)
399
+
400
+ if os.getenv("DEBUG_MODE") == "true":
401
+ log_path = os.getenv("LOG_PATH")
402
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
403
+ with open(log_path, "a", encoding="utf-8") as f:
404
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
405
+ f.write(f"Content: {content}\n")
406
+ f.write(f"Solution: {gt_ans}\n")
407
+
408
+ return rewards
409
+
410
+
411
+ def simple_format_reward(completions, **kwargs):
412
+ """Reward function that checks the same format as `format_reward`:
413
+ <description>...</description><think>...</think>\boxed{...}
414
+ """
415
+ pattern = re.compile(
416
+ r"^\s*<description>.*?</description>\s*"
417
+ r"<think>.*?</think>\s*"
418
+ r"\\boxed\{.*?\}\s*$",
419
+ re.DOTALL,
420
+ )
421
+ completion_contents = [completion[0]["content"] for completion in completions]
422
+ return [0.1 if pattern.fullmatch(content or "") else 0.0
423
+ for content in completion_contents]
424
+
425
+
426
+ reward_funcs_registry = {
427
+ "accuracy": accuracy_reward,
428
+ "format": simple_format_reward,
429
+ }
430
+
431
+ # SYSTEM_PROMPT = (
432
+ # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
433
+ # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
434
+ # "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
435
+ # "<think> reasoning process here </think><answer> answer here </answer>"
436
+ # )
437
+
438
+ SYSTEM_PROMPT = (
439
+ "You are tasked with analyzing an image/video to generate a detailed description to help you answer the question. First analyze the image/video and produce a self-contained description—detailed enough that can lead to the correct answer. Wrap the entire description in <description> </description> tags.\n Next, engage in an internal dialogue and include self-reflection or verification in your reasoning process. Provide your detailed, step-by-step reasoning based on the image/video description information and image/video, and enclose this part within <think> </think> tags.\n Finally, provide a single word or phrase answer to the question in \boxed{}.\nThe output format should be: <description> image/video description here </description> <think> reasoning process here </think> \boxed{FINAL ANSWER here}."
440
+ )
441
+
442
+
443
+ def main(script_args, training_args, model_args):
444
+ # Get reward functions
445
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
446
+
447
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
448
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
449
+ else:
450
+ # Load the dataset
451
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
452
+
453
+
454
+ # Format into conversation
455
+ def make_conversation(example):
456
+ return {
457
+ "prompt": [
458
+ {"role": "system", "content": SYSTEM_PROMPT},
459
+ {"role": "user", "content": example["problem"]},
460
+ ],
461
+ }
462
+
463
+
464
+ # QUESTION_TEMPLATE = (
465
+ # "{Question}\n"
466
+ # "Please think about this question as if you were a human pondering deeply. "
467
+ # "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
468
+ # "It's encouraged to include self-reflection or verification in the reasoning process. "
469
+ # "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
470
+ # )
471
+
472
+ QUESTION_TEMPLATE = (
473
+ "{Question}\n"
474
+ "You are tasked with analyzing an image/video to generate a detailed description to help you answer the question. "
475
+ "First analyze the image/video and produce a self-contained description—detailed enough that can lead to the correct answer. "
476
+ "Wrap the entire description in <description> </description> tags.\n"
477
+ "Next, engage in an internal dialogue and include self-reflection or verification in your reasoning process. "
478
+ "Provide your detailed, step-by-step reasoning based on the image/video description information and image/video, and enclose this part within <think> </think> tags.\n"
479
+ "Finally, provide a single word or phrase answer to the question in \\boxed{{}}.\n"
480
+ "The output format should be: <description> image/video description here </description> "
481
+ "<think> reasoning process here </think> \\boxed{{FINAL ANSWER here}}."
482
+ )
483
+
484
+
485
+
486
+ def make_conversation_image(example):
487
+
488
+ return {
489
+ "prompt": [
490
+ {
491
+ "role": "user",
492
+ "content": [
493
+ {"type": "image"},
494
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
495
+ ],
496
+ },
497
+ ],
498
+ }
499
+
500
+
501
+ def make_conversation_video(example):
502
+ return {
503
+ "prompt": [
504
+ {
505
+ "role": "user",
506
+ "content": [
507
+ {"type": "video"},
508
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
509
+ ],
510
+ },
511
+ ],
512
+ }
513
+
514
+ def make_conversation_image_and_video(example):
515
+ if example["problem_type"] == 'multiple choice':
516
+ question = example['problem'] + "Options:\n"
517
+ for op in example["options"]:
518
+ question += op + "\n"
519
+ else:
520
+ question = example['problem']
521
+
522
+
523
+ msg ={
524
+ "prompt":
525
+ [{
526
+ "role": "user",
527
+ "content": [
528
+ {
529
+ "type": example['data_type'],
530
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
531
+ },
532
+ {
533
+ "type": "text",
534
+ # "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
535
+ "text": QUESTION_TEMPLATE.format(Question=question)
536
+ }
537
+ ]
538
+ }]
539
+ }
540
+
541
+ return msg
542
+
543
+
544
+ dataset = dataset.map(make_conversation_image_and_video)
545
+
546
+
547
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModifiedOrig
548
+ print("using: ", trainer_cls)
549
+
550
+ # Initialize the GRPO trainer
551
+ trainer = trainer_cls(
552
+ model=model_args.model_name_or_path,
553
+ reward_funcs=reward_funcs,
554
+ args=training_args,
555
+ script_args=script_args,
556
+ train_dataset=dataset[script_args.dataset_train_split],
557
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
558
+ peft_config=get_peft_config(model_args),
559
+ attn_implementation=model_args.attn_implementation,
560
+ max_pixels=script_args.max_pixels,
561
+ min_pixels=script_args.min_pixels,
562
+ )
563
+
564
+ if training_args.resume_from_checkpoint is not None:
565
+ checkpoint = training_args.resume_from_checkpoint
566
+ trainer.train(resume_from_checkpoint=checkpoint)
567
+ else:
568
+ trainer.train()
569
+
570
+ # Save and push to hub
571
+ trainer.save_model(training_args.output_dir)
572
+ if training_args.push_to_hub:
573
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
574
+
575
+
576
+ if __name__ == "__main__":
577
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
578
+ script_args, training_args, model_args = parser.parse_args_and_config()
579
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from datetime import datetime
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ from datasets import load_dataset, load_from_disk
22
+ from transformers import Qwen2VLForConditionalGeneration
23
+
24
+ from trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainerModified
25
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
26
+
27
+ from datasets import Dataset, DatasetDict
28
+
29
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
30
+ from rouge_score import rouge_scorer
31
+
32
+
33
+ @dataclass
34
+ class GRPOScriptArguments(ScriptArguments):
35
+ """
36
+ Script arguments for the GRPO training script.
37
+
38
+ Args:
39
+ reward_funcs (`list[str]`):
40
+ List of reward functions. Possible values: 'accuracy', 'format'.
41
+ """
42
+
43
+ reward_funcs: list[str] = field(
44
+ default_factory=lambda: ["accuracy", "format"],
45
+ metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
46
+ )
47
+ max_pixels: Optional[int] = field(
48
+ default=12845056,
49
+ metadata={"help": "Maximum number of pixels for the image"},
50
+ )
51
+ min_pixels: Optional[int] = field(
52
+ default=3136,
53
+ metadata={"help": "Minimum number of pixels for the image"},
54
+ )
55
+ temporal: Optional[bool] = field(
56
+ default=True,
57
+ metadata={"help": "whether using temporal GRPO"},
58
+ )
59
+ len_control: Optional[bool] = field(
60
+ default=True,
61
+ metadata={"help": "whether using length reward"},
62
+ )
63
+
64
+
65
+
66
+ def accuracy_reward(completions, solution, **kwargs):
67
+
68
+ def extract_answer(text):
69
+ pattern = r'<answer>\s*(.*?)\s*</answer>'
70
+ match = re.search(pattern, text, re.DOTALL)
71
+ if match:
72
+ return match.group(1).strip()
73
+ return ""
74
+
75
+ def normalize_number(num_str):
76
+ try:
77
+ num_str = num_str.replace(',', '')
78
+ return float(num_str)
79
+ except Exception as e:
80
+ print(f"Error converting '{num_str}' to float: {e}")
81
+ return None
82
+
83
+ def wer(reference, hypothesis):
84
+ ref_words = reference.split()
85
+ hyp_words = hypothesis.split()
86
+ m = len(ref_words)
87
+ n = len(hyp_words)
88
+ d = [[0]*(n+1) for _ in range(m+1)]
89
+ for i in range(m+1):
90
+ d[i][0] = i
91
+ for j in range(n+1):
92
+ d[0][j] = j
93
+ for i in range(1, m+1):
94
+ for j in range(1, n+1):
95
+ if ref_words[i-1] == hyp_words[j-1]:
96
+ d[i][j] = d[i-1][j-1]
97
+ else:
98
+ d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1])
99
+ return d[m][n] / max(1, m)
100
+
101
+
102
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
103
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
104
+ scores = scorer.score(reference, hypothesis)
105
+ average_fmeasure = (scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
106
+ return average_fmeasure
107
+
108
+
109
+ question_type = kwargs['problem_type'][0]
110
+
111
+ contents = [completion[0]["content"] for completion in completions]
112
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
113
+ rewards = []
114
+
115
+ for content, sol in zip(contents, solution):
116
+
117
+ try:
118
+ output_ans = extract_answer(content)
119
+ gt_ans = extract_answer(sol)
120
+ if question_type == "multiple choice":
121
+ reward = 1.0 if output_ans.strip() == gt_ans.strip() else 0.0
122
+ elif question_type == "numerical":
123
+ gt_has_decimal = ("." in gt_ans) or ("," in gt_ans)
124
+ out_has_decimal = ("." in output_ans) or ("," in output_ans)
125
+ if gt_has_decimal != out_has_decimal:
126
+ reward = 0.0
127
+ else:
128
+ gt_number = normalize_number(gt_ans)
129
+ out_number = normalize_number(output_ans)
130
+ if gt_number is None or out_number is None:
131
+ reward = 0.0
132
+ else:
133
+ reward = 1.0 if round(gt_number, 2) == round(out_number, 2) else 0.0
134
+ elif question_type == "OCR":
135
+ error_rate = wer(gt_ans, output_ans)
136
+ reward = 1 - error_rate
137
+ reward = max(0.0, min(1.0, reward))
138
+ elif question_type == "free-form":
139
+ score = compute_rouge_score(gt_ans, output_ans)
140
+ reward = max(0.0, min(1.0, score))
141
+ elif question_type == "regression":
142
+ gt_number = normalize_number(gt_ans)
143
+ out_number = normalize_number(output_ans)
144
+ if gt_number is None or out_number is None:
145
+ reward = 0.0
146
+ rel_diff = (abs(out_number - gt_number) + 1e-9) / (abs(gt_number) + 1e-9)
147
+ rel_diff = min(1.0, max(0.0, rel_diff))
148
+ reward = 1 - rel_diff
149
+ else:
150
+ reward = 0.0
151
+ except Exception as e:
152
+ print(f"Error in reward_fn for question_type '{question_type}': {e}")
153
+ reward = 0.0
154
+
155
+ rewards.append(reward)
156
+
157
+ if os.getenv("DEBUG_MODE") == "true":
158
+ log_path = os.getenv("LOG_PATH")
159
+ # local_rank = int(os.getenv("LOCAL_RANK", 0))
160
+ with open(log_path, "a", encoding="utf-8") as f:
161
+ f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
162
+ f.write(f"Content: {content}\n")
163
+ f.write(f"Solution: {sol}\n")
164
+
165
+ return rewards
166
+
167
+
168
+ def format_reward(completions, **kwargs):
169
+ """Reward function that checks if the completion has a specific format."""
170
+ pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
171
+ completion_contents = [completion[0]["content"] for completion in completions]
172
+ matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
173
+ return [0.1 if match else 0.0 for match in matches]
174
+
175
+
176
+ reward_funcs_registry = {
177
+ "accuracy": accuracy_reward,
178
+ "format": format_reward,
179
+ }
180
+
181
+ SYSTEM_PROMPT = (
182
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
183
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
184
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
185
+ "<think> reasoning process here </think><answer> answer here </answer>"
186
+ )
187
+
188
+
189
+ def main(script_args, training_args, model_args):
190
+ # Get reward functions
191
+ reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
192
+
193
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
194
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
195
+ else:
196
+ # Load the dataset
197
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
198
+
199
+
200
+ # Format into conversation
201
+ def make_conversation(example):
202
+ return {
203
+ "prompt": [
204
+ {"role": "system", "content": SYSTEM_PROMPT},
205
+ {"role": "user", "content": example["problem"]},
206
+ ],
207
+ }
208
+
209
+
210
+ QUESTION_TEMPLATE = (
211
+ "{Question}\n"
212
+ "Please think about this question as if you were a human pondering deeply. "
213
+ "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
214
+ "It's encouraged to include self-reflection or verification in the reasoning process. "
215
+ "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
216
+ )
217
+
218
+ TYPE_TEMPLATE = {
219
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
220
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
221
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
222
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
223
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
224
+ }
225
+
226
+ def make_conversation_image(example):
227
+
228
+ return {
229
+ "prompt": [
230
+ {
231
+ "role": "user",
232
+ "content": [
233
+ {"type": "image"},
234
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
235
+ ],
236
+ },
237
+ ],
238
+ }
239
+
240
+
241
+ def make_conversation_video(example):
242
+ return {
243
+ "prompt": [
244
+ {
245
+ "role": "user",
246
+ "content": [
247
+ {"type": "video"},
248
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
249
+ ],
250
+ },
251
+ ],
252
+ }
253
+
254
+ def make_conversation_image_and_video(example):
255
+ if example["problem_type"] == 'multiple choice':
256
+ question = example['problem'] + "Options:\n"
257
+ for op in example["options"]:
258
+ question += op + "\n"
259
+ else:
260
+ question = example['problem']
261
+
262
+
263
+ msg ={
264
+ "prompt":
265
+ [{
266
+ "role": "user",
267
+ "content": [
268
+ {
269
+ "type": example['data_type'],
270
+ # example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
271
+ },
272
+ {
273
+ "type": "text",
274
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
275
+ }
276
+ ]
277
+ }]
278
+ }
279
+
280
+ return msg
281
+
282
+
283
+ dataset = dataset.map(make_conversation_image_and_video)
284
+
285
+
286
+ trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainerModified
287
+ print("using: ", trainer_cls)
288
+
289
+ # Initialize the GRPO trainer
290
+ trainer = trainer_cls(
291
+ model=model_args.model_name_or_path,
292
+ reward_funcs=reward_funcs,
293
+ args=training_args,
294
+ script_args=script_args,
295
+ train_dataset=dataset[script_args.dataset_train_split],
296
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
297
+ peft_config=get_peft_config(model_args),
298
+ attn_implementation=model_args.attn_implementation,
299
+ max_pixels=script_args.max_pixels,
300
+ min_pixels=script_args.min_pixels,
301
+ )
302
+
303
+ if training_args.resume_from_checkpoint is not None:
304
+ checkpoint = training_args.resume_from_checkpoint
305
+ trainer.train(resume_from_checkpoint=checkpoint)
306
+ else:
307
+ trainer.train()
308
+
309
+ # Save and push to hub
310
+ trainer.save_model(training_args.output_dir)
311
+ if training_args.push_to_hub:
312
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
313
+
314
+
315
+ if __name__ == "__main__":
316
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
317
+ script_args, training_args, model_args = parser.parse_args_and_config()
318
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/grpo_vllm_caption.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import re
16
+ from datetime import datetime
17
+ import json
18
+ from io import BytesIO
19
+ import base64
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from datasets import load_dataset
25
+ from rouge_score import rouge_scorer
26
+ from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
27
+ import Levenshtein
28
+ import wandb
29
+
30
+ from dataclasses import dataclass, field
31
+ from typing import Optional
32
+ from math_verify import parse, verify
33
+
34
+ from trainer.grpo_trainer_vllm_caption import Qwen2VLGRPOTrainerCap
35
+
36
+ os.environ["WANDB_MODE"] = "offline"
37
+
38
+
39
+ wandb.init(project="SelfEval-R1", name="SelfEval-R1")
40
+
41
+
42
+ @dataclass
43
+ class GRPOScriptArguments(ScriptArguments):
44
+ """
45
+ Script arguments for the GRPO training script.
46
+
47
+ Args:
48
+ reward_funcs (`list[str]`):
49
+ List of reward functions. Possible values: 'accuracy', 'format'.
50
+ """
51
+
52
+ reward_funcs: list[str] = field(
53
+ default_factory=lambda: ["accuracy", "format"],
54
+ metadata={
55
+ "help": "List of reward functions. Possible values: 'accuracy', 'format'"},
56
+ )
57
+ max_pixels: Optional[int] = field(
58
+ default=12845056,
59
+ metadata={"help": "Maximum number of pixels for the image"},
60
+ )
61
+ min_pixels: Optional[int] = field(
62
+ default=3136,
63
+ metadata={"help": "Minimum number of pixels for the image"},
64
+ )
65
+ caption_reward: Optional[bool] = field(
66
+ default=True,
67
+ metadata={"help": "Whether to use caption reward or not"},
68
+ )
69
+ caption_reward_weight: Optional[float] = field(
70
+ default=0.1,
71
+ metadata={"help": "Weight for the caption reward"},
72
+ )
73
+
74
+
75
+ # This function is partially borrowed from Video-R1[https://github.com/tulerfeng/Video-R1]
76
+ def accuracy_reward(completions, solution, **kwargs):
77
+
78
+ def extract_answer(text):
79
+ pattern = r'<answer>(.*?)</answer>'
80
+ match = re.search(pattern, text, re.DOTALL)
81
+ if match:
82
+ return match.group(1).strip()
83
+ return ""
84
+
85
+ def extract_option(text):
86
+ pattern = r'<option>(.*?)</option>'
87
+ match = re.search(pattern, text, re.DOTALL)
88
+ if match:
89
+ return match.group(1).strip()
90
+ return ""
91
+
92
+ def is_number(num_str):
93
+ try:
94
+ float(num_str)
95
+ return True
96
+ except Exception as e:
97
+ return False
98
+
99
+ def extract_numbers(answer):
100
+ pattern = r"[-+]?\d*\.?\d+"
101
+ match = re.search(pattern, answer)
102
+ if match:
103
+ number_str = match.group()
104
+ if answer.strip().endswith('%'):
105
+ number = float(number_str) / 100
106
+ else:
107
+ number = float(number_str)
108
+ return number
109
+ else:
110
+ return None
111
+
112
+ def anls(reference, hypothesis):
113
+ distance = Levenshtein.distance(reference, hypothesis)
114
+ max_length = max(len(reference), len(hypothesis))
115
+ similarity = 1 - (distance / max_length)
116
+
117
+ return similarity
118
+
119
+ def compute_rouge_score(reference, hypothesis, use_stemmer=True):
120
+ scorer = rouge_scorer.RougeScorer(
121
+ ['rouge1', 'rouge2', 'rougeL'], use_stemmer=use_stemmer)
122
+ scores = scorer.score(reference, hypothesis)
123
+ average_fmeasure = (
124
+ scores['rouge1'].fmeasure + scores['rouge2'].fmeasure + scores['rougeL'].fmeasure) / 3
125
+ return average_fmeasure
126
+
127
+ question_type = kwargs['problem_type'][0]
128
+
129
+ contents = [completion[0]["content"] for completion in completions]
130
+ current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
131
+ rewards = []
132
+
133
+ for content, sol in zip(contents, solution):
134
+ try:
135
+ output_ans = extract_answer(content)
136
+ gt_ans = extract_answer(sol)
137
+ if question_type == "OCR":
138
+ if is_number(gt_ans):
139
+ output_ans = extract_numbers(output_ans)
140
+ reward = 1.0 if output_ans == float(
141
+ gt_ans) else 0.0
142
+ else:
143
+ reward = anls(gt_ans.lower(),
144
+ output_ans.lower())
145
+ reward = max(0.0, min(1.0, reward))
146
+ elif question_type == "free-form":
147
+ score = compute_rouge_score(gt_ans, output_ans)
148
+ reward = max(0.0, min(1.0, score))
149
+ else:
150
+ if is_number(gt_ans):
151
+ output_ans = extract_numbers(output_ans)
152
+ reward = 1.0 if output_ans == float(
153
+ gt_ans) else 0.0
154
+ else:
155
+ reward = 1.0 if output_ans.lower() == gt_ans.lower() else 0.0
156
+ except Exception as e:
157
+ print(
158
+ f"Error in reward_fn for question_type '{question_type}': {e}")
159
+ reward = 0.0
160
+
161
+ rewards.append(reward)
162
+
163
+ if os.getenv("DEBUG_MODE") == "true":
164
+ log_path = 'debug.log'
165
+ with open(log_path, "a") as f:
166
+ try:
167
+ f.write(
168
+ f"------------- {current_time} Accuracy reward: {reward} -------------\n")
169
+ f.write(f"Content: {content}\n")
170
+ f.write(f"Solution: {sol}\n")
171
+ f.write(f"type: {question_type}\n")
172
+ except BaseException:
173
+ f.write("writeing error")
174
+
175
+ return rewards
176
+
177
+
178
+ def format_reward(completions, **kwargs):
179
+ """Reward function that checks if the completion has a specific format."""
180
+ pattern = r"<info>.*?</info>\s<think>.*?</think>\s*<answer>.*?</answer>"
181
+ completion_contents = [completion[0]["content"]
182
+ for completion in completions]
183
+ matches = [re.fullmatch(pattern, content, re.DOTALL)
184
+ for content in completion_contents]
185
+ return [1.0 if match else 0.0 for match in matches]
186
+
187
+
188
+ reward_funcs_registry = {
189
+ "accuracy": accuracy_reward,
190
+ "format": format_reward,
191
+ }
192
+
193
+
194
+ SYSTEM_PROMPT = (
195
+ "You are tasked with analyzing an image to generate an exhaustive and detailed description. "
196
+ "Your goal is to extract and describe all possible information from the image, including but not limited to objects, "
197
+ "numbers, text, and the relationships between these elements. The description should be as fine and detailed as possible, "
198
+ "capturing every nuance. After generating the detailed description, you need to analyze it and provide step-by-step "
199
+ "detailed reasoning for the given question based on the information. Finally, provide a single word or phrase answer "
200
+ "to the question. The description, reasoning process and answer are enclosed within <info> </info>, <think> </think> "
201
+ "and <answer> </answer> tags, respectively, i.e., <info> image description here </info> <think> reasoning process here "
202
+ "</think> <answer> answer here </answer>"
203
+ )
204
+
205
+
206
+ def main(script_args, training_args, model_args):
207
+ # Get reward functions
208
+ reward_funcs = [reward_funcs_registry[func]
209
+ for func in script_args.reward_funcs]
210
+
211
+ # Load the dataset
212
+ # dataset = load_dataset(script_args.dataset_name,
213
+ # name=script_args.dataset_config)
214
+ dataset = load_dataset("json", data_files=script_args.dataset_name, split='train')
215
+
216
+
217
+ # Format into conversation
218
+ def make_conversation_image(example):
219
+ return {
220
+ "prompt": [
221
+ {"role": "system", "content": [
222
+ {"type": "text", "text": SYSTEM_PROMPT}]},
223
+ {
224
+ "role": "user",
225
+ "content": [
226
+ {"type": "image"},
227
+ {"type": "text", "text": example["problem"]},
228
+ ],
229
+ },
230
+ ]
231
+ }
232
+
233
+ dataset = dataset.map(make_conversation_image)
234
+
235
+ if "Qwen" in model_args.model_name_or_path or "Aria" in model_args.model_name_or_path:
236
+ trainer_cls = Qwen2VLGRPOTrainerCap
237
+ else:
238
+ trainer_cls = GRPOTrainer
239
+
240
+ # Initialize the GRPO trainer
241
+ trainer = trainer_cls(
242
+ model=model_args.model_name_or_path,
243
+ reward_funcs=reward_funcs,
244
+ args=training_args,
245
+ train_dataset=dataset,
246
+ eval_dataset=None,
247
+ peft_config=get_peft_config(model_args),
248
+ attn_implementation=model_args.attn_implementation,
249
+ max_pixels=script_args.max_pixels,
250
+ min_pixels=script_args.min_pixels,
251
+ caption_reward=script_args.caption_reward,
252
+ caption_reward_weight=script_args.caption_reward_weight,
253
+ )
254
+
255
+ trainer.train()
256
+ # trainer.train()
257
+
258
+
259
+ if __name__ == "__main__":
260
+ parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
261
+ script_args, training_args, model_args = parser.parse_args_and_config()
262
+
263
+ print('training_args:\n', training_args)
264
+ print('script_args:\n', script_args)
265
+ print('model_args:\n', model_args)
266
+ main(script_args, training_args, model_args)
src/r1-v/src/open_r1/sft_video.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Example usage:
16
+ accelerate launch \
17
+ --config_file=deepspeed_zero2.yaml \
18
+ train_video_llm.py \
19
+ --dataset_name mfarre/simplevideoshorts \
20
+ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
21
+ --per_device_train_batch_size 1 \
22
+ --gradient_accumulation_steps 4 \
23
+ --output_dir video-llm-output \
24
+ --bf16 \
25
+ --torch_dtype bfloat16 \
26
+ --gradient_checkpointing
27
+ """
28
+
29
+ import os
30
+ import json
31
+ import random
32
+ import requests
33
+ import torch
34
+ from torch.optim import AdamW
35
+ from datasets import load_dataset
36
+ from transformers import (
37
+ AutoModelForVision2Seq,
38
+ AutoProcessor,
39
+ BitsAndBytesConfig,
40
+ Qwen2VLProcessor,
41
+ Qwen2VLForConditionalGeneration,
42
+ Qwen2_5_VLForConditionalGeneration
43
+ )
44
+ from transformers import get_linear_schedule_with_warmup
45
+
46
+ from trl import (
47
+ ModelConfig,
48
+ ScriptArguments,
49
+ SFTConfig,
50
+ SFTTrainer,
51
+ TrlParser,
52
+ get_kbit_device_map,
53
+ get_peft_config,
54
+ )
55
+ from accelerate import Accelerator
56
+ from qwen_vl_utils import process_vision_info
57
+
58
+ from datasets import Dataset, DatasetDict
59
+
60
+ import wandb
61
+
62
+ from typing import List, Dict, Any
63
+
64
+ os.environ["DS_BUILD_FUSED_ADAM"] = "0"
65
+
66
+ def get_current_device():
67
+ """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
68
+ return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
69
+
70
+ def download_video(url: str, folder: str = '/tmp/videos/') -> str:
71
+ """Download video if not already present locally."""
72
+ filename = url.split("/")[-1]
73
+ local_path = os.path.join(folder, filename)
74
+
75
+ if os.path.exists(local_path):
76
+ return local_path
77
+
78
+ try:
79
+ with requests.get(url, stream=True) as r:
80
+ r.raise_for_status()
81
+ with open(local_path, 'wb') as f:
82
+ for chunk in r.iter_content(chunk_size=8192):
83
+ if chunk:
84
+ f.write(chunk)
85
+ return local_path
86
+ except requests.RequestException as e:
87
+ raise Exception(f"Failed to download video: {e}")
88
+
89
+ def prepare_dataset(example: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
90
+ """Prepare dataset example for training."""
91
+
92
+
93
+
94
+ system_message = "You are a helpful assistant"
95
+
96
+
97
+ QUESTION_TEMPLATE = (
98
+ "{Question}\n"
99
+ "Please think about this question as if you were a human pondering deeply. "
100
+ "Engage in an internal dialogue using expressions such as 'let me think', 'wait', 'Hmm', 'oh, I see', 'let's break it down', etc, or other natural language thought expressions "
101
+ "It's encouraged to include self-reflection or verification in the reasoning process. "
102
+ "Provide your detailed reasoning between the <think> </think> tags, and then give your final answer between the <answer> </answer> tags."
103
+ )
104
+
105
+ TYPE_TEMPLATE = {
106
+ "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
107
+ "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
108
+ "OCR": " Please transcribe text from the image/video clearly and provide your text answer within the <answer> </answer> tags.",
109
+ "free-form": " Please provide your text answer within the <answer> </answer> tags.",
110
+ "regression": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags."
111
+ }
112
+
113
+
114
+
115
+ if example["problem_type"] == 'multiple choice':
116
+ question = example['problem'] + "Options:\n"
117
+ for op in example["options"]:
118
+ question += op + "\n"
119
+ else:
120
+ question = example['problem']
121
+
122
+
123
+ messages = [
124
+ {
125
+ "role": "system",
126
+ "content": [{"type": "text", "text": system_message}]
127
+ },
128
+ {
129
+ "role": "user",
130
+ "content": [
131
+ {
132
+ "type": example['data_type'],
133
+ example['data_type']: os.getcwd() + "/Video-R1-data" + example['path'][1:]
134
+ # "max_pixels": 360*420,
135
+ # "fps": 1.0
136
+ },
137
+ {
138
+ "type": "text",
139
+ "text": QUESTION_TEMPLATE.format(Question=question) + TYPE_TEMPLATE[example['problem_type']]
140
+ }
141
+ ]
142
+ },
143
+ {
144
+ "role": "assistant",
145
+ "content": [{"type": "text", "text": example['process'] + "\n" + example['solution']}]
146
+ }
147
+ ]
148
+
149
+
150
+ return {"messages": messages}
151
+
152
+ def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
153
+ """Collate batch of examples for training."""
154
+ texts = []
155
+ # video_inputs = []
156
+ # image_inputs = []
157
+
158
+ for i, example in enumerate(examples):
159
+ try:
160
+
161
+ texts.append(processor.apply_chat_template(example["messages"], tokenize=False))
162
+ image_inputs, video_inputs, video_kwargs = process_vision_info(example["messages"], return_video_kwargs=True)
163
+
164
+ except Exception as e:
165
+ raise ValueError(f"Failed to process example {i}: {e}")
166
+
167
+ inputs = processor(
168
+ text=texts,
169
+ images=image_inputs,
170
+ videos=video_inputs,
171
+ return_tensors="pt",
172
+ padding=True
173
+ )
174
+
175
+ labels = inputs["input_ids"].clone()
176
+ labels[labels == processor.tokenizer.pad_token_id] = -100
177
+
178
+ # Handle visual tokens based on processor type
179
+ visual_tokens = [151652, 151653, 151656] if isinstance(processor, Qwen2VLProcessor) else [
180
+ processor.tokenizer.convert_tokens_to_ids(processor.image_token)
181
+ ]
182
+
183
+ for visual_token_id in visual_tokens:
184
+ labels[labels == visual_token_id] = -100
185
+
186
+ inputs["labels"] = labels
187
+ return inputs
188
+
189
+ if __name__ == "__main__":
190
+ # Parse arguments
191
+ parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
192
+ script_args, training_args, model_config = parser.parse_args_and_config()
193
+
194
+ # Configure training args
195
+ training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
196
+ training_args.remove_unused_columns = False
197
+ training_args.dataset_kwargs = {"skip_prepare_dataset": True}
198
+
199
+ # Load dataset
200
+ if script_args.dataset_name.endswith('.json') or script_args.dataset_name.endswith('.jsonl'):
201
+ dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
202
+ else:
203
+ # Load the dataset
204
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
205
+
206
+ # Setup model
207
+ torch_dtype = (
208
+ model_config.torch_dtype
209
+ if model_config.torch_dtype in ["auto", None]
210
+ else getattr(torch, model_config.torch_dtype)
211
+ )
212
+
213
+ # # Quantization configuration for 4-bit training
214
+ # bnb_config = BitsAndBytesConfig(
215
+ # load_in_4bit=True,
216
+ # bnb_4bit_use_double_quant=True,
217
+ # bnb_4bit_quant_type="nf4",
218
+ # bnb_4bit_compute_dtype=torch.bfloat16
219
+ # )
220
+
221
+ # Model initialization
222
+ model_kwargs = dict(
223
+ revision=model_config.model_revision,
224
+ trust_remote_code=model_config.trust_remote_code,
225
+ torch_dtype=torch_dtype,
226
+ device_map=get_kbit_device_map(),
227
+ # quantization_config=bnb_config,
228
+ )
229
+
230
+
231
+ if "Qwen2-VL" in model_config.model_name_or_path:
232
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
233
+ elif "Qwen2.5-VL" in model_config.model_name_or_path:
234
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
235
+ else:
236
+ model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
237
+
238
+ processor = AutoProcessor.from_pretrained(
239
+ model_config.model_name_or_path,
240
+ trust_remote_code=model_config.trust_remote_code
241
+ )
242
+
243
+ # Prepare dataset
244
+ prepared_dataset = [prepare_dataset(example) for example in dataset['train']]
245
+
246
+ # Initialize wandb if specified
247
+ if training_args.report_to == "wandb":
248
+ wandb.init(project="video-llm-training")
249
+
250
+
251
+ '''
252
+ Below is added code
253
+ '''
254
+ base_lr = 2e-4
255
+ optimizer = AdamW(
256
+ params=model.parameters(),
257
+ lr=base_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
258
+ )
259
+
260
+ num_training_steps = len(prepared_dataset) // (
261
+ training_args.per_device_train_batch_size
262
+ * training_args.gradient_accumulation_steps
263
+ * training_args.world_size
264
+ ) * training_args.num_train_epochs
265
+
266
+ lr_scheduler = get_linear_schedule_with_warmup(
267
+ optimizer,
268
+ num_warmup_steps=int(0.05 * num_training_steps),
269
+ num_training_steps=num_training_steps,
270
+ )
271
+ '''
272
+ Above is added code
273
+ '''
274
+
275
+
276
+ # Initialize trainer
277
+ trainer = SFTTrainer(
278
+ model=model,
279
+ args=training_args,
280
+ train_dataset=prepared_dataset,
281
+ data_collator=collate_fn,
282
+ peft_config=get_peft_config(model_config),
283
+ # tokenizer=processor.tokenizer
284
+ optimizers=(optimizer, lr_scheduler),
285
+ )
286
+
287
+ # Train model
288
+ trainer.train()
289
+
290
+ # Save final model
291
+
292
+ trainer.save_model(training_args.output_dir)
293
+ processor.save_pretrained(training_args.output_dir)
294
+
295
+ if trainer.accelerator.is_main_process:
296
+ # Restore k,v cache for fast inference
297
+ trainer.model.config.use_cache = True
298
+ trainer.model.config.save_pretrained(training_args.output_dir)
299
+
300
+ # Cleanup
301
+ del model
302
+ del trainer
303
+ torch.cuda.empty_cache()
304
+ wandb.finish()
src/r1-v/src/open_r1/trainer/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .grpo_trainer import Qwen2VLGRPOTrainer
2
+ from .vllm_grpo_trainer_modified import Qwen2VLGRPOVLLMTrainerModified
3
+ from .vllm_grpo_trainer_modified_orig import Qwen2VLGRPOVLLMTrainerModifiedOrig
4
+ from .vllm_grpo_trainer_selfConst import Qwen2VLGRPOVLLMTrainerSelfConst
5
+
6
+
7
+ __all__ = [
8
+ "Qwen2VLGRPOTrainer",
9
+ "Qwen2VLGRPOVLLMTrainerModified",
10
+ "Qwen2VLGRPOVLLMTrainerModifiedOrig",
11
+ "Qwen2VLGRPOVLLMTrainerSelfConst"
12
+ ]
src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_error.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ from accelerate.utils.other import is_compiled_module
20
+ from accelerate.utils import broadcast_object_list, gather, gather_object
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ import warnings
25
+ from unittest.mock import patch
26
+ from datasets import Dataset, IterableDataset
27
+ from packaging import version
28
+ from transformers import (
29
+ AriaForConditionalGeneration,
30
+ AriaProcessor,
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoProcessor,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Qwen2VLForConditionalGeneration,
39
+ Qwen2_5_VLForConditionalGeneration,
40
+ Trainer,
41
+ TrainerCallback,
42
+ is_wandb_available,
43
+ )
44
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
45
+ from transformers.utils import is_peft_available
46
+
47
+ from trl.data_utils import (
48
+ apply_chat_template,
49
+ is_conversational,
50
+ maybe_apply_chat_template,
51
+ )
52
+ from trl.import_utils import is_vllm_available
53
+
54
+ from trl.models import (
55
+ create_reference_model,
56
+ prepare_deepspeed,
57
+ unwrap_model_for_generation,
58
+ )
59
+ from trl.trainer.grpo_config import GRPOConfig
60
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
61
+ from trl import GRPOTrainer
62
+
63
+ import copy
64
+
65
+ if is_peft_available():
66
+ from peft import PeftConfig, get_peft_model
67
+
68
+ if is_vllm_available():
69
+ from vllm import LLM, SamplingParams
70
+
71
+ if is_wandb_available():
72
+ import wandb
73
+ import torch.nn as nn
74
+ from torch.utils.data import Sampler
75
+ import gc
76
+ from qwen_vl_utils import process_vision_info
77
+
78
+ import re
79
+
80
+ def extract_answer(predict: str) -> Optional[str]:
81
+ """
82
+ Extracts the content of the <answer>…</answer> block from `predict`.
83
+ Returns the inner text (with leading/trailing whitespace stripped),
84
+ or None if no <answer> tag is found.
85
+ """
86
+ match = re.search(r"<answer>([\s\S]*?)</answer>", predict, re.DOTALL)
87
+ if not match:
88
+ return None
89
+ return match.group(1).strip()
90
+
91
+ def extract_info(predict: str) -> Optional[str]:
92
+ """
93
+ Extracts the content of the <answer>…</answer> block from `predict`.
94
+ Returns the inner text (with leading/trailing whitespace stripped),
95
+ or None if no <answer> tag is found.
96
+ """
97
+ match = re.search(r"<des>([\s\S]*?)</des>", predict, re.DOTALL)
98
+ if not match:
99
+ return None
100
+ return match.group(1).strip()
101
+
102
+
103
+
104
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
105
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
106
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
107
+
108
+
109
+ class Qwen2VLGRPOVLLMTrainerModified(Trainer):
110
+ def __init__(
111
+ self,
112
+ model: Union[str, PreTrainedModel],
113
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
114
+ args: GRPOConfig = None,
115
+ script_args = None,
116
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
117
+ eval_dataset: Optional[
118
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
119
+ ] = None,
120
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
121
+ reward_processing_classes: Optional[
122
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
123
+ ] = None,
124
+ callbacks: Optional[list[TrainerCallback]] = None,
125
+ optimizers: tuple[
126
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
127
+ ] = (None, None),
128
+ peft_config: Optional["PeftConfig"] = None,
129
+ # qwen2-vl related params
130
+ max_pixels: Optional[int] = 12845056,
131
+ min_pixels: Optional[int] = 3136,
132
+ attn_implementation: str = "flash_attention_2",
133
+ ):
134
+
135
+ # Args
136
+ if args is None:
137
+ model_name = model if isinstance(model, str) else model.config._name_or_path
138
+ model_name = model_name.split("/")[-1]
139
+ args = GRPOConfig(f"{model_name}-GRPO")
140
+
141
+ # Models
142
+ # Trained model
143
+ model_init_kwargs = args.model_init_kwargs or {}
144
+ model_init_kwargs["attn_implementation"] = attn_implementation
145
+ if isinstance(model, str):
146
+ model_id = model
147
+ torch_dtype = model_init_kwargs.get("torch_dtype")
148
+ if (
149
+ isinstance(torch_dtype, torch.dtype)
150
+ or torch_dtype == "auto"
151
+ or torch_dtype is None
152
+ ):
153
+ pass # torch_dtype is already a torch.dtype or "auto" or None
154
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
155
+ torch_dtype = getattr(torch, torch_dtype)
156
+ model_init_kwargs["torch_dtype"] = torch_dtype
157
+ else:
158
+ raise ValueError(
159
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
160
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
161
+ )
162
+ # Disable caching if gradient checkpointing is enabled (not supported)
163
+ model_init_kwargs["use_cache"] = (
164
+ False
165
+ if args.gradient_checkpointing
166
+ else model_init_kwargs.get("use_cache")
167
+ )
168
+ if "Qwen2-VL" in model_id:
169
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
170
+ model, **model_init_kwargs
171
+ )
172
+ elif "Qwen2.5-VL" in model_id:
173
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
174
+ model, **model_init_kwargs
175
+ )
176
+ elif "Aria" in model_id:
177
+ model_init_kwargs.pop("use_cache")
178
+ model = AriaForConditionalGeneration.from_pretrained(
179
+ model, **model_init_kwargs
180
+ )
181
+ else:
182
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
183
+ else:
184
+ model_id = model.config._name_or_path
185
+ if args.model_init_kwargs is not None:
186
+ raise ValueError(
187
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
188
+ "This argument can only be used when the `model` argument is a string."
189
+ )
190
+
191
+ if peft_config is not None:
192
+ model = get_peft_model(model, peft_config)
193
+
194
+ # Reference model
195
+ if is_deepspeed_zero3_enabled():
196
+ if "Qwen2-VL" in model_id:
197
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
198
+ model_id, **model_init_kwargs
199
+ )
200
+ elif "Qwen2.5-VL" in model_id:
201
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
202
+ model_id, **model_init_kwargs
203
+ )
204
+ elif "Aria" in model_id:
205
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
206
+ model_id, **model_init_kwargs
207
+ )
208
+ else:
209
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
210
+ model_id, **model_init_kwargs
211
+ )
212
+ elif peft_config is None:
213
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
214
+ self.ref_model = create_reference_model(model)
215
+ else:
216
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
217
+ # to revert to the initial model.
218
+ self.ref_model = None
219
+
220
+ # Processing class
221
+ if processing_class is None:
222
+ if "Qwen" in model_id or "Aria" in model_id:
223
+ processing_class = AutoProcessor.from_pretrained(model_id)
224
+ pad_token_id = processing_class.tokenizer.pad_token_id
225
+ processing_class.pad_token_id = pad_token_id
226
+ processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
227
+ if "Qwen" in model_id:
228
+ processing_class.image_processor.max_pixels = max_pixels
229
+ processing_class.image_processor.min_pixels = min_pixels
230
+ else:
231
+ processing_class = AutoTokenizer.from_pretrained(
232
+ model.config._name_or_path, padding_side="left"
233
+ )
234
+ pad_token_id = processing_class.pad_token_id
235
+
236
+ # Reward functions
237
+ if not isinstance(reward_funcs, list):
238
+ reward_funcs = [reward_funcs]
239
+ for i, reward_func in enumerate(reward_funcs):
240
+ if isinstance(reward_func, str):
241
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
242
+ reward_func, num_labels=1, **model_init_kwargs
243
+ )
244
+ self.reward_funcs = reward_funcs
245
+
246
+ # Reward processing class
247
+ if reward_processing_classes is None:
248
+ reward_processing_classes = [None] * len(reward_funcs)
249
+ elif not isinstance(reward_processing_classes, list):
250
+ reward_processing_classes = [reward_processing_classes]
251
+ else:
252
+ if len(reward_processing_classes) != len(reward_funcs):
253
+ raise ValueError(
254
+ "The number of reward processing classes must match the number of reward functions."
255
+ )
256
+
257
+ for i, (reward_processing_class, reward_func) in enumerate(
258
+ zip(reward_processing_classes, reward_funcs)
259
+ ):
260
+ if isinstance(reward_func, PreTrainedModel):
261
+ if reward_processing_class is None:
262
+ reward_processing_class = AutoTokenizer.from_pretrained(
263
+ reward_func.config._name_or_path
264
+ )
265
+ if reward_processing_class.pad_token_id is None:
266
+ reward_processing_class.pad_token = (
267
+ reward_processing_class.eos_token
268
+ )
269
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
270
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
271
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
272
+ reward_processing_classes[i] = reward_processing_class
273
+ self.reward_processing_classes = reward_processing_classes
274
+
275
+ # Data collator
276
+ def data_collator(features): # No data collation is needed in GRPO
277
+ return features
278
+
279
+ # Training arguments
280
+ self.max_prompt_length = args.max_prompt_length
281
+ self.max_completion_length = (
282
+ args.max_completion_length
283
+ ) # = |o_i| in the GRPO paper
284
+ self.num_generations = args.num_generations # = G in the GRPO paper
285
+ self.temporal = script_args.temporal
286
+ self.generation_config = GenerationConfig(
287
+ max_new_tokens=self.max_completion_length,
288
+ do_sample=True,
289
+ temperature=1, # HACK
290
+ num_return_sequences=self.num_generations,
291
+ pad_token_id=pad_token_id,
292
+ )
293
+ self.beta = args.beta
294
+
295
+ self.shuffled_num_generations = self.num_generations // 2
296
+ self.shuffled_generation_config = GenerationConfig(
297
+ max_new_tokens=self.max_completion_length,
298
+ do_sample=True,
299
+ top_p=0.95,
300
+ temperature=1, # HACK
301
+ num_return_sequences=self.shuffled_num_generations,
302
+ pad_token_id=pad_token_id,
303
+ )
304
+
305
+ self.dummy_generation_config = GenerationConfig(
306
+ max_new_tokens=1,
307
+ do_sample=True,
308
+ top_p=0.95,
309
+ temperature=1, # HACK
310
+ num_return_sequences=1,
311
+ pad_token_id=pad_token_id,
312
+ )
313
+ self.len_control = script_args.len_control
314
+ self.beta = args.beta
315
+
316
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
317
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
318
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
319
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
320
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
321
+ # This acts as a flag to indicate that the warning has already been issued.
322
+ model.warnings_issued["estimate_tokens"] = True
323
+
324
+ # Initialize the metrics
325
+ self._metrics = defaultdict(list)
326
+ self.use_vllm = args.use_vllm
327
+
328
+ super().__init__(
329
+ model=model,
330
+ args=args,
331
+ data_collator=data_collator,
332
+ train_dataset=train_dataset,
333
+ eval_dataset=eval_dataset,
334
+ processing_class=processing_class,
335
+ callbacks=callbacks,
336
+ optimizers=optimizers,
337
+ )
338
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
339
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
340
+ # self.model_accepts_loss_kwargs to False to enable scaling.
341
+ self.model_accepts_loss_kwargs = False
342
+
343
+ if self.use_vllm:
344
+ if not is_vllm_available():
345
+ raise ImportError(
346
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
347
+ "`pip install vllm` to use it."
348
+ )
349
+
350
+ if self.accelerator.is_main_process:
351
+ vllm_device = self.args.vllm_device
352
+ if vllm_device == "auto":
353
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
354
+ # Check that the requested device is available
355
+ if (
356
+ vllm_device.split(":")[0] == "cuda"
357
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
358
+ ):
359
+ raise ValueError(
360
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
361
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
362
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
363
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
364
+ )
365
+ # Check that the requested device is not also used for training
366
+ if vllm_device in {
367
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
368
+ }:
369
+ warnings.warn(
370
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
371
+ "behavior. It is recommended to use a dedicated device for vLLM."
372
+ )
373
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
374
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
375
+ # setting (profiling_patch).
376
+ world_size_patch = patch(
377
+ "torch.distributed.get_world_size", return_value=1
378
+ )
379
+ profiling_patch = patch(
380
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
381
+ return_value=None,
382
+ )
383
+ with world_size_patch, profiling_patch:
384
+ print("vllm is running on: ", vllm_device)
385
+ self.llm = LLM(
386
+ model=model.name_or_path,
387
+ device=vllm_device,
388
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
389
+ dtype=torch.bfloat16,
390
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
391
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
392
+ # This is particularly useful here because we generate completions from the same prompts.
393
+ enable_prefix_caching=True,
394
+ enforce_eager=True,
395
+ mm_processor_kwargs=(
396
+ {
397
+ "max_pixels": max_pixels,
398
+ "min_pixels": min_pixels,
399
+ }
400
+ # if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
401
+ if False
402
+ else None
403
+ ),
404
+ max_model_len=args.max_prompt_length + args.max_completion_length,
405
+ )
406
+ self.sampling_params = SamplingParams(
407
+ temperature=1.0,
408
+ top_p=0.95,
409
+ max_tokens=self.max_completion_length,
410
+ )
411
+
412
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
413
+
414
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
415
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
416
+ # synchronize all processes after vLLM has been fully initialized.
417
+ self.accelerator.wait_for_everyone()
418
+ else:
419
+ raise ValueError(
420
+ "GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
421
+ )
422
+
423
+ if self.ref_model is not None:
424
+ if self.is_deepspeed_enabled:
425
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
426
+ else:
427
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
428
+
429
+ for i, reward_func in enumerate(self.reward_funcs):
430
+ if isinstance(reward_func, PreTrainedModel):
431
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
432
+
433
+ def _set_signature_columns_if_needed(self):
434
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
435
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
436
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
437
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
438
+ if self._signature_columns is None:
439
+ self._signature_columns = ["prompt"]
440
+
441
+ # Get the per-token log probabilities for the completions for the model and the reference model
442
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
443
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
444
+ # import pdb
445
+ # pdb.set_trace()
446
+ logits = model(input_ids, **kwargs).logits
447
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
448
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
449
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
450
+ per_token_logps = []
451
+ for logits_row, input_ids_row in zip(logits, input_ids):
452
+ log_probs = logits_row.log_softmax(dim=-1)
453
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
454
+ per_token_logps.append(token_log_prob)
455
+ return torch.stack(per_token_logps)
456
+
457
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
458
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
459
+ def _prepare_inputs(
460
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
461
+ ) -> dict[str, Union[torch.Tensor, Any]]:
462
+ return inputs
463
+
464
+ def remove_none_from_data(self, data):
465
+ for entry in data:
466
+ if "content" in entry and isinstance(entry["content"], list):
467
+ for sub_entry in entry["content"]:
468
+ if isinstance(sub_entry, dict):
469
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
470
+ for k in keys_to_remove:
471
+ del sub_entry[k]
472
+ return data
473
+
474
+ def _vllm_generate(self, prompts_text, mm_data, n):
475
+ """
476
+ Helper that wraps the whole ‘gather-broadcast-slice-pad-decode’ dance
477
+ and returns (completion_ids, decoded_texts) *ON THIS RANK ONLY*.
478
+ `mm_data` can be None/[] for pure-text inputs.
479
+ """
480
+ device = self.accelerator.device
481
+
482
+ # --------------- gather everything to rank-0 ----------------
483
+ all_prompts = gather_object(prompts_text)
484
+ all_mm_data = gather_object(mm_data or [[]] * len(prompts_text))
485
+
486
+ # build the multimodal inputs expected by vLLM
487
+ vllm_inputs = [
488
+ {"prompt": p, "multi_modal_data": m[0] if m else {}}
489
+ for p, m in zip(all_prompts, all_mm_data)
490
+ ]
491
+
492
+ # -------------------------------------------------------------
493
+ if self.accelerator.is_main_process:
494
+ p = copy.deepcopy(self.sampling_params)
495
+ p.n = n
496
+ outs = self.llm.generate(vllm_inputs, sampling_params=p, use_tqdm=False)
497
+ comp_ids = [o.token_ids for c in outs for o in c.outputs]
498
+ else:
499
+ comp_ids = [None] * (len(vllm_inputs) * n)
500
+
501
+ # broadcast back, pick this rank’s slice
502
+ comp_ids = broadcast_object_list(comp_ids, from_process=0)
503
+ lo = self.accelerator.process_index * len(prompts_text) * n
504
+ hi = (self.accelerator.process_index + 1) * len(prompts_text) * n
505
+ comp_ids = comp_ids[lo:hi]
506
+
507
+ # pad, convert to tensor → decode
508
+ comp_ids = [torch.tensor(x, device=device) for x in comp_ids]
509
+ comp_ids = pad(comp_ids, padding_value=self.processing_class.pad_token_id)
510
+ decoded = self.processing_class.batch_decode(comp_ids, skip_special_tokens=True)
511
+ return comp_ids, decoded
512
+
513
+
514
+ def compute_loss(
515
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
516
+ ):
517
+ if return_outputs:
518
+ raise ValueError("The GRPOTrainer does not support returning outputs")
519
+ # Compute the per-token log probabilities for the model
520
+
521
+
522
+ device = self.accelerator.device
523
+ prompts = [x["prompt"] for x in inputs]
524
+ # images = [x["image"] for x in inputs]
525
+ prompts_text = [
526
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
527
+ for example in inputs
528
+ ]
529
+
530
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
531
+
532
+ input_copy = self.remove_none_from_data(input_copy)
533
+
534
+ data_type = inputs[0]['data_type']
535
+
536
+ if data_type == 'image':
537
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
538
+ elif data_type == 'video':
539
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
540
+
541
+
542
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
543
+
544
+
545
+ prompt_inputs = self.processing_class(
546
+ text=copy.deepcopy(prompts_text),
547
+ images=image_inputs,
548
+ videos=video_inputs,
549
+ return_tensors="pt",
550
+ padding=True,
551
+ padding_side="left",
552
+ add_special_tokens=False,
553
+ )
554
+
555
+ mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
556
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
557
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
558
+
559
+ if self.max_prompt_length is not None:
560
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
561
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
562
+
563
+
564
+ if self.temporal:
565
+ if video_inputs:
566
+ indices = torch.randperm(video_inputs[0].size(0))
567
+ shuffled_video_inputs = [video_inputs[0][indices]]
568
+ shuffled_prompt_inputs = self.processing_class(
569
+ text=copy.deepcopy(prompts_text),
570
+ images=image_inputs,
571
+ videos=shuffled_video_inputs,
572
+ return_tensors="pt",
573
+ padding=True,
574
+ padding_side="left",
575
+ add_special_tokens=False,
576
+ )
577
+ shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
578
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
579
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
580
+ if self.max_prompt_length is not None:
581
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
582
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
583
+ else:
584
+ shuffled_mm_data = [None]
585
+
586
+
587
+
588
+ if self.args.use_vllm:
589
+ # First, have main process load weights if needed
590
+ if self.state.global_step != self._last_loaded_step:
591
+ with unwrap_model_for_generation(
592
+ self.model,
593
+ self.accelerator,
594
+ gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
595
+ ) as unwrapped_model:
596
+ if is_compiled_module(unwrapped_model):
597
+ state_dict = unwrapped_model._orig_mod.state_dict()
598
+ else:
599
+ state_dict = unwrapped_model.state_dict()
600
+ if self.accelerator.is_main_process:
601
+ llm_model = (
602
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
603
+ )
604
+ # import pdb
605
+ # pdb.set_trace()
606
+ llm_model.load_weights(state_dict.items())
607
+ self._last_loaded_step = self.state.global_step
608
+ '''
609
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
610
+ all_prompts_text = gather_object(prompts_text)
611
+ all_mm_data = gather_object(mm_data)
612
+ # group into pairs
613
+ all_multimodal_inputs = []
614
+
615
+ if self.temporal:
616
+ shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
617
+ shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
618
+ shuffled_all_multimodal_inputs = []
619
+
620
+ # 2. Refer to TobiasLee's implementation suggestions
621
+ # this is a better implementation for vLLM sampling.
622
+ for prompt, mm_item in zip(all_prompts_text, all_mm_data):
623
+ all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
624
+
625
+ if self.temporal and shuffled_all_mm_data!=[]:
626
+ for mm_item in shuffled_all_mm_data:
627
+ shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
628
+
629
+ # Create sampling params with num_generations
630
+ if self.accelerator.is_main_process:
631
+ # Clone to avoid modifying original params
632
+ sampling_params = copy.deepcopy(self.sampling_params)
633
+ sampling_params.n = self.num_generations
634
+ # Single generate call with all prompts
635
+ if self.accelerator.is_main_process:
636
+ outputs = self.llm.generate(
637
+ all_multimodal_inputs,
638
+ sampling_params=sampling_params,
639
+ use_tqdm=False,
640
+ )
641
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
642
+ completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
643
+
644
+ if self.temporal and shuffled_all_mm_data!=[]:
645
+ # Clone to avoid modifying original params
646
+ shuffled_sampling_params = copy.deepcopy(self.sampling_params)
647
+ shuffled_sampling_params.n = self.num_generations // 2
648
+ # Single generate call with all prompts
649
+ if self.accelerator.is_main_process:
650
+ shuffled_outputs = self.llm.generate(
651
+ shuffled_all_multimodal_inputs,
652
+ sampling_params=shuffled_sampling_params,
653
+ use_tqdm=False,
654
+ )
655
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
656
+ shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
657
+
658
+
659
+ else:
660
+ completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
661
+
662
+ if self.temporal and shuffled_all_mm_data!=[]:
663
+ shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
664
+
665
+
666
+ # broadcast and slice
667
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
668
+ process_slice = slice(
669
+ self.accelerator.process_index * len(prompts) * self.num_generations,
670
+ (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
671
+ )
672
+ completion_ids = completion_ids[process_slice]
673
+
674
+ # Pad the completions, and concatenate them with the prompts
675
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
676
+ completion_ids = pad(
677
+ completion_ids, padding_value=self.processing_class.pad_token_id
678
+ )
679
+ '''
680
+
681
+ completion_ids, completions = self._vllm_generate(
682
+ prompts_text, # original text prompts
683
+ mm_data, # vision payload (may be empty for text-only)
684
+ self.num_generations,
685
+ )
686
+
687
+ prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
688
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
689
+
690
+ prompt_length = prompt_ids.size(1)
691
+
692
+ print('prompt_length:', prompt_length)
693
+
694
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
695
+ completion_ids = prompt_completion_ids[:, prompt_length:]
696
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
697
+
698
+
699
+ '''
700
+ This is the additional code that avoids the shuffled_all_mm_data variable undefined error.
701
+ '''
702
+ if self.temporal and video_inputs:
703
+ # ❶ make the shuffled video batch (you already computed shuffled_video_inputs)
704
+ local_shuffled_mm = [[data_type, shuffled_video_inputs]]
705
+ shuffled_prompts = copy.deepcopy(prompts_text)
706
+
707
+ # ❷ generate half as many completions for each prompt
708
+ shuffled_completion_ids, _ = self._vllm_generate(
709
+ prompts_text=shuffled_prompts,
710
+ mm_data=local_shuffled_mm,
711
+ n=self.num_generations // 2,
712
+ )
713
+
714
+ # ❸ mimic the old triple-list so later broadcast logic works unchanged
715
+ shuffled_all_mm_data = [[self.accelerator.process_index,
716
+ data_type,
717
+ shuffled_video_inputs]]
718
+ # -----------------------------------------------------------------
719
+
720
+ if self.temporal and shuffled_all_mm_data!=[]:
721
+ # broadcast and slice
722
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
723
+ process_id_list = []
724
+ for mm_item in shuffled_all_mm_data:
725
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
726
+
727
+ if video_inputs:
728
+ cur_shuffled_completion_ids = []
729
+ for i in range(len(process_id_list)):
730
+ if self.accelerator.process_index == process_id_list[i]:
731
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
732
+
733
+ # Pad the completions, and concatenate them with the prompts
734
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
735
+ cur_shuffled_completion_ids = pad(
736
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
737
+ )
738
+ shuffled_completion_ids = cur_shuffled_completion_ids
739
+
740
+
741
+ else:
742
+ raise ValueError("Only vLLM generation is supported in this version ")
743
+ '''Above is additional code'''
744
+
745
+
746
+ if self.temporal and shuffled_all_mm_data!=[]:
747
+ # broadcast and slice
748
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
749
+ process_id_list = []
750
+ for mm_item in shuffled_all_mm_data:
751
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
752
+
753
+ if video_inputs:
754
+ cur_shuffled_completion_ids = []
755
+ for i in range(len(process_id_list)):
756
+ if self.accelerator.process_index == process_id_list[i]:
757
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
758
+
759
+ # Pad the completions, and concatenate them with the prompts
760
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
761
+ cur_shuffled_completion_ids = pad(
762
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
763
+ )
764
+ shuffled_completion_ids = cur_shuffled_completion_ids
765
+
766
+
767
+ else:
768
+ raise ValueError("Only vLLM generation is supported in this version ")
769
+
770
+ # below are the same with yifan's code
771
+ # Mask everything after the first EOS token
772
+ is_eos = completion_ids == self.processing_class.eos_token_id
773
+ device = self.accelerator.device
774
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
775
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
776
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
777
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
778
+
779
+
780
+
781
+ prompt_inputs.pop("input_ids")
782
+ prompt_inputs.pop("attention_mask")
783
+
784
+ if data_type == 'image':
785
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
786
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
787
+ # import pdb; pdb.set_trace()
788
+
789
+
790
+ if data_type == 'video':
791
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
792
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
793
+ if 'second_per_grid_ts' in prompt_inputs:
794
+ del prompt_inputs["second_per_grid_ts"]
795
+
796
+ # import pdb
797
+ # pdb.set_trace()
798
+
799
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
800
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
801
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
802
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
803
+
804
+ gc.collect()
805
+ torch.cuda.empty_cache()
806
+
807
+ with torch.inference_mode():
808
+ if self.ref_model is not None:
809
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
810
+ else:
811
+ with self.accelerator.unwrap_model(model).disable_adapter():
812
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
813
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
814
+
815
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
816
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
817
+
818
+ gc.collect()
819
+ torch.cuda.empty_cache()
820
+
821
+ if self.temporal and video_inputs:
822
+
823
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
824
+ if is_conversational(inputs[0]):
825
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
826
+
827
+ # Compute the rewards
828
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
829
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
830
+ for i, (reward_func, reward_processing_class) in enumerate(
831
+ zip(self.reward_funcs, self.reward_processing_classes)
832
+ ):
833
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
834
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
835
+ for key in shuffled_reward_kwargs:
836
+ for example in inputs:
837
+ # Repeat each value in the column for `num_generations` times
838
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
839
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
840
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
841
+
842
+
843
+
844
+ # Decode the generated completions
845
+ completions = self.processing_class.batch_decode(
846
+ completion_ids, skip_special_tokens=True
847
+ )
848
+ if is_conversational(inputs[0]):
849
+ completions = [
850
+ [{"role": "assistant", "content": completion}]
851
+ for completion in completions
852
+ ]
853
+
854
+
855
+ '''Below is code for second completions generation'''
856
+ if is_conversational(inputs[0]):
857
+ first_texts = [c[0]["content"] for c in completions]
858
+ else:
859
+ first_texts = completions
860
+
861
+ # ------------------------------------------------------------
862
+ # 2️⃣ Build follow-up prompts with `extract_info`
863
+ # ------------------------------------------------------------
864
+ follow_up_prompts = [extract_info(txt) for txt in first_texts]
865
+
866
+ # ------------------------------------------------------------
867
+ # 3️⃣ SECOND-hop generation ➜ `second_completions`
868
+ # ------------------------------------------------------------
869
+ _, second_texts = self._vllm_generate(
870
+ follow_up_prompts, # new prompts (pure text)
871
+ None, # no vision payload
872
+ 1 # one follow-up per prompt
873
+ )
874
+
875
+ # pack in chat format if needed
876
+ if is_conversational(inputs[0]):
877
+ second_completions = [
878
+ [{"role": "assistant", "content": t}] for t in second_texts
879
+ ]
880
+ else:
881
+ second_completions = second_texts
882
+
883
+ '''Above is code for second completions generation'''
884
+
885
+ # Compute the rewards
886
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
887
+ rewards_per_func = torch.zeros(
888
+ len(prompts), len(self.reward_funcs), device=device
889
+ )
890
+ for i, (reward_func, reward_processing_class) in enumerate(
891
+ zip(self.reward_funcs, self.reward_processing_classes)
892
+ ):
893
+ reward_kwargs = {
894
+ key: []
895
+ for key in inputs[0].keys()
896
+ if key not in ["prompt", "completion"]
897
+ }
898
+
899
+ '''Below is code for taking second generations'''
900
+ # every original example contributes `self.num_generations`
901
+ for example in inputs:
902
+ for _ in range(self.num_generations): # n times
903
+ for key in reward_kwargs:
904
+ reward_kwargs[key].append(example[key])
905
+
906
+ # -------- call the reward function --------
907
+ outputs = reward_func(
908
+ prompts=follow_up_prompts, # ⬅ extracted info
909
+ completions=second_completions, # ⬅ fresh answers
910
+ **reward_kwargs,
911
+ )
912
+ rewards_per_func[:, i] = torch.tensor(outputs, dtype=torch.float32, device=device)
913
+ '''Above is code for taking second generations'''
914
+
915
+ # for key in reward_kwargs:
916
+ # for example in inputs:
917
+ # # Repeat each value in the column for `num_generations` times
918
+ # reward_kwargs[key].extend([example[key]] * self.num_generations)
919
+ # output_reward_func = reward_func(
920
+ # prompts=prompts, completions=completions, **reward_kwargs
921
+ # )
922
+ # rewards_per_func[:, i] = torch.tensor(
923
+ # output_reward_func, dtype=torch.float32, device=device
924
+ # )
925
+
926
+
927
+ # rewards_per_func = gather(rewards_per_func)
928
+ # # Sum the rewards from all reward functions
929
+ # rewards = rewards_per_func.sum(dim=1)
930
+
931
+ # process_slice = slice(
932
+ # self.accelerator.process_index * len(prompts),
933
+ # (self.accelerator.process_index + 1) * len(prompts),
934
+ # )
935
+
936
+ # rewards = rewards[process_slice]
937
+
938
+
939
+
940
+ if self.temporal and video_inputs:
941
+ temporal_rewards_per_func = rewards_per_func.clone()
942
+
943
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
944
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
945
+
946
+ if acc_mean >= 0.8 * shuffled_acc_mean:
947
+ mask = temporal_rewards_per_func[:, 0] > 0.1
948
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
949
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
950
+ else:
951
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
952
+ else:
953
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
954
+
955
+ # Sum the rewards from all reward functions
956
+ if self.temporal and video_inputs:
957
+ rewards = temporal_rewards_per_func.sum(dim=1)
958
+ else:
959
+ rewards = rewards_per_func.sum(dim=1)
960
+
961
+ if self.len_control:
962
+ mem_rewards = [0] * self.num_generations
963
+ mask = rewards_per_func[:, 0] > 0.1
964
+ lenth_list = completion_mask.sum(1)
965
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
966
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
967
+ # if len(selected_indices) > 1:
968
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
969
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
970
+ # N = len(sorted_items)
971
+ # for rank, (idx, length) in enumerate(sorted_items):
972
+ # reward = 0.2 - 0.2 * (rank / N)
973
+ # rewards[idx] += reward
974
+ # mem_rewards[idx] = reward
975
+ # for idx in range(len(lenth_list)):
976
+ # if lenth_list[idx] >= 512:
977
+ # rewards[idx] -= 0.5
978
+
979
+ if len(selected_indices) > 1:
980
+ for idx in selected_indices:
981
+ if 320 <= lenth_list[idx] <= 512:
982
+ rewards[idx] += 0.2
983
+
984
+ print(rewards)
985
+ print(completion_mask.sum(1))
986
+
987
+ # Compute grouped-wise rewards
988
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
989
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
990
+
991
+ # Normalize the rewards to compute the advantages
992
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
993
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
994
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
995
+
996
+ # x - x.detach() allows for preserving gradients from x
997
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
998
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
999
+ # per_token_loss = -per_token_loss
1000
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1001
+
1002
+
1003
+ # import pdb
1004
+ # pdb.set_trace()
1005
+
1006
+ # Log the metrics
1007
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
1008
+ self._metrics["completion_length"].append(completion_length)
1009
+
1010
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
1011
+ for i, reward_func in enumerate(self.reward_funcs):
1012
+ if isinstance(reward_func, PreTrainedModel):
1013
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
1014
+ else:
1015
+ reward_func_name = reward_func.__name__
1016
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
1017
+
1018
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
1019
+
1020
+ num_devices = gathered_rewards.size(0) // self.num_generations
1021
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
1022
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
1023
+ wrong_ratio = wrong_devices.sum().item() / num_devices
1024
+
1025
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
1026
+ correct_ratio = correct_devices.sum().item() / num_devices
1027
+
1028
+ self._metrics["all_wrong"].append(wrong_ratio)
1029
+ self._metrics["all_correct"].append(correct_ratio)
1030
+
1031
+ if self.temporal:
1032
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
1033
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
1034
+
1035
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
1036
+
1037
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
1038
+
1039
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
1040
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
1041
+
1042
+
1043
+ return loss
1044
+
1045
+
1046
+
1047
+
1048
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
1049
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
1050
+
1051
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
1052
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
1053
+ if next(iter(logs.keys())).startswith("eval_"):
1054
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
1055
+
1056
+ logs = {**logs, **metrics}
1057
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
1058
+ super().log(logs, start_time)
1059
+ else: # transformers<=4.46
1060
+ super().log(logs)
1061
+ self._metrics.clear()
src/r1-v/src/open_r1/trainer/vllm_grpo_trainer_modified_orig.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import textwrap
17
+ from collections import defaultdict
18
+ from typing import Any, Callable, Optional, Union
19
+ from accelerate.utils.other import is_compiled_module
20
+ from accelerate.utils import broadcast_object_list, gather, gather_object
21
+ import torch
22
+ import torch.utils.data
23
+ import transformers
24
+ import warnings
25
+ from unittest.mock import patch
26
+ from datasets import Dataset, IterableDataset
27
+ from packaging import version
28
+ from transformers import (
29
+ AriaForConditionalGeneration,
30
+ AriaProcessor,
31
+ AutoModelForCausalLM,
32
+ AutoModelForSequenceClassification,
33
+ AutoProcessor,
34
+ AutoTokenizer,
35
+ GenerationConfig,
36
+ PreTrainedModel,
37
+ PreTrainedTokenizerBase,
38
+ Qwen2VLForConditionalGeneration,
39
+ Qwen2_5_VLForConditionalGeneration,
40
+ Trainer,
41
+ TrainerCallback,
42
+ is_wandb_available,
43
+ )
44
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
45
+ from transformers.utils import is_peft_available
46
+
47
+ from trl.data_utils import (
48
+ apply_chat_template,
49
+ is_conversational,
50
+ maybe_apply_chat_template,
51
+ )
52
+ from trl.import_utils import is_vllm_available
53
+
54
+ from trl.models import (
55
+ create_reference_model,
56
+ prepare_deepspeed,
57
+ unwrap_model_for_generation,
58
+ )
59
+ from trl.trainer.grpo_config import GRPOConfig
60
+ from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad
61
+ from trl import GRPOTrainer
62
+
63
+ import copy
64
+
65
+ if is_peft_available():
66
+ from peft import PeftConfig, get_peft_model
67
+
68
+ if is_vllm_available():
69
+ from vllm import LLM, SamplingParams
70
+
71
+ if is_wandb_available():
72
+ import wandb
73
+ import torch.nn as nn
74
+ from torch.utils.data import Sampler
75
+ import gc
76
+ from qwen_vl_utils import process_vision_info
77
+
78
+ # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
79
+ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
80
+ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
81
+
82
+
83
+ class Qwen2VLGRPOVLLMTrainerModifiedOrig(Trainer):
84
+ def __init__(
85
+ self,
86
+ model: Union[str, PreTrainedModel],
87
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
88
+ args: GRPOConfig = None,
89
+ script_args = None,
90
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
91
+ eval_dataset: Optional[
92
+ Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
93
+ ] = None,
94
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
95
+ reward_processing_classes: Optional[
96
+ Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
97
+ ] = None,
98
+ callbacks: Optional[list[TrainerCallback]] = None,
99
+ optimizers: tuple[
100
+ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
101
+ ] = (None, None),
102
+ peft_config: Optional["PeftConfig"] = None,
103
+ # qwen2-vl related params
104
+ max_pixels: Optional[int] = 12845056,
105
+ min_pixels: Optional[int] = 3136,
106
+ attn_implementation: str = "flash_attention_2",
107
+ ):
108
+
109
+ # Args
110
+ if args is None:
111
+ model_name = model if isinstance(model, str) else model.config._name_or_path
112
+ model_name = model_name.split("/")[-1]
113
+ args = GRPOConfig(f"{model_name}-GRPO")
114
+
115
+ # Models
116
+ # Trained model
117
+ model_init_kwargs = args.model_init_kwargs or {}
118
+ model_init_kwargs["attn_implementation"] = attn_implementation
119
+ if isinstance(model, str):
120
+ model_id = model
121
+ torch_dtype = model_init_kwargs.get("torch_dtype")
122
+ if (
123
+ isinstance(torch_dtype, torch.dtype)
124
+ or torch_dtype == "auto"
125
+ or torch_dtype is None
126
+ ):
127
+ pass # torch_dtype is already a torch.dtype or "auto" or None
128
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
129
+ torch_dtype = getattr(torch, torch_dtype)
130
+ model_init_kwargs["torch_dtype"] = torch_dtype
131
+ else:
132
+ raise ValueError(
133
+ "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
134
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
135
+ )
136
+ # Disable caching if gradient checkpointing is enabled (not supported)
137
+ model_init_kwargs["use_cache"] = (
138
+ False
139
+ if args.gradient_checkpointing
140
+ else model_init_kwargs.get("use_cache")
141
+ )
142
+ if "Qwen2-VL" in model_id:
143
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
144
+ model, **model_init_kwargs
145
+ )
146
+ elif "Qwen2.5-VL" in model_id:
147
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
148
+ model, **model_init_kwargs
149
+ )
150
+ elif "Aria" in model_id:
151
+ model_init_kwargs.pop("use_cache")
152
+ model = AriaForConditionalGeneration.from_pretrained(
153
+ model, **model_init_kwargs
154
+ )
155
+ else:
156
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
157
+ else:
158
+ model_id = model.config._name_or_path
159
+ if args.model_init_kwargs is not None:
160
+ raise ValueError(
161
+ "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
162
+ "This argument can only be used when the `model` argument is a string."
163
+ )
164
+
165
+ if peft_config is not None:
166
+ model = get_peft_model(model, peft_config)
167
+
168
+ # Reference model
169
+ if is_deepspeed_zero3_enabled():
170
+ if "Qwen2-VL" in model_id:
171
+ self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(
172
+ model_id, **model_init_kwargs
173
+ )
174
+ elif "Qwen2.5-VL" in model_id:
175
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
176
+ model_id, **model_init_kwargs
177
+ )
178
+ elif "Aria" in model_id:
179
+ self.ref_model = AriaForConditionalGeneration.from_pretrained(
180
+ model_id, **model_init_kwargs
181
+ )
182
+ else:
183
+ self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
184
+ model_id, **model_init_kwargs
185
+ )
186
+ elif peft_config is None:
187
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
188
+ self.ref_model = create_reference_model(model)
189
+ else:
190
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
191
+ # to revert to the initial model.
192
+ self.ref_model = None
193
+
194
+ # Processing class
195
+ # if processing_class is None:
196
+ # if "Qwen" in model_id or "Aria" in model_id:
197
+ # processing_class = AutoProcessor.from_pretrained(model_id)
198
+ # pad_token_id = processing_class.tokenizer.pad_token_id
199
+ # processing_class.pad_token_id = pad_token_id
200
+ # processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
201
+ # if "Qwen" in model_id:
202
+ # processing_class.image_processor.max_pixels = max_pixels
203
+ # processing_class.image_processor.min_pixels = min_pixels
204
+ # else:
205
+ # processing_class = AutoTokenizer.from_pretrained(
206
+ # model.config._name_or_path, padding_side="left"
207
+ # )
208
+ # pad_token_id = processing_class.pad_token_id
209
+
210
+
211
+ # ────────────────────────────────────────────────────────────────
212
+ # Robust processor loading ― works for both fresh models *and* checkpoints
213
+ # ────────────────────────────────────────────────────────────────
214
+ if processing_class is None:
215
+ # 1️⃣ First try to load whatever lives in the directory we were given.
216
+ # This succeeds if you previously did `processor.save_pretrained(output_dir)`.
217
+ try:
218
+ processing_class = AutoProcessor.from_pretrained(model_id)
219
+ pad_token_id = processing_class.tokenizer.pad_token_id
220
+ except (OSError, ValueError): # no processor files found
221
+ # 2️⃣ Fall back to inspecting the *model object* instead of the path.
222
+ is_vl_model = (
223
+ hasattr(model, "vision_tower") or # Qwen-VL, InternVL, etc.
224
+ getattr(model.config, "vision_config", None) is not None or
225
+ getattr(model.config, "image_vocab_size", None) is not None
226
+ )
227
+
228
+ if is_vl_model:
229
+ # Always use the *base* model name stored in the config.
230
+ base_name = model.config._name_or_path # e.g. "Qwen/Qwen2.5-VL-7B-Instruct"
231
+ processing_class = AutoProcessor.from_pretrained(base_name)
232
+ pad_token_id = processing_class.tokenizer.pad_token_id
233
+
234
+ # Optional Qwen-specific limits
235
+ if hasattr(processing_class, "image_processor"):
236
+ processing_class.image_processor.max_pixels = max_pixels
237
+ processing_class.image_processor.min_pixels = min_pixels
238
+ else:
239
+ # Pure text model → plain tokenizer
240
+ processing_class = AutoTokenizer.from_pretrained(
241
+ model.config._name_or_path, padding_side="left"
242
+ )
243
+ pad_token_id = processing_class.pad_token_id
244
+
245
+ # 3️⃣ Harmonise attributes the rest of the trainer expects
246
+ processing_class.pad_token_id = pad_token_id
247
+ if not hasattr(processing_class, "eos_token_id"):
248
+ processing_class.eos_token_id = pad_token_id
249
+ # ────────────────────────────────────────────────────────────────
250
+
251
+ # Reward functions
252
+ if not isinstance(reward_funcs, list):
253
+ reward_funcs = [reward_funcs]
254
+ for i, reward_func in enumerate(reward_funcs):
255
+ if isinstance(reward_func, str):
256
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
257
+ reward_func, num_labels=1, **model_init_kwargs
258
+ )
259
+ self.reward_funcs = reward_funcs
260
+
261
+ # Reward processing class
262
+ if reward_processing_classes is None:
263
+ reward_processing_classes = [None] * len(reward_funcs)
264
+ elif not isinstance(reward_processing_classes, list):
265
+ reward_processing_classes = [reward_processing_classes]
266
+ else:
267
+ if len(reward_processing_classes) != len(reward_funcs):
268
+ raise ValueError(
269
+ "The number of reward processing classes must match the number of reward functions."
270
+ )
271
+
272
+ for i, (reward_processing_class, reward_func) in enumerate(
273
+ zip(reward_processing_classes, reward_funcs)
274
+ ):
275
+ if isinstance(reward_func, PreTrainedModel):
276
+ if reward_processing_class is None:
277
+ reward_processing_class = AutoTokenizer.from_pretrained(
278
+ reward_func.config._name_or_path
279
+ )
280
+ if reward_processing_class.pad_token_id is None:
281
+ reward_processing_class.pad_token = (
282
+ reward_processing_class.eos_token
283
+ )
284
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
285
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
286
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
287
+ reward_processing_classes[i] = reward_processing_class
288
+ self.reward_processing_classes = reward_processing_classes
289
+
290
+ # Data collator
291
+ def data_collator(features): # No data collation is needed in GRPO
292
+ return features
293
+
294
+ # Training arguments
295
+ self.max_prompt_length = args.max_prompt_length
296
+ self.max_completion_length = (
297
+ args.max_completion_length
298
+ ) # = |o_i| in the GRPO paper
299
+ self.num_generations = args.num_generations # = G in the GRPO paper
300
+ self.temporal = script_args.temporal
301
+ self.generation_config = GenerationConfig(
302
+ max_new_tokens=self.max_completion_length,
303
+ do_sample=True,
304
+ temperature=1, # HACK
305
+ num_return_sequences=self.num_generations,
306
+ pad_token_id=pad_token_id,
307
+ )
308
+ self.beta = args.beta
309
+
310
+ self.shuffled_num_generations = self.num_generations // 2
311
+ self.shuffled_generation_config = GenerationConfig(
312
+ max_new_tokens=self.max_completion_length,
313
+ do_sample=True,
314
+ top_p=0.95,
315
+ temperature=1, # HACK
316
+ num_return_sequences=self.shuffled_num_generations,
317
+ pad_token_id=pad_token_id,
318
+ )
319
+
320
+ self.dummy_generation_config = GenerationConfig(
321
+ max_new_tokens=1,
322
+ do_sample=True,
323
+ top_p=0.95,
324
+ temperature=1, # HACK
325
+ num_return_sequences=1,
326
+ pad_token_id=pad_token_id,
327
+ )
328
+ self.len_control = script_args.len_control
329
+ self.beta = args.beta
330
+
331
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
332
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
333
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
334
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
335
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
336
+ # This acts as a flag to indicate that the warning has already been issued.
337
+ model.warnings_issued["estimate_tokens"] = True
338
+
339
+ # Initialize the metrics
340
+ self._metrics = defaultdict(list)
341
+ self.use_vllm = args.use_vllm
342
+
343
+ super().__init__(
344
+ model=model,
345
+ args=args,
346
+ data_collator=data_collator,
347
+ train_dataset=train_dataset,
348
+ eval_dataset=eval_dataset,
349
+ processing_class=processing_class,
350
+ callbacks=callbacks,
351
+ optimizers=optimizers,
352
+ )
353
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
354
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
355
+ # self.model_accepts_loss_kwargs to False to enable scaling.
356
+ self.model_accepts_loss_kwargs = False
357
+
358
+ if self.use_vllm:
359
+ if not is_vllm_available():
360
+ raise ImportError(
361
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
362
+ "`pip install vllm` to use it."
363
+ )
364
+
365
+ if self.accelerator.is_main_process:
366
+ vllm_device = self.args.vllm_device
367
+ if vllm_device == "auto":
368
+ vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
369
+ # Check that the requested device is available
370
+ if (
371
+ vllm_device.split(":")[0] == "cuda"
372
+ and int(vllm_device.split(":")[1]) >= torch.cuda.device_count()
373
+ ):
374
+ raise ValueError(
375
+ f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
376
+ "without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
377
+ "value lower than the number of GPUs available on your machine—typically, reducing it by one "
378
+ f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
379
+ )
380
+ # Check that the requested device is not also used for training
381
+ if vllm_device in {
382
+ f"cuda:{idx}" for idx in range(self.accelerator.num_processes)
383
+ }:
384
+ warnings.warn(
385
+ f"The requested device {vllm_device} is also used for training. This may lead to unexpected "
386
+ "behavior. It is recommended to use a dedicated device for vLLM."
387
+ )
388
+ # vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
389
+ # model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
390
+ # setting (profiling_patch).
391
+ world_size_patch = patch(
392
+ "torch.distributed.get_world_size", return_value=1
393
+ )
394
+ profiling_patch = patch(
395
+ "vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling",
396
+ return_value=None,
397
+ )
398
+ with world_size_patch, profiling_patch:
399
+ print("vllm is running on: ", vllm_device)
400
+ self.llm = LLM(
401
+ model=model.name_or_path,
402
+ device=vllm_device,
403
+ gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
404
+ dtype=torch.bfloat16,
405
+ # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
406
+ # directly reuse the KV cache if it shares the same prefix with one of the existing queries.
407
+ # This is particularly useful here because we generate completions from the same prompts.
408
+ enable_prefix_caching=True,
409
+ enforce_eager=True,
410
+ mm_processor_kwargs=(
411
+ {
412
+ "max_pixels": max_pixels,
413
+ "min_pixels": min_pixels,
414
+ }
415
+ # if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id
416
+ if False
417
+ else None
418
+ ),
419
+ max_model_len=args.max_prompt_length + args.max_completion_length,
420
+ )
421
+ self.sampling_params = SamplingParams(
422
+ temperature=1.0,
423
+ top_p=0.95,
424
+ max_tokens=self.max_completion_length,
425
+ )
426
+
427
+ self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
428
+
429
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
430
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
431
+ # synchronize all processes after vLLM has been fully initialized.
432
+ self.accelerator.wait_for_everyone()
433
+ else:
434
+ raise ValueError(
435
+ "GRPOVLLMTrainerModified only supports vllm generation, please set --use_vllm True"
436
+ )
437
+
438
+ if self.ref_model is not None:
439
+ if self.is_deepspeed_enabled:
440
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
441
+ else:
442
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
443
+
444
+ for i, reward_func in enumerate(self.reward_funcs):
445
+ if isinstance(reward_func, PreTrainedModel):
446
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
447
+
448
+ def _set_signature_columns_if_needed(self):
449
+ # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
450
+ # By default, this method sets `self._signature_columns` to the model's expected inputs.
451
+ # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
452
+ # Instead, we set them to the columns expected by the `training_step` method, hence the override.
453
+ if self._signature_columns is None:
454
+ self._signature_columns = ["prompt"]
455
+
456
+ # Get the per-token log probabilities for the completions for the model and the reference model
457
+ def _get_per_token_logps(self, model, input_ids, **kwargs):
458
+ # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V)
459
+ # import pdb
460
+ # pdb.set_trace()
461
+ logits = model(input_ids, **kwargs).logits
462
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
463
+ input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
464
+ # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
465
+ per_token_logps = []
466
+ for logits_row, input_ids_row in zip(logits, input_ids):
467
+ log_probs = logits_row.log_softmax(dim=-1)
468
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
469
+ per_token_logps.append(token_log_prob)
470
+ return torch.stack(per_token_logps)
471
+
472
+ # Trainer "prepares" the inputs before calling `compute_loss`. It converts to tensor and move to device.
473
+ # Since we preprocess the data in `compute_loss`, we need to override this method to skip this step.
474
+ def _prepare_inputs(
475
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
476
+ ) -> dict[str, Union[torch.Tensor, Any]]:
477
+ return inputs
478
+
479
+ def remove_none_from_data(self, data):
480
+ for entry in data:
481
+ if "content" in entry and isinstance(entry["content"], list):
482
+ for sub_entry in entry["content"]:
483
+ if isinstance(sub_entry, dict):
484
+ keys_to_remove = [k for k, v in sub_entry.items() if v is None]
485
+ for k in keys_to_remove:
486
+ del sub_entry[k]
487
+ return data
488
+
489
+
490
+
491
+ def compute_loss(
492
+ self, model, inputs, return_outputs=False, num_items_in_batch=None
493
+ ):
494
+ if return_outputs:
495
+ raise ValueError("The GRPOTrainer does not support returning outputs")
496
+ # Compute the per-token log probabilities for the model
497
+
498
+
499
+ device = self.accelerator.device
500
+ prompts = [x["prompt"] for x in inputs]
501
+ # images = [x["image"] for x in inputs]
502
+ prompts_text = [
503
+ maybe_apply_chat_template(example, self.processing_class)["prompt"]
504
+ for example in inputs
505
+ ]
506
+
507
+ input_copy = copy.deepcopy(inputs[0]['prompt'])
508
+
509
+ input_copy = self.remove_none_from_data(input_copy)
510
+
511
+ data_type = inputs[0]['data_type']
512
+
513
+ if data_type == 'image':
514
+ input_copy[0]['content'][0]['image'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
515
+ elif data_type == 'video':
516
+ input_copy[0]['content'][0]['video'] = os.getcwd() + "/Video-R1-data" + inputs[0]['path'][1:]
517
+
518
+
519
+ image_inputs, video_inputs, video_kwargs = process_vision_info(input_copy, return_video_kwargs=True)
520
+
521
+
522
+ prompt_inputs = self.processing_class(
523
+ text=copy.deepcopy(prompts_text),
524
+ images=image_inputs,
525
+ videos=video_inputs,
526
+ return_tensors="pt",
527
+ padding=True,
528
+ padding_side="left",
529
+ add_special_tokens=False,
530
+ )
531
+
532
+ mm_data = [[data_type, image_inputs if image_inputs else video_inputs]]
533
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
534
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
535
+
536
+ if self.max_prompt_length is not None:
537
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
538
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
539
+
540
+
541
+ if self.temporal:
542
+ if video_inputs:
543
+ indices = torch.randperm(video_inputs[0].size(0))
544
+ shuffled_video_inputs = [video_inputs[0][indices]]
545
+ shuffled_prompt_inputs = self.processing_class(
546
+ text=copy.deepcopy(prompts_text),
547
+ images=image_inputs,
548
+ videos=shuffled_video_inputs,
549
+ return_tensors="pt",
550
+ padding=True,
551
+ padding_side="left",
552
+ add_special_tokens=False,
553
+ )
554
+ shuffled_mm_data = [[self.accelerator.process_index, data_type, image_inputs if image_inputs else video_inputs]]
555
+ shuffled_prompt_inputs = super()._prepare_inputs(shuffled_prompt_inputs)
556
+ shuffled_prompt_ids, shuffled_prompt_mask = shuffled_prompt_inputs["input_ids"], shuffled_prompt_inputs["attention_mask"]
557
+ if self.max_prompt_length is not None:
558
+ shuffled_prompt_ids = shuffled_prompt_ids[:, -self.max_prompt_length :]
559
+ shuffled_prompt_mask = shuffled_prompt_mask[:, -self.max_prompt_length :]
560
+ else:
561
+ shuffled_mm_data = [None]
562
+
563
+
564
+
565
+ if self.args.use_vllm:
566
+ # First, have main process load weights if needed
567
+ if self.state.global_step != self._last_loaded_step:
568
+ with unwrap_model_for_generation(
569
+ self.model,
570
+ self.accelerator,
571
+ gather_deepspeed3_params=True, # TODO: fix this, self.args.ds3_gather_for_generation,
572
+ ) as unwrapped_model:
573
+ if is_compiled_module(unwrapped_model):
574
+ state_dict = unwrapped_model._orig_mod.state_dict()
575
+ else:
576
+ state_dict = unwrapped_model.state_dict()
577
+ if self.accelerator.is_main_process:
578
+ llm_model = (
579
+ self.llm.llm_engine.model_executor.driver_worker.model_runner.model
580
+ )
581
+ # import pdb
582
+ # pdb.set_trace()
583
+ llm_model.load_weights(state_dict.items())
584
+ self._last_loaded_step = self.state.global_step
585
+
586
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
587
+ all_prompts_text = gather_object(prompts_text)
588
+ all_mm_data = gather_object(mm_data)
589
+ # group into pairs
590
+ all_multimodal_inputs = []
591
+
592
+ if self.temporal:
593
+ shuffled_all_mm_data_none = gather_object(shuffled_mm_data)
594
+ shuffled_all_mm_data = [x for x in shuffled_all_mm_data_none if x]
595
+ shuffled_all_multimodal_inputs = []
596
+
597
+ # 2. Refer to TobiasLee's implementation suggestions
598
+ # this is a better implementation for vLLM sampling.
599
+ for prompt, mm_item in zip(all_prompts_text, all_mm_data):
600
+ all_multimodal_inputs.append({"prompt": prompt, "multi_modal_data": {mm_item[0]: mm_item[1]}})
601
+
602
+ if self.temporal and shuffled_all_mm_data!=[]:
603
+ for mm_item in shuffled_all_mm_data:
604
+ shuffled_all_multimodal_inputs.append({"prompt": all_prompts_text[mm_item[0]], "multi_modal_data": {mm_item[1]: mm_item[2]}})
605
+
606
+ # Create sampling params with num_generations
607
+ if self.accelerator.is_main_process:
608
+ # Clone to avoid modifying original params
609
+ sampling_params = copy.deepcopy(self.sampling_params)
610
+ sampling_params.n = self.num_generations
611
+ # Single generate call with all prompts
612
+ if self.accelerator.is_main_process:
613
+ outputs = self.llm.generate(
614
+ all_multimodal_inputs,
615
+ sampling_params=sampling_params,
616
+ use_tqdm=False,
617
+ )
618
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
619
+ completion_ids = [out.token_ids for completion in outputs for out in completion.outputs]
620
+
621
+ if self.temporal and shuffled_all_mm_data!=[]:
622
+ # Clone to avoid modifying original params
623
+ shuffled_sampling_params = copy.deepcopy(self.sampling_params)
624
+ shuffled_sampling_params.n = self.num_generations // 2
625
+ # Single generate call with all prompts
626
+ if self.accelerator.is_main_process:
627
+ shuffled_outputs = self.llm.generate(
628
+ shuffled_all_multimodal_inputs,
629
+ sampling_params=shuffled_sampling_params,
630
+ use_tqdm=False,
631
+ )
632
+ # Flatten outputs: [prompt1_gen1, prompt1_gen2, ..., prompt2_gen1, prompt2_gen2, ...]
633
+ shuffled_completion_ids = [out.token_ids for completion in shuffled_outputs for out in completion.outputs]
634
+
635
+
636
+ else:
637
+ completion_ids = [None] * len(all_multimodal_inputs) * self.num_generations
638
+
639
+ if self.temporal and shuffled_all_mm_data!=[]:
640
+ shuffled_completion_ids = [None] * len(shuffled_all_multimodal_inputs) * (self.num_generations // 2)
641
+
642
+
643
+ # broadcast and slice
644
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
645
+ process_slice = slice(
646
+ self.accelerator.process_index * len(prompts) * self.num_generations,
647
+ (self.accelerator.process_index + 1) * len(prompts) * self.num_generations,
648
+ )
649
+ completion_ids = completion_ids[process_slice]
650
+
651
+ # Pad the completions, and concatenate them with the prompts
652
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
653
+ completion_ids = pad(
654
+ completion_ids, padding_value=self.processing_class.pad_token_id
655
+ )
656
+ prompt_ids = prompt_ids.repeat_interleave(self.num_generations, dim=0)
657
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
658
+
659
+ prompt_length = prompt_ids.size(1)
660
+
661
+ # print('prompt_length:', prompt_length)
662
+
663
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
664
+ completion_ids = prompt_completion_ids[:, prompt_length:]
665
+ prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
666
+
667
+
668
+ if self.temporal and shuffled_all_mm_data!=[]:
669
+ # broadcast and slice
670
+ shuffled_completion_ids = broadcast_object_list(shuffled_completion_ids, from_process=0)
671
+ process_id_list = []
672
+ for mm_item in shuffled_all_mm_data:
673
+ process_id_list += [mm_item[0]] * len(prompts) * (self.num_generations // 2)
674
+
675
+ if video_inputs:
676
+ cur_shuffled_completion_ids = []
677
+ for i in range(len(process_id_list)):
678
+ if self.accelerator.process_index == process_id_list[i]:
679
+ cur_shuffled_completion_ids.append(shuffled_completion_ids[i])
680
+
681
+ # Pad the completions, and concatenate them with the prompts
682
+ cur_shuffled_completion_ids = [torch.tensor(ids, device=device) for ids in cur_shuffled_completion_ids]
683
+ cur_shuffled_completion_ids = pad(
684
+ cur_shuffled_completion_ids, padding_value=self.processing_class.pad_token_id
685
+ )
686
+ shuffled_completion_ids = cur_shuffled_completion_ids
687
+
688
+
689
+ else:
690
+ raise ValueError("Only vLLM generation is supported in this version ")
691
+
692
+ # below are the same with yifan's code
693
+ # Mask everything after the first EOS token
694
+ is_eos = completion_ids == self.processing_class.eos_token_id
695
+ device = self.accelerator.device
696
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
697
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
698
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
699
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
700
+
701
+
702
+
703
+ prompt_inputs.pop("input_ids")
704
+ prompt_inputs.pop("attention_mask")
705
+
706
+ if data_type == 'image':
707
+ prompt_inputs["pixel_values"] = prompt_inputs["pixel_values"].repeat(len(prompt_completion_ids), 1)
708
+ prompt_inputs["image_grid_thw"] = prompt_inputs["image_grid_thw"].repeat(len(prompt_completion_ids), 1)
709
+ # import pdb; pdb.set_trace()
710
+
711
+
712
+ if data_type == 'video':
713
+ prompt_inputs["pixel_values_videos"] = prompt_inputs["pixel_values_videos"].repeat(len(prompt_completion_ids), 1)
714
+ prompt_inputs["video_grid_thw"] = prompt_inputs["video_grid_thw"].repeat(len(prompt_completion_ids), 1)
715
+ if 'second_per_grid_ts' in prompt_inputs:
716
+ del prompt_inputs["second_per_grid_ts"]
717
+
718
+ # import pdb
719
+ # pdb.set_trace()
720
+
721
+ # per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw)
722
+ per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
723
+ # Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
724
+ per_token_logps = per_token_logps[:, prompt_length - 1 :]
725
+
726
+ gc.collect()
727
+ torch.cuda.empty_cache()
728
+
729
+ with torch.inference_mode():
730
+ if self.ref_model is not None:
731
+ ref_per_token_logps = self._get_per_token_logps(self.ref_model, prompt_completion_ids, **prompt_inputs)
732
+ else:
733
+ with self.accelerator.unwrap_model(model).disable_adapter():
734
+ ref_per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
735
+ ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1 :]
736
+
737
+ x_clamped = torch.clamp(ref_per_token_logps - per_token_logps, min=-10, max=10) # 限制 x 的范围
738
+ per_token_kl = torch.exp(x_clamped) - x_clamped - 1
739
+
740
+ gc.collect()
741
+ torch.cuda.empty_cache()
742
+
743
+ if self.temporal and video_inputs:
744
+
745
+ shuffled_completions = self.processing_class.batch_decode(shuffled_completion_ids, skip_special_tokens=True)
746
+ if is_conversational(inputs[0]):
747
+ shuffled_completions = [[{"role": "assistant", "content": shuffled_completion}] for shuffled_completion in shuffled_completions]
748
+
749
+ # Compute the rewards
750
+ shuffled_prompts = [prompt for prompt in prompts for _ in range(self.shuffled_num_generations)]
751
+ shuffled_rewards_per_func = torch.zeros(len(shuffled_prompts), len(self.reward_funcs), device=device)
752
+ for i, (reward_func, reward_processing_class) in enumerate(
753
+ zip(self.reward_funcs, self.reward_processing_classes)
754
+ ):
755
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
756
+ shuffled_reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
757
+ for key in shuffled_reward_kwargs:
758
+ for example in inputs:
759
+ # Repeat each value in the column for `num_generations` times
760
+ shuffled_reward_kwargs[key].extend([example[key]] * self.shuffled_num_generations)
761
+ shuffled_output_reward_func = reward_func(prompts=shuffled_prompts, completions=shuffled_completions, **shuffled_reward_kwargs)
762
+ shuffled_rewards_per_func[:, i] = torch.tensor(shuffled_output_reward_func, dtype=torch.float32, device=device)
763
+
764
+
765
+
766
+ # Decode the generated completions
767
+ completions = self.processing_class.batch_decode(
768
+ completion_ids, skip_special_tokens=True
769
+ )
770
+ if is_conversational(inputs[0]):
771
+ completions = [
772
+ [{"role": "assistant", "content": completion}]
773
+ for completion in completions
774
+ ]
775
+
776
+ # Compute the rewards
777
+ prompts = [prompt for prompt in prompts for _ in range(self.num_generations)]
778
+ rewards_per_func = torch.zeros(
779
+ len(prompts), len(self.reward_funcs), device=device
780
+ )
781
+ for i, (reward_func, reward_processing_class) in enumerate(
782
+ zip(self.reward_funcs, self.reward_processing_classes)
783
+ ):
784
+ reward_kwargs = {
785
+ key: []
786
+ for key in inputs[0].keys()
787
+ if key not in ["prompt", "completion"]
788
+ }
789
+ for key in reward_kwargs:
790
+ for example in inputs:
791
+ # Repeat each value in the column for `num_generations` times
792
+ reward_kwargs[key].extend([example[key]] * self.num_generations)
793
+ output_reward_func = reward_func(
794
+ prompts=prompts, completions=completions, **reward_kwargs
795
+ )
796
+ rewards_per_func[:, i] = torch.tensor(
797
+ output_reward_func, dtype=torch.float32, device=device
798
+ )
799
+
800
+
801
+ # rewards_per_func = gather(rewards_per_func)
802
+ # # Sum the rewards from all reward functions
803
+ # rewards = rewards_per_func.sum(dim=1)
804
+
805
+ # process_slice = slice(
806
+ # self.accelerator.process_index * len(prompts),
807
+ # (self.accelerator.process_index + 1) * len(prompts),
808
+ # )
809
+
810
+ # rewards = rewards[process_slice]
811
+
812
+
813
+
814
+ if self.temporal and video_inputs:
815
+ temporal_rewards_per_func = rewards_per_func.clone()
816
+
817
+ acc_mean = temporal_rewards_per_func[:, 0].mean()
818
+ shuffled_acc_mean = shuffled_rewards_per_func[:, 0].mean()
819
+
820
+ if acc_mean >= 0.8 * shuffled_acc_mean:
821
+ mask = temporal_rewards_per_func[:, 0] > 0.1
822
+ temporal_rewards_per_func[mask, 0] = temporal_rewards_per_func[mask, 0] + 0.3
823
+ temporal_rewards = torch.tensor([1.0]).to('cuda')
824
+ else:
825
+ temporal_rewards = torch.tensor([0.0]).to('cuda')
826
+ else:
827
+ temporal_rewards = torch.tensor([0.5]).to('cuda')
828
+
829
+ # Sum the rewards from all reward functions
830
+ if self.temporal and video_inputs:
831
+ rewards = temporal_rewards_per_func.sum(dim=1)
832
+ else:
833
+ rewards = rewards_per_func.sum(dim=1)
834
+
835
+ if self.len_control:
836
+ mem_rewards = [0] * self.num_generations
837
+ mask = rewards_per_func[:, 0] > 0.1
838
+ lenth_list = completion_mask.sum(1)
839
+ selected_indices = torch.nonzero(mask, as_tuple=True)[0].tolist()
840
+ # if len(selected_indices) > 1 and len(selected_indices) < self.num_generations:
841
+ # if len(selected_indices) > 1:
842
+ # selected_items = [(i, lenth_list[i]) for i in selected_indices]
843
+ # sorted_items = sorted(selected_items, key=lambda x: x[1], reverse=True)
844
+ # N = len(sorted_items)
845
+ # for rank, (idx, length) in enumerate(sorted_items):
846
+ # reward = 0.2 - 0.2 * (rank / N)
847
+ # rewards[idx] += reward
848
+ # mem_rewards[idx] = reward
849
+ # for idx in range(len(lenth_list)):
850
+ # if lenth_list[idx] >= 512:
851
+ # rewards[idx] -= 0.5
852
+
853
+ if len(selected_indices) > 1:
854
+ for idx in selected_indices:
855
+ if 320 <= lenth_list[idx] <= 1600:
856
+ rewards[idx] += 0.2
857
+
858
+ # print(rewards)
859
+ # print(completion_mask.sum(1))
860
+
861
+ # Compute grouped-wise rewards
862
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
863
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
864
+
865
+ # Normalize the rewards to compute the advantages
866
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
867
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
868
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
869
+
870
+ # x - x.detach() allows for preserving gradients from x
871
+ per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
872
+ per_token_loss = -(per_token_loss - self.beta * per_token_kl)
873
+ # per_token_loss = -per_token_loss
874
+ loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
875
+
876
+
877
+ # import pdb
878
+ # pdb.set_trace()
879
+
880
+ # Log the metrics
881
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
882
+ self._metrics["completion_length"].append(completion_length)
883
+
884
+ reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
885
+ for i, reward_func in enumerate(self.reward_funcs):
886
+ if isinstance(reward_func, PreTrainedModel):
887
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
888
+ else:
889
+ reward_func_name = reward_func.__name__
890
+ self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
891
+
892
+ gathered_rewards = self.accelerator.gather_for_metrics(rewards)
893
+
894
+ num_devices = gathered_rewards.size(0) // self.num_generations
895
+ rewards_per_device = gathered_rewards.view(num_devices, self.num_generations)
896
+ wrong_devices = (rewards_per_device <= 1).all(dim=1)
897
+ wrong_ratio = wrong_devices.sum().item() / num_devices
898
+
899
+ correct_devices = (rewards_per_device >= 2).all(dim=1)
900
+ correct_ratio = correct_devices.sum().item() / num_devices
901
+
902
+ self._metrics["all_wrong"].append(wrong_ratio)
903
+ self._metrics["all_correct"].append(correct_ratio)
904
+
905
+ if self.temporal:
906
+ temporal_rewards_list = self.accelerator.gather_for_metrics(temporal_rewards)
907
+ self._metrics["temporal_rewards"].append(self.accelerator.gather_for_metrics(temporal_rewards_list).mean().item())
908
+
909
+ self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
910
+
911
+ self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
912
+
913
+ mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
914
+ self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
915
+
916
+
917
+ return loss
918
+
919
+
920
+
921
+
922
+ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
923
+ metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
924
+
925
+ # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
926
+ # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
927
+ if next(iter(logs.keys())).startswith("eval_"):
928
+ metrics = {f"eval_{key}": val for key, val in metrics.items()}
929
+
930
+ logs = {**logs, **metrics}
931
+ if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
932
+ super().log(logs, start_time)
933
+ else: # transformers<=4.46
934
+ super().log(logs)
935
+ self._metrics.clear()