RobertoBarrosoLuque commited on
Commit
582e83e
·
1 Parent(s): 69ab3a1

Add qwen 3 vl

Browse files
notebooks/01-eda-and-fine-tuning.ipynb CHANGED
@@ -331,10 +331,46 @@
331
  "! firectl -a pyroworks get sftj bew0pztj"
332
  ]
333
  },
 
 
 
 
 
 
 
 
334
  {
335
  "cell_type": "code",
336
  "execution_count": null,
337
- "id": "28",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  "metadata": {},
339
  "outputs": [],
340
  "source": []
 
331
  "! firectl -a pyroworks get sftj bew0pztj"
332
  ]
333
  },
334
+ {
335
+ "cell_type": "markdown",
336
+ "id": "28",
337
+ "metadata": {},
338
+ "source": [
339
+ "##### Fine tune Qwen 3 vl 8B"
340
+ ]
341
+ },
342
  {
343
  "cell_type": "code",
344
  "execution_count": null,
345
+ "id": "29",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "! firectl -a pyroworks create sftj --base-model accounts/fireworks/models/qwen3-vl-8b-instruct --dataset accounts/pyroworks/datasets/fashion-catalog-train --output-model qwen3-8b-fashion-catalog --display-name \"Qwen3-8B-fashion-catalog\" --epochs 3 --learning-rate 0.0001 --early-stop --eval-auto-carveout"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "id": "30",
355
+ "metadata": {},
356
+ "source": [
357
+ "##### Fine tune Qwen 3 VL 32B"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": null,
363
+ "id": "31",
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "! firectl -a pyroworks create sftj --base-model accounts/fireworks/models/qwen3-vl-32b-instruct --dataset accounts/pyroworks/datasets/fashion-catalog-train --output-model qwen3-32b-fashion-catalog --display-name \"Qwen3-32B-fashion-catalog\" --epochs 3 --learning-rate 0.0001 --early-stop --eval-auto-carveout"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "id": "32",
374
  "metadata": {},
375
  "outputs": [],
376
  "source": []
notebooks/02-model-evals.ipynb CHANGED
@@ -7,9 +7,9 @@
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
- "from src.modules.vlm_inference import analyze_product_image\n",
11
- "from src.modules.data_processing import load_test_data, image_to_base64\n",
12
- "from src.modules.evals import run_inference_on_dataframe_async, evaluate_all_categories, extract_metrics\n",
13
  "from dotenv import load_dotenv\n",
14
  "import os\n",
15
  "from PIL import Image\n",
@@ -137,8 +137,7 @@
137
  "id": "10",
138
  "metadata": {},
139
  "source": [
140
- "##### Run inference on Qwen 2.5 VL 32B\n",
141
- "m"
142
  ]
143
  },
144
  {
@@ -235,9 +234,65 @@
235
  ]
236
  },
237
  {
238
- "cell_type": "markdown",
 
239
  "id": "18",
240
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  "source": [
242
  "#### Run test set through fine tuned FW Qwen model\n",
243
  "1. Create a Lora deployment of our fine tuned model\n",
@@ -247,33 +302,41 @@
247
  },
248
  {
249
  "cell_type": "markdown",
250
- "id": "19",
251
  "metadata": {},
252
  "source": [
253
- "#### Run evals on Qwen 32B SFT\n"
254
  ]
255
  },
256
  {
257
  "cell_type": "code",
258
  "execution_count": null,
259
- "id": "20",
260
  "metadata": {},
261
  "outputs": [],
262
  "source": [
263
- "! firectl -a pyroworks create deployment accounts/pyroworks/models/qwen-32b-fashion-catalog --min-replica-count 1 --max-replica-count 1 --accelerator-type NVIDIA_H100_80GB"
 
 
 
 
 
 
 
 
264
  ]
265
  },
266
  {
267
  "cell_type": "code",
268
  "execution_count": null,
269
- "id": "21",
270
  "metadata": {},
271
  "outputs": [],
272
  "source": [
273
  "# Run with concurrent requests using await directly in Jupyter\n",
274
  "df_predictions_qwen_32b_fine_tuned = await run_inference_on_dataframe_async(\n",
275
  " df_test,\n",
276
- " model=\"accounts/pyroworks/deployedModels/qwen-32b-fashion-catalog-c6fhxibo\",\n",
277
  " provider=\"FireworksAI\",\n",
278
  " api_key=FIREWORKS_API_KEY,\n",
279
  " max_concurrent_requests=20, # Adjust based on rate limits\n",
@@ -288,7 +351,7 @@
288
  },
289
  {
290
  "cell_type": "markdown",
291
- "id": "22",
292
  "metadata": {},
293
  "source": [
294
  "#### Run evals on Qwen 72B SFT"
@@ -297,7 +360,7 @@
297
  {
298
  "cell_type": "code",
299
  "execution_count": null,
300
- "id": "23",
301
  "metadata": {},
302
  "outputs": [],
303
  "source": [
@@ -307,17 +370,17 @@
307
  {
308
  "cell_type": "code",
309
  "execution_count": null,
310
- "id": "24",
311
  "metadata": {},
312
  "outputs": [],
313
  "source": [
314
- "!firectl-admin get deployment bedocpar"
315
  ]
316
  },
317
  {
318
  "cell_type": "code",
319
  "execution_count": null,
320
- "id": "25",
321
  "metadata": {},
322
  "outputs": [],
323
  "source": [
@@ -339,7 +402,89 @@
339
  },
340
  {
341
  "cell_type": "markdown",
342
- "id": "26",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  "metadata": {},
344
  "source": [
345
  "#### Run test set through closed source model"
@@ -348,7 +493,7 @@
348
  {
349
  "cell_type": "code",
350
  "execution_count": null,
351
- "id": "27",
352
  "metadata": {},
353
  "outputs": [],
354
  "source": [
@@ -371,7 +516,7 @@
371
  },
372
  {
373
  "cell_type": "markdown",
374
- "id": "28",
375
  "metadata": {},
376
  "source": [
377
  "### Compare eval metrics across models"
@@ -380,7 +525,7 @@
380
  {
381
  "cell_type": "code",
382
  "execution_count": null,
383
- "id": "29",
384
  "metadata": {},
385
  "outputs": [],
386
  "source": [
@@ -401,7 +546,7 @@
401
  {
402
  "cell_type": "code",
403
  "execution_count": null,
404
- "id": "30",
405
  "metadata": {},
406
  "outputs": [],
407
  "source": [
@@ -416,7 +561,7 @@
416
  {
417
  "cell_type": "code",
418
  "execution_count": null,
419
- "id": "31",
420
  "metadata": {},
421
  "outputs": [],
422
  "source": [
@@ -453,7 +598,7 @@
453
  {
454
  "cell_type": "code",
455
  "execution_count": null,
456
- "id": "32",
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
@@ -463,7 +608,7 @@
463
  {
464
  "cell_type": "code",
465
  "execution_count": null,
466
- "id": "33",
467
  "metadata": {},
468
  "outputs": [],
469
  "source": [
 
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
+ "from modules.vlm_inference import analyze_product_image\n",
11
+ "from modules.data_processing import load_test_data, image_to_base64\n",
12
+ "from modules.evals import run_inference_on_dataframe_async, evaluate_all_categories, extract_metrics\n",
13
  "from dotenv import load_dotenv\n",
14
  "import os\n",
15
  "from PIL import Image\n",
 
137
  "id": "10",
138
  "metadata": {},
139
  "source": [
140
+ "##### Run inference on Qwen 2.5 VL 32B"
 
141
  ]
142
  },
143
  {
 
234
  ]
235
  },
236
  {
237
+ "cell_type": "code",
238
+ "execution_count": null,
239
  "id": "18",
240
  "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "# Run with concurrent requests using await directly in Jupyter\n",
244
+ "df_predictions_qwen3_8B_base = await run_inference_on_dataframe_async(\n",
245
+ " df_test,\n",
246
+ " model=\"accounts/pyroworks/deployedModels/qwen3-vl-8b-instruct-y147m785\",\n",
247
+ " provider=\"FireworksAI\",\n",
248
+ " api_key=FIREWORKS_API_KEY,\n",
249
+ " max_concurrent_requests=20, # Adjust based on rate limits\n",
250
+ ")\n",
251
+ "\n",
252
+ "results_qwen3_8B_base = evaluate_all_categories(\n",
253
+ " df_ground_truth=df_test,\n",
254
+ " df_predictions=df_predictions_qwen3_8B_base,\n",
255
+ " categories=[\"masterCategory\", \"gender\", \"subCategory\"]\n",
256
+ ")"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "19",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "! firectl create deployment accounts/fireworks/models/qwen3-vl-32b-instruct --deployment-shape THROUGHPUT"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "id": "20",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "# Run with concurrent requests using await directly in Jupyter\n",
277
+ "df_predictions_qwen3_32B_base = await run_inference_on_dataframe_async(\n",
278
+ " df_test,\n",
279
+ " model=\"accounts/pyroworks/deployedModels/qwen3-vl-32b-instruct-jalntd80\",\n",
280
+ " provider=\"FireworksAI\",\n",
281
+ " api_key=FIREWORKS_API_KEY,\n",
282
+ " max_concurrent_requests=20, # Adjust based on rate limits\n",
283
+ ")\n",
284
+ "\n",
285
+ "results_qwen3_32B_base = evaluate_all_categories(\n",
286
+ " df_ground_truth=df_test,\n",
287
+ " df_predictions=df_predictions_qwen3_32B_base,\n",
288
+ " categories=[\"masterCategory\", \"gender\", \"subCategory\"]\n",
289
+ ")"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "markdown",
294
+ "id": "21",
295
+ "metadata": {},
296
  "source": [
297
  "#### Run test set through fine tuned FW Qwen model\n",
298
  "1. Create a Lora deployment of our fine tuned model\n",
 
302
  },
303
  {
304
  "cell_type": "markdown",
305
+ "id": "22",
306
  "metadata": {},
307
  "source": [
308
+ "#### Run evals on Qwen 32B SFT"
309
  ]
310
  },
311
  {
312
  "cell_type": "code",
313
  "execution_count": null,
314
+ "id": "23",
315
  "metadata": {},
316
  "outputs": [],
317
  "source": [
318
+ "!firectl -a pyroworks create deployment accounts/pyroworks/models/qwen-32b-fashion-catalog"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "markdown",
323
+ "id": "24",
324
+ "metadata": {},
325
+ "source": [
326
+ "Deployment ID: accounts/pyroworks/deployments/c09a2c4q"
327
  ]
328
  },
329
  {
330
  "cell_type": "code",
331
  "execution_count": null,
332
+ "id": "25",
333
  "metadata": {},
334
  "outputs": [],
335
  "source": [
336
  "# Run with concurrent requests using await directly in Jupyter\n",
337
  "df_predictions_qwen_32b_fine_tuned = await run_inference_on_dataframe_async(\n",
338
  " df_test,\n",
339
+ " model=\"accounts/pyroworks/deployedModels/qwen-32b-fashion-catalog-pwb1mga2\",\n",
340
  " provider=\"FireworksAI\",\n",
341
  " api_key=FIREWORKS_API_KEY,\n",
342
  " max_concurrent_requests=20, # Adjust based on rate limits\n",
 
351
  },
352
  {
353
  "cell_type": "markdown",
354
+ "id": "26",
355
  "metadata": {},
356
  "source": [
357
  "#### Run evals on Qwen 72B SFT"
 
360
  {
361
  "cell_type": "code",
362
  "execution_count": null,
363
+ "id": "27",
364
  "metadata": {},
365
  "outputs": [],
366
  "source": [
 
370
  {
371
  "cell_type": "code",
372
  "execution_count": null,
373
+ "id": "28",
374
  "metadata": {},
375
  "outputs": [],
376
  "source": [
377
+ "!firectl get deployment bedocpar"
378
  ]
379
  },
380
  {
381
  "cell_type": "code",
382
  "execution_count": null,
383
+ "id": "29",
384
  "metadata": {},
385
  "outputs": [],
386
  "source": [
 
402
  },
403
  {
404
  "cell_type": "markdown",
405
+ "id": "30",
406
+ "metadata": {},
407
+ "source": [
408
+ "#### Run evals on Qwen 3 8B SFT"
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "id": "31",
415
+ "metadata": {},
416
+ "outputs": [],
417
+ "source": [
418
+ "! firectl-admin -a pyroworks create deployment accounts/pyroworks/models/qwen3-8b-fashion-catalog"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "32",
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": [
428
+ "# Run with concurrent requests using await directly in Jupyter\n",
429
+ "df_predictions_qwen_3_8b_fine_tuned = await run_inference_on_dataframe_async(\n",
430
+ " df_test,\n",
431
+ " model=\"accounts/pyroworks/deployedModels/qwen3-8b-fashion-catalog-bdo0tqxe\",\n",
432
+ " provider=\"FireworksAI\",\n",
433
+ " api_key=FIREWORKS_API_KEY,\n",
434
+ " max_concurrent_requests=20,\n",
435
+ ")\n",
436
+ "\n",
437
+ "results_qwen__3_8b_fine_tuned = evaluate_all_categories(\n",
438
+ " df_ground_truth=df_test,\n",
439
+ " df_predictions=df_predictions_qwen_3_8b_fine_tuned,\n",
440
+ " categories=[\"masterCategory\", \"gender\", \"subCategory\"]\n",
441
+ ")"
442
+ ]
443
+ },
444
+ {
445
+ "cell_type": "markdown",
446
+ "id": "33",
447
+ "metadata": {},
448
+ "source": [
449
+ "#### Run evals on Qwen 3 32B SFT"
450
+ ]
451
+ },
452
+ {
453
+ "cell_type": "code",
454
+ "execution_count": null,
455
+ "id": "34",
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "! firectl -a pyroworks create deployment accounts/pyroworks/models/qwen3-32b-fashion-catalog --world-size 4 --accelerator-type NVIDIA_H200_141GB --min-replica-count 1 --max-replica-count 1"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "id": "35",
466
+ "metadata": {},
467
+ "outputs": [],
468
+ "source": [
469
+ "# Run with concurrent requests using await directly in Jupyter\n",
470
+ "df_predictions_qwen_3_32b_fine_tuned = await run_inference_on_dataframe_async(\n",
471
+ " df_test,\n",
472
+ " model=\"accounts/pyroworks/deployedModels/qwen-32b-fashion-catalog-pwb1mga2\",\n",
473
+ " provider=\"FireworksAI\",\n",
474
+ " api_key=FIREWORKS_API_KEY,\n",
475
+ " max_concurrent_requests=20,\n",
476
+ ")\n",
477
+ "\n",
478
+ "results_qwen__3_32b_fine_tuned = evaluate_all_categories(\n",
479
+ " df_ground_truth=df_test,\n",
480
+ " df_predictions=df_predictions_qwen_3_32b_fine_tuned,\n",
481
+ " categories=[\"masterCategory\", \"gender\", \"subCategory\"]\n",
482
+ ")"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "id": "36",
488
  "metadata": {},
489
  "source": [
490
  "#### Run test set through closed source model"
 
493
  {
494
  "cell_type": "code",
495
  "execution_count": null,
496
+ "id": "37",
497
  "metadata": {},
498
  "outputs": [],
499
  "source": [
 
516
  },
517
  {
518
  "cell_type": "markdown",
519
+ "id": "38",
520
  "metadata": {},
521
  "source": [
522
  "### Compare eval metrics across models"
 
525
  {
526
  "cell_type": "code",
527
  "execution_count": null,
528
+ "id": "39",
529
  "metadata": {},
530
  "outputs": [],
531
  "source": [
 
546
  {
547
  "cell_type": "code",
548
  "execution_count": null,
549
+ "id": "40",
550
  "metadata": {},
551
  "outputs": [],
552
  "source": [
 
561
  {
562
  "cell_type": "code",
563
  "execution_count": null,
564
+ "id": "41",
565
  "metadata": {},
566
  "outputs": [],
567
  "source": [
 
598
  {
599
  "cell_type": "code",
600
  "execution_count": null,
601
+ "id": "42",
602
  "metadata": {},
603
  "outputs": [],
604
  "source": [
 
608
  {
609
  "cell_type": "code",
610
  "execution_count": null,
611
+ "id": "43",
612
  "metadata": {},
613
  "outputs": [],
614
  "source": [
src/modules/evals.py CHANGED
@@ -8,9 +8,10 @@ from sklearn.metrics import (
8
  )
9
  from tqdm.asyncio import tqdm as async_tqdm
10
  import asyncio
11
-
12
- from src.modules.vlm_inference import analyze_product_image_async
13
- from src.modules.data_processing import image_to_base64
 
14
  from pathlib import Path
15
 
16
  DATA_PATH = Path(__file__).parents[2] / "data"
@@ -149,7 +150,9 @@ def run_inference_on_dataframe(
149
  - pred_description: Predicted description
150
  """
151
  return asyncio.run(
152
- run_inference_on_dataframe_async(df, model, api_key, provider, max_concurrent_requests)
 
 
153
  )
154
 
155
 
@@ -312,13 +315,128 @@ def extract_metrics(results_dict, model_name):
312
  metrics_list = []
313
 
314
  for category, metrics in results_dict.items():
315
- metrics_list.append({
316
- 'model': model_name,
317
- 'category': category,
318
- 'accuracy': metrics['accuracy'],
319
- 'precision': metrics['precision'],
320
- 'recall': metrics['recall'],
321
- 'num_samples': metrics['num_samples']
322
- })
323
-
324
- return metrics_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  )
9
  from tqdm.asyncio import tqdm as async_tqdm
10
  import asyncio
11
+ import re
12
+ from glob import glob
13
+ from modules.vlm_inference import analyze_product_image_async
14
+ from modules.data_processing import image_to_base64
15
  from pathlib import Path
16
 
17
  DATA_PATH = Path(__file__).parents[2] / "data"
 
150
  - pred_description: Predicted description
151
  """
152
  return asyncio.run(
153
+ run_inference_on_dataframe_async(
154
+ df, model, api_key, provider, max_concurrent_requests
155
+ )
156
  )
157
 
158
 
 
315
  metrics_list = []
316
 
317
  for category, metrics in results_dict.items():
318
+ metrics_list.append(
319
+ {
320
+ "model": model_name,
321
+ "category": category,
322
+ "accuracy": metrics["accuracy"],
323
+ "precision": metrics["precision"],
324
+ "recall": metrics["recall"],
325
+ "num_samples": metrics["num_samples"],
326
+ }
327
+ )
328
+
329
+ return metrics_list
330
+
331
+
332
+ def parse_model_name(filename: str) -> str:
333
+ """
334
+ Parse a human-readable model name from prediction CSV filename.
335
+
336
+ Examples:
337
+ df_pred_FireworksAI_qwen2-vl-72b-BASE-instruct-yaxztv7t.csv -> Qwen2-VL-72B-BASE
338
+ df_pred_OpenAI_gpt-5-mini-2025-08-07.csv -> GPT-5-Mini
339
+ df_pred_FireworksAI_qwen-72b-SFT-fashion-catalog-oueqouqs.csv -> Qwen2-VL-72B-SFT
340
+ df_pred_FireworksAI_qwen2p5-vl-32b-instruct-ralh0ben.csv -> Qwen2.5-VL-32B-BASE
341
+ df_pred_FireworksAI_qwen-32b-SFT-fashion-catalog-c6fhxibo.csv -> Qwen2.5-VL-32B-SFT
342
+ df_pred_FireworksAI_qwen3-vl-8b-instruct-*.csv -> Qwen3-VL-8B-BASE
343
+ df_pred_FireworksAI_qwen3-8b-fashion-catalog-*.csv -> Qwen3-VL-8B-SFT
344
+ """
345
+ basename = Path(filename).stem
346
+
347
+ # Remove prefix
348
+ name = basename.replace("df_pred_FireworksAI_", "").replace("df_pred_OpenAI_", "")
349
+
350
+ # GPT models
351
+ if "gpt" in name.lower():
352
+ return "GPT-5-Mini"
353
+
354
+ # Check if SFT (fine-tuned) model
355
+ is_sft = "SFT" in name or "fashion-catalog" in name
356
+
357
+ if "qwen3" in name.lower():
358
+ size_match = re.search(r"(\d+)b", name.lower())
359
+ size = size_match.group(1) if size_match else "?"
360
+ suffix = "SFT" if is_sft else "BASE"
361
+ return f"Qwen3-VL-{size}B-{suffix}"
362
+
363
+ if "qwen2p5" in name.lower() or (
364
+ "qwen-32b" in name.lower() and "qwen2-vl" not in name.lower()
365
+ ):
366
+ size_match = re.search(r"(\d+)b", name.lower())
367
+ size = size_match.group(1) if size_match else "?"
368
+ suffix = "SFT" if is_sft else "BASE"
369
+ return f"Qwen2.5-VL-{size}B-{suffix}"
370
+
371
+ if "qwen2-vl" in name.lower() or "qwen-72b" in name.lower():
372
+ size_match = re.search(r"(\d+)b", name.lower())
373
+ size = size_match.group(1) if size_match else "?"
374
+ suffix = "SFT" if is_sft else "BASE"
375
+ return f"Qwen2-VL-{size}B-{suffix}"
376
+
377
+ return name
378
+
379
+
380
+ def compile_evaluation_results(data_path: str = None) -> pd.DataFrame:
381
+ """
382
+ Compile evaluation results from all prediction CSVs in the data directory.
383
+
384
+ Finds all df_pred_*.csv files, calculates metrics against ground truth,
385
+ and creates a consolidated evaluation_results.csv.
386
+
387
+ Args:
388
+ data_path: Path to data directory. Defaults to project's data/ folder.
389
+
390
+ Returns:
391
+ pd.DataFrame: Compiled evaluation results with columns:
392
+ model, category, accuracy, precision, recall, num_samples
393
+ """
394
+ if data_path is None:
395
+ data_path = Path(__file__).parents[2] / "data"
396
+ else:
397
+ data_path = Path(data_path)
398
+
399
+ # Load ground truth
400
+ test_csv = data_path / "test.csv"
401
+ df_test = pd.read_csv(test_csv)
402
+ print(f"Loaded {len(df_test)} ground truth samples from {test_csv}")
403
+
404
+ # Find all prediction CSVs
405
+ pred_files = sorted(glob(str(data_path / "df_pred_*.csv")))
406
+ print(f"Found {len(pred_files)} prediction files")
407
+
408
+ all_metrics = []
409
+
410
+ for pred_file in pred_files:
411
+ model_name = parse_model_name(pred_file)
412
+ print(f"\nProcessing: {Path(pred_file).name} -> {model_name}")
413
+
414
+ # Load predictions
415
+ df_pred = pd.read_csv(pred_file)
416
+
417
+ # Calculate metrics
418
+ results = evaluate_all_categories(
419
+ df_ground_truth=df_test,
420
+ df_predictions=df_pred,
421
+ id_col="id",
422
+ )
423
+
424
+ # Skip models with all errors
425
+ valid_results = {k: v for k, v in results.items() if "error" not in v}
426
+ if not valid_results:
427
+ print(f" Skipping {model_name}: no valid predictions")
428
+ continue
429
+
430
+ # Extract metrics for this model (only valid categories)
431
+ metrics = extract_metrics(valid_results, model_name)
432
+ all_metrics.extend(metrics)
433
+
434
+ # Create final DataFrame
435
+ df_eval = pd.DataFrame(all_metrics)
436
+
437
+ # Save results
438
+ output_path = data_path / "evaluation_results.csv"
439
+ df_eval.to_csv(output_path, index=False)
440
+ print(f"\nSaved evaluation results to {output_path}")
441
+
442
+ return df_eval
src/modules/vlm_inference.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from openai import OpenAI, AsyncOpenAI
3
  from pydantic import BaseModel, Field
4
  from typing import Optional, Literal
5
- from src.modules.constants import PROMPT_LIBRARY
6
 
7
  SYSTEM_PROMPT = """
8
  You are an e-commerce fashion catalog assistant.
 
2
  from openai import OpenAI, AsyncOpenAI
3
  from pydantic import BaseModel, Field
4
  from typing import Optional, Literal
5
+ from modules.constants import PROMPT_LIBRARY
6
 
7
  SYSTEM_PROMPT = """
8
  You are an e-commerce fashion catalog assistant.