Antuke commited on
Commit
c69c4af
·
1 Parent(s): eea3a1c
Files changed (37) hide show
  1. app.py +788 -0
  2. core/args.py +72 -0
  3. core/checkpoint.py +379 -0
  4. core/transformer.py +646 -0
  5. core/transforms/image_transform.py +409 -0
  6. core/utils.py +40 -0
  7. core/vision_encoder/__init__.py +0 -0
  8. core/vision_encoder/__pycache__/__init__.cpython-312.pyc +0 -0
  9. core/vision_encoder/__pycache__/__init__.cpython-313.pyc +0 -0
  10. core/vision_encoder/__pycache__/config.cpython-312.pyc +0 -0
  11. core/vision_encoder/__pycache__/config.cpython-313.pyc +0 -0
  12. core/vision_encoder/__pycache__/pe.cpython-312.pyc +0 -0
  13. core/vision_encoder/__pycache__/pe.cpython-313.pyc +0 -0
  14. core/vision_encoder/__pycache__/pe_lora.cpython-312.pyc +0 -0
  15. core/vision_encoder/__pycache__/rope.cpython-312.pyc +0 -0
  16. core/vision_encoder/__pycache__/rope.cpython-313.pyc +0 -0
  17. core/vision_encoder/__pycache__/tokenizer.cpython-312.pyc +0 -0
  18. core/vision_encoder/__pycache__/tokenizer.cpython-313.pyc +0 -0
  19. core/vision_encoder/__pycache__/transforms.cpython-312.pyc +0 -0
  20. core/vision_encoder/__pycache__/transforms.cpython-313.pyc +0 -0
  21. core/vision_encoder/config.py +260 -0
  22. core/vision_encoder/pe.py +833 -0
  23. core/vision_encoder/rope.py +347 -0
  24. core/vision_encoder/transforms.py +86 -0
  25. core/vision_projector/base.py +26 -0
  26. core/vision_projector/mlp.py +62 -0
  27. requirements.txt +10 -0
  28. setup.py +7 -0
  29. src/model.py +809 -0
  30. utils/__pycache__/commons.cpython-313.pyc +0 -0
  31. utils/__pycache__/dataset.cpython-313.pyc +0 -0
  32. utils/__pycache__/face_detector.cpython-313.pyc +0 -0
  33. utils/__pycache__/task_config.cpython-313.pyc +0 -0
  34. utils/commons.py +158 -0
  35. utils/deploy.prototxt +1789 -0
  36. utils/face_detector.py +105 -0
  37. utils/task_config.py +21 -0
app.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VLM Soft Biometrics - Gradio Interface
3
+ A web application for analyzing facial soft biometrics (age, gender, emotion) using Vision-Language Models.
4
+ """
5
+ import os
6
+ import gradio as gr
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ import base64
12
+ from io import BytesIO
13
+ import traceback # Import traceback at the top
14
+
15
+ from utils.face_detector import FaceDetector
16
+
17
+ # Class definitions
18
+ from src.model import MTLModel
19
+ from utils.commons import get_backbone_pe
20
+ from utils.task_config import Task
21
+
22
+
23
+ TASKS = [
24
+ Task(name='Age', class_labels=["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], criterion=None),
25
+ Task(name='Gender', class_labels=["Male", "Female"], criterion=None),
26
+ Task(name='Emotion', class_labels=["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"], criterion=None)
27
+ ]
28
+ CLASSES = [
29
+ ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"],
30
+ ["M", "F"],
31
+ ["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"]
32
+ ]
33
+
34
+ # Global variables for model and detector
35
+ model = None
36
+ transform = None
37
+ detector = None
38
+ device = None
39
+ current_ckpt_dir = None
40
+ CHECKPOINTS_DIR = './checkpoints/'
41
+
42
+ def scan_checkpoints(ckpt_dir):
43
+ """Scans a directory for .pt or .pth files."""
44
+ if not os.path.exists(ckpt_dir):
45
+ print(f"Warning: Checkpoint directory not found: {ckpt_dir}")
46
+ return [], None
47
+
48
+ try:
49
+ ckpt_files = [
50
+ os.path.join(ckpt_dir, f)
51
+ for f in sorted(os.listdir(ckpt_dir))
52
+ ]
53
+ except Exception as e:
54
+ print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
55
+ return [], None
56
+
57
+ # Create a list of (label, value) tuples
58
+ # label = filename (e.g., "mtlora.pt"), value = full path
59
+ choices_list = [(os.path.basename(f), f) for f in ckpt_files]
60
+
61
+ default_ckpt_path = os.path.join(ckpt_dir, 'mtlora.pt')
62
+
63
+ if default_ckpt_path in ckpt_files:
64
+ return choices_list, default_ckpt_path
65
+ elif ckpt_files:
66
+ return choices_list, ckpt_files[0] # default_ckpt_value is the full path
67
+ else:
68
+ print(f"No checkpoints found in {ckpt_dir}")
69
+ return [], None
70
+
71
+ def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
72
+ """Load and configure model."""
73
+ backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
74
+ model = MTLModel(backbone,tasks=TASKS,use_lora=True,use_deep_head=True,
75
+ use_mtl_lora=('mtlora' in ckpt_dir),
76
+ )
77
+ print(f'loading from {ckpt_dir}')
78
+ model.load_model(filepath=ckpt_dir,map_location=device)
79
+ return model,transform
80
+
81
+ def load_model_and_update_status(ckpt_dir):
82
+ """
83
+ Wrapper function to load a model and return a status message.
84
+ This is used by the dropdown's 'change' event.
85
+ """
86
+ global model, current_ckpt_dir
87
+
88
+ if ckpt_dir is None or ckpt_dir == "":
89
+ return "No checkpoint selected."
90
+
91
+ if model is not None and ckpt_dir == current_ckpt_dir:
92
+ status = f"Model already loaded: {os.path.basename(ckpt_dir)}"
93
+ print(status)
94
+ return status
95
+
96
+ gr.Info(f"Loading model: {os.path.basename(ckpt_dir)}...")
97
+ try:
98
+ init_model(ckpt_dir=ckpt_dir, detection_confidence=0.5)
99
+ current_ckpt_dir = ckpt_dir # Set global directory on successful load
100
+ status = f"Successfully loaded: {os.path.basename(ckpt_dir)}"
101
+ gr.Info("Model loaded successfully!")
102
+ print(status)
103
+ return status
104
+ except Exception as e:
105
+ status = f"Failed to load {os.path.basename(ckpt_dir)}: {str(e)}"
106
+ print(status)
107
+ traceback.print_exc()
108
+ return status
109
+
110
+ def predict(model, image):
111
+ """Make predictions for age, gender, and emotion."""
112
+ with torch.no_grad():
113
+ results = model(image)
114
+
115
+ age_logits, gender_logits, emotion_logits = results['Age'], results['Gender'], results['Emotion']
116
+ # Get probabilities using softmax
117
+ age_probs = torch.softmax(age_logits, dim=-1)
118
+ gender_probs = torch.softmax(gender_logits, dim=-1)
119
+ emotion_probs = torch.softmax(emotion_logits, dim=-1)
120
+
121
+ ages = torch.argmax(age_logits, dim=-1).cpu().tolist()
122
+ genders = torch.argmax(gender_logits, dim=-1).cpu().tolist()
123
+ emotions = torch.argmax(emotion_logits, dim=-1).cpu().tolist()
124
+
125
+ results = []
126
+ for i in range(len(ages)):
127
+ # Get all probabilities for each class
128
+ age_all_probs = {
129
+ CLASSES[0][j]: float(age_probs[i][j].cpu().detach())
130
+ for j in range(len(CLASSES[0]))
131
+ }
132
+ gender_all_probs = {
133
+ CLASSES[1][j]: float(gender_probs[i][j].cpu().detach())
134
+ for j in range(len(CLASSES[1]))
135
+ }
136
+ emotion_all_probs = {
137
+ CLASSES[2][j]: float(emotion_probs[i][j].cpu().detach())
138
+ for j in range(len(CLASSES[2]))
139
+ }
140
+
141
+ results.append({
142
+ 'age': {
143
+ 'predicted_class': CLASSES[0][ages[i]],
144
+ 'predicted_confidence': float(age_probs[i][ages[i]].cpu().detach()),
145
+ 'all_probabilities': age_all_probs
146
+ },
147
+ 'gender': {
148
+ 'predicted_class': CLASSES[1][genders[i]],
149
+ 'predicted_confidence': float(gender_probs[i][genders[i]].cpu().detach()),
150
+ 'all_probabilities': gender_all_probs
151
+ },
152
+ 'emotion': {
153
+ 'predicted_class': CLASSES[2][emotions[i]],
154
+ 'predicted_confidence': float(emotion_probs[i][emotions[i]].cpu().detach()),
155
+ 'all_probabilities': emotion_all_probs
156
+ }
157
+ })
158
+
159
+ return results
160
+
161
+ def get_centroid_weighted_age(probs):
162
+ probs = list(probs.values())
163
+ centroids = [1, 4.5, 14.5, 24.5, 34.5, 44.5, 54.5, 64.5, 80]
164
+ age = 0
165
+ # print(probs) # DEBUG
166
+ for i,p in enumerate(probs):
167
+ age += p * centroids[i]
168
+
169
+ return age
170
+
171
+
172
+ def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
173
+ """Initialize model and detector."""
174
+ global model, transform, detector, device
175
+
176
+ print(f"\n{'='*60}")
177
+ print(f"INITIALIZING MODEL: {ckpt_dir}")
178
+ print(f"{'='*60}")
179
+
180
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
181
+ print(f"Using device: {device}")
182
+
183
+ # Verify model weights exist
184
+ if not os.path.exists(ckpt_dir):
185
+ error_msg = f"Model weights not found: {ckpt_dir}."
186
+ print(f"ERROR: {error_msg}")
187
+ raise FileNotFoundError(error_msg)
188
+
189
+ print(f"Model weights found: {ckpt_dir}")
190
+
191
+ # Load the perception encoder
192
+ model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
193
+ model.eval()
194
+ model.to(device)
195
+
196
+
197
+ detector = FaceDetector(confidence_threshold=detection_confidence)
198
+
199
+ print("✓ Model and detector initialized successfully")
200
+ print(f"{'='*60}\n")
201
+
202
+ def process_image(image, selected_checkpoint_path):
203
+ """
204
+ Process an uploaded image and return predictions with annotated image.
205
+
206
+ Args:
207
+ image: PIL Image or numpy array
208
+ selected_checkpoint_path: The path from the checkpoint dropdown
209
+
210
+ Returns:
211
+ tuple: (annotated_image, results_html)
212
+ """
213
+ if image is None:
214
+ return None, "<p style='color: red;'>Please upload an image</p>"
215
+
216
+ # Ensure model is initialized
217
+ if model is None or selected_checkpoint_path != current_ckpt_dir:
218
+ status = load_model_and_update_status(selected_checkpoint_path)
219
+ if "Failed" in status or "Error" in status:
220
+ return image, f"<p style'color: red;'>Model Error: {status}</p>"
221
+
222
+
223
+ try:
224
+ # --- 1. Prepare images for detection and drawing ---
225
+
226
+ # Convert PIL to OpenCV format (BGR) for the detector
227
+ if isinstance(image, Image.Image):
228
+ img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
229
+ else:
230
+ # Assuming it's a numpy array from Gradio webcam
231
+ img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
232
+
233
+ # Create a PIL copy to draw annotations on
234
+ img_pil_annotated = image.copy()
235
+ draw = ImageDraw.Draw(img_pil_annotated)
236
+
237
+ # --- 2. Detect faces ---
238
+ faces = detector.detect(img_cv, pad_rect=True)
239
+
240
+ if faces is None or len(faces) == 0:
241
+ return image, "<p style='color: orange;'>No faces detected in the image</p>"
242
+
243
+ # --- 3. Process detected faces ---
244
+ crops_pil = []
245
+ face_data = []
246
+
247
+ for idx, (crop, confidence, bbox) in enumerate(faces):
248
+ crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
249
+ crop_pil = Image.fromarray(crop_rgb)
250
+ crops_pil.append(crop_pil)
251
+
252
+ # Resize crop to 336x336 for display
253
+ crop_resized = crop_pil.resize((336, 336), Image.Resampling.LANCZOS)
254
+
255
+ face_data.append({
256
+ 'bbox': bbox,
257
+ 'detection_confidence': float(confidence),
258
+ 'crop_image': crop_resized # Store the resized crop
259
+ })
260
+
261
+ # --- 4. Batch transform and predict ---
262
+ crop_tensors = [transform(crop_pil) for crop_pil in crops_pil]
263
+ batch_tensor = torch.stack(crop_tensors).to(device)
264
+
265
+ # Get predictions
266
+ predictions = predict(model, batch_tensor)
267
+
268
+ # Combine face data with predictions
269
+ for face, pred in zip(face_data, predictions):
270
+ face['predictions'] = pred
271
+
272
+ # --- 5. Create annotated image (using PIL) ---
273
+ for idx, face in enumerate(face_data):
274
+ bbox = face['bbox']
275
+ pred = face['predictions']
276
+ x, y, w, h = bbox
277
+
278
+ # --- Calculate Adaptive Font (from demo.py) ---
279
+ font_size_ratio = 0.08
280
+ min_font_size = 12
281
+ max_font_size = 48
282
+ adaptive_font_size = max(min_font_size, min(int(w * font_size_ratio), max_font_size))
283
+ try:
284
+ font = ImageFont.load_default(size=adaptive_font_size)
285
+ except IOError:
286
+ font = ImageFont.load_default()
287
+
288
+ # --- Draw Bounding Box ---
289
+ draw.rectangle([(x, y), (x + w, y + h)], outline="lime", width=2)
290
+
291
+ # --- Prepare Text Lines (Top-1 Only) ---
292
+ lines_to_draw = []
293
+
294
+ # Age
295
+ age_label = pred['age']['predicted_class']
296
+ age_conf = pred['age']['predicted_confidence']
297
+ lines_to_draw.append(f"Age: {age_label} ({age_conf*100:.0f}%)")
298
+
299
+ # Gender
300
+ gen_label = pred['gender']['predicted_class']
301
+ gen_conf = pred['gender']['predicted_confidence']
302
+ lines_to_draw.append(f"Gender: {gen_label} ({gen_conf*100:.0f}%)")
303
+
304
+ # Emotion
305
+ emo_label = pred['emotion']['predicted_class']
306
+ emo_conf = pred['emotion']['predicted_confidence']
307
+ lines_to_draw.append(f"Emotion: {emo_label} ({emo_conf*100:.0f}%)")
308
+
309
+
310
+ # --- Calculate total height of the text block (from demo.py) ---
311
+ line_spacing = 10
312
+ total_text_height = 0
313
+ for line in lines_to_draw:
314
+ _left, top, _right, bottom = draw.textbbox((0, 0), line, font=font)
315
+ total_text_height += (bottom - top) + line_spacing
316
+
317
+ # --- Place text ABOVE or BELOW the box (from demo.py) ---
318
+ if y - total_text_height > 0:
319
+ # PLACE TEXT ABOVE: There is enough space
320
+ text_y = y - line_spacing
321
+ for line in reversed(lines_to_draw):
322
+ left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="ls") # anchor left-baseline
323
+ draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
324
+ draw.text((x, text_y), line, font=font, fill="white", anchor="ls")
325
+ text_y = top - line_spacing # Move y-position up for the next line
326
+ else:
327
+ # PLACE TEXT BELOW: Not enough space above, so draw downwards
328
+ text_y = y + h + line_spacing
329
+ for line in lines_to_draw:
330
+ left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="lt")
331
+ draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
332
+ draw.text((x, text_y), line, font=font, fill="white", anchor="lt")
333
+ text_y = bottom + line_spacing
334
+
335
+ # --- 6. Create HTML results ---
336
+
337
+ # Helper function to convert PIL image to base64
338
+ def pil_to_base64(img_pil):
339
+ buffered = BytesIO()
340
+ img_pil.save(buffered, format="JPEG")
341
+ img_str = base64.b64encode(buffered.getvalue()).decode()
342
+ return f"data:image/jpeg;base64,{img_str}"
343
+
344
+ # (HTML Generation code remains the same as before)
345
+ results_html = f"""
346
+ <style>
347
+ :root {{
348
+ --primary-color: #4f46e5;
349
+ --success-color: #10b981;
350
+ --text-primary: #ffffff;
351
+ --text-secondary: #9ca3af;
352
+ --background-dark: #1f2937;
353
+ --background-darker: #111827;
354
+ --border-color: #374151;
355
+ }}
356
+ .results-container {{
357
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
358
+ background: var(--background-darker);
359
+ padding: 20px;
360
+ border-radius: 12px;
361
+ color: var(--text-primary);
362
+ }}
363
+ .results-container h2 {{
364
+ color: var(--text-primary);
365
+ margin-bottom: 20px;
366
+ }}
367
+ .face-count {{
368
+ display: inline-block;
369
+ background: var(--primary-color);
370
+ color: white;
371
+ padding: 4px 12px;
372
+ border-radius: 20px;
373
+ font-size: 0.9rem;
374
+ font-weight: 500;
375
+ margin-left: 8px;
376
+ }}
377
+ .face-card {{
378
+ background: var(--background-dark);
379
+ border-radius: 8px;
380
+ padding: 20px;
381
+ margin-top: 15px;
382
+ border: 1px solid var(--border-color);
383
+ display: flex;
384
+ gap: 20px;
385
+ align-items: flex-start;
386
+ }}
387
+ .face-header {{
388
+ font-size: 1rem;
389
+ font-weight: 600;
390
+ margin-bottom: 20px;
391
+ color: var(--text-primary);
392
+ }}
393
+ .face-image-left {{
394
+ flex-shrink: 0;
395
+ width: 336px;
396
+ height: 336px;
397
+ background: var(--background-darker);
398
+ border-radius: 8px;
399
+ overflow: hidden;
400
+ border: 1px solid var(--border-color);
401
+ }}
402
+ .face-image-left img {{
403
+ width: 100%;
404
+ height: 100%;
405
+ object-fit: cover;
406
+ }}
407
+ .face-predictions-right {{
408
+ flex: 1;
409
+ display: flex;
410
+ flex-direction: column;
411
+ gap: 10px;
412
+ }}
413
+ .predictions-horizontal {{
414
+ display: flex;
415
+ flex-direction: row;
416
+ gap: 30px;
417
+ justify-content: space-between;
418
+ }}
419
+ .prediction-section {{
420
+ flex: 1;
421
+ min-width: 0;
422
+ }}
423
+ .prediction-category-label {{
424
+ font-size: 0.8rem;
425
+ font-weight: 700;
426
+ text-transform: uppercase;
427
+ letter-spacing: 0.5px;
428
+ color: var(--primary-color);
429
+ margin-bottom: 8px;
430
+ border-bottom: 2px solid var(--primary-color);
431
+ padding-bottom: 4px;
432
+ }}
433
+ .probabilities-list {{
434
+ display: flex;
435
+ flex-direction: column;
436
+ gap: 6px;
437
+ }}
438
+ .probability-item {{
439
+ display: grid;
440
+ grid-template-columns: 70px 1fr 55px;
441
+ align-items: center;
442
+ gap: 8px;
443
+ padding: 4px 6px;
444
+ border-radius: 4px;
445
+ }}
446
+ .probability-item.predicted {{
447
+ background: rgba(79, 70, 229, 0.2);
448
+ border-left: 3px solid var(--primary-color);
449
+ padding-left: 8px;
450
+ }}
451
+ .prob-class {{
452
+ font-size: 0.8rem;
453
+ font-weight: 600;
454
+ color: var(--text-primary);
455
+ word-wrap: break-word; /* Ensure long class names wrap */
456
+ }}
457
+ .probability-item.predicted .prob-class {{
458
+ color: var(--primary-color);
459
+ font-weight: 700;
460
+ }}
461
+ .prob-bar-container {{
462
+ height: 6px;
463
+ background: var(--border-color);
464
+ border-radius: 3px;
465
+ overflow: hidden;
466
+ }}
467
+ .prob-bar {{
468
+ height: 100%;
469
+ background: linear-gradient(90deg, var(--primary-color), var(--success-color));
470
+ border-radius: 3px;
471
+ transition: width 0.6s ease;
472
+ }}
473
+ .probability-item.predicted .prob-bar {{
474
+ background: var(--primary-color);
475
+ }}
476
+ .prob-percentage {{
477
+ font-size: 0.75rem;
478
+ font-weight: 500;
479
+ color: var(--text-secondary);
480
+ text-align: right;
481
+ }}
482
+ .probability-item.predicted .prob-percentage {{
483
+ color: var(--primary-color);
484
+ font-weight: 700;
485
+ }}
486
+ @media (max-width: 1200px) {{
487
+ .predictions-horizontal {{
488
+ flex-direction: column;
489
+ gap: 15px;
490
+ }}
491
+ }}
492
+ @media (max-width: 900px) {{
493
+ .face-card {{
494
+ flex-direction: column;
495
+ }}
496
+ .face-image-left {{
497
+ width: 100%;
498
+ max-width: 336px;
499
+ margin: 0 auto;
500
+ }}
501
+ .probability-item {{
502
+ grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
503
+ }}
504
+ .prob-class {{
505
+ font-size: 0.75rem;
506
+ }}
507
+ }}
508
+ </style>
509
+
510
+ <div class='results-container'>
511
+ <h2 style='margin-top: 0;'>Classification Results <span class='face-count'>{len(face_data)} face(s)</span></h2>
512
+ """
513
+
514
+ for idx, face in enumerate(face_data):
515
+ pred = face['predictions']
516
+ face_img_base64 = pil_to_base64(face['crop_image'])
517
+ age = get_centroid_weighted_age(pred['age']['all_probabilities'])
518
+ results_html += f"""
519
+ <div class='face-card'>
520
+ <div class='face-image-left'>
521
+ <img src='{face_img_base64}' alt='Face {idx+1}'>
522
+ </div>
523
+ <div class='face-predictions-right'>
524
+ <div class='face-header'>Face {idx+1} - Detection Confidence: {face['detection_confidence']:.1%} - Centroid Age: {int(age)}</div>
525
+ <div class='predictions-horizontal'>
526
+ <div class='prediction-section'>
527
+ <div class='prediction-category-label'>Age</div>
528
+ <div class='probabilities-list'>
529
+ """
530
+ for age_class in CLASSES[0]:
531
+ prob = pred['age']['all_probabilities'][age_class]
532
+ is_predicted = (age_class == pred['age']['predicted_class'])
533
+ predicted_class = 'predicted' if is_predicted else ''
534
+ results_html += f"""
535
+ <div class='probability-item {predicted_class}'>
536
+ <span class='prob-class'>{age_class}</span>
537
+ <div class='prob-bar-container'>
538
+ <div class='prob-bar' style='width: {prob*100}%'></div>
539
+ </div>
540
+ <span class='prob-percentage'>{prob*100:.1f}%</span>
541
+ </div>
542
+ """
543
+ results_html += f"""
544
+ </div>
545
+ </div>
546
+ <div class='prediction-section'>
547
+ <div class='prediction-category-label'>Gender</div>
548
+ <div class='probabilities-list'>
549
+ """
550
+ for gender_class in CLASSES[1]:
551
+ prob = pred['gender']['all_probabilities'][gender_class]
552
+ is_predicted = (gender_class == pred['gender']['predicted_class'])
553
+ predicted_class = 'predicted' if is_predicted else ''
554
+ results_html += f"""
555
+ <div class='probability-item {predicted_class}'>
556
+ <span class='prob-class'>{gender_class}</span>
557
+ <div class='prob-bar-container'>
558
+ <div class='prob-bar' style='width: {prob*100}%'></div>
559
+ </div>
560
+ <span class='prob-percentage'>{prob*100:.1f}%</span>
561
+ </div>
562
+ """
563
+ results_html += """
564
+ </div>
565
+ </div>
566
+ <div class='prediction-section'>
567
+ <div class='prediction-category-label'>Emotion</div>
568
+ <div class='probabilities-list'>
569
+ """
570
+ for emotion_class in CLASSES[2]:
571
+ prob = pred['emotion']['all_probabilities'][emotion_class]
572
+ is_predicted = (emotion_class == pred['emotion']['predicted_class'])
573
+ predicted_class = 'predicted' if is_predicted else ''
574
+ results_html += f"""
575
+ <div class='probability-item {predicted_class}'>
576
+ <span class='prob-class'>{emotion_class}</span>
577
+ <div class='prob-bar-container'>
578
+ <div class='prob-bar' style='width: {prob*100}%'></div>
579
+ </div>
580
+ <span class='prob-percentage'>{prob*100:.1f}%</span>
581
+ </div>
582
+ """
583
+ results_html += """
584
+ </div>
585
+ </div>
586
+ </div>
587
+ </div>
588
+ </div>
589
+ """
590
+ results_html += "</div>"
591
+
592
+ # --- 7. Return the annotated PIL image and HTML ---
593
+ return img_pil_annotated, results_html
594
+
595
+ except Exception as e:
596
+ traceback.print_exc()
597
+ return image, f"<p style='color: red;'>Error processing image: {str(e)}</p>"
598
+
599
+ def create_interface(checkpoint_list, default_checkpoint, initial_status):
600
+ """Create and configure the Gradio interface."""
601
+
602
+ # Custom CSS for better styling
603
+ custom_css = """
604
+ .gradio-container {
605
+ font-family: 'Arial', sans-serif;
606
+ }
607
+ .output-html {
608
+ max-height: none !important;
609
+ overflow-y: auto;
610
+ }
611
+ """
612
+
613
+ # Create interface
614
+ with gr.Blocks(css=custom_css, title="Face Classification System") as demo:
615
+
616
+ with gr.Row():
617
+ gr.Markdown("# Face Classification System")
618
+
619
+ # --- Model Selection ---
620
+ with gr.Row():
621
+ with gr.Column(scale=3):
622
+ checkpoint_dropdown = gr.Dropdown(
623
+ label="Select Model Checkpoint",
624
+ choices=checkpoint_list,
625
+ value=default_checkpoint,
626
+ )
627
+ with gr.Column(scale=2):
628
+ model_status_text = gr.Textbox(
629
+ label="Model Status",
630
+ value=initial_status,
631
+ interactive=False,
632
+ )
633
+
634
+ # Features | Instructions
635
+ with gr.Row():
636
+ with gr.Column(scale=1):
637
+ gr.Markdown("""
638
+ ### Features
639
+ - **Age Classification**: 9 categories (0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+) + Age estimation with weighted centroid average
640
+ - **Gender Classification**: M/F
641
+ - **Emotion Recognition**: 7 categories (Surprise, Fear, Disgust, Happy, Sad, Angry, Neutral)
642
+ - **Automatic Face Detection**: Detects and analyzes multiple faces
643
+ - **Detailed Probability Distributions**: View confidence for all classes
644
+ """)
645
+
646
+ with gr.Column(scale=1):
647
+ gr.Markdown("""
648
+ ### Instructions
649
+ 1. (Optional) Select a model checkpoint from the dropdown.
650
+ 2. Upload an image or capture from webcam (or select an example below)
651
+ 3. Click "Classify Image"
652
+ 4. View detected faces with age, gender, and emotion predictions below
653
+ """)
654
+
655
+ # Upload Image | Annotated Image
656
+ with gr.Row():
657
+ with gr.Column(scale=1):
658
+ input_image = gr.Image(
659
+ label="Upload Image",
660
+ type="pil",
661
+ sources=["upload", "webcam"],
662
+ height=400
663
+ )
664
+
665
+ with gr.Column(scale=1):
666
+ output_image = gr.Image(
667
+ label="Annotated Image",
668
+ type="pil",
669
+ height=400
670
+ )
671
+
672
+ with gr.Row():
673
+ with gr.Column(scale=1):
674
+ analyze_btn = gr.Button(
675
+ "Classify Image",
676
+ variant="primary",
677
+ size="lg"
678
+ )
679
+
680
+ # Examples - after button
681
+ # Dynamically load example images from example directory
682
+ example_dir = "example"
683
+ example_images = []
684
+ if os.path.exists(example_dir):
685
+ try:
686
+ example_images = [
687
+ os.path.join(example_dir, f)
688
+ for f in sorted(os.listdir(example_dir))
689
+ if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
690
+ ]
691
+ except Exception as e:
692
+ print(f"Error reading example images from {example_dir}: {e}")
693
+
694
+ if example_images:
695
+ gr.Markdown("### 📸 Try with example images")
696
+ gr.Examples(
697
+ examples=example_images,
698
+ inputs=input_image,
699
+ cache_examples=False
700
+ )
701
+
702
+ # Results section - full width below everything
703
+ with gr.Row():
704
+ with gr.Column(scale=1):
705
+ output_html = gr.HTML(
706
+ label="Classification Results",
707
+ elem_classes="output-html"
708
+ )
709
+
710
+ # Event handlers
711
+ analyze_btn.click(
712
+ fn=process_image,
713
+ inputs=[input_image, checkpoint_dropdown], # Pass dropdown value
714
+ outputs=[output_image, output_html]
715
+ )
716
+
717
+ checkpoint_dropdown.change(
718
+ fn=load_model_and_update_status,
719
+ inputs=[checkpoint_dropdown],
720
+ outputs=[model_status_text]
721
+ )
722
+
723
+
724
+ return demo
725
+
726
+ # === Main Application Startup ===
727
+
728
+ # Initialize for Hugging Face Spaces (module-level)
729
+ print("="*60)
730
+ print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
731
+ print("="*60)
732
+
733
+ # --- 1. Scan for models first ---
734
+ checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
735
+
736
+ if not checkpoint_list:
737
+ print(f"CRITICAL: No checkpoints found in {CHECKPOINTS_DIR}. App may not function.")
738
+ else:
739
+ print(f"Found checkpoints: {len(checkpoint_list)} file(s).")
740
+ print(f"Default checkpoint: {default_checkpoint}")
741
+
742
+ # --- 2. Try to initialize default model ---
743
+ initial_status_msg = "No default model found. Please select one."
744
+ if default_checkpoint:
745
+ print(f"\nInitializing default model: {default_checkpoint}")
746
+ # This will load the model AND set current_ckpt_dir
747
+ initial_status_msg = load_model_and_update_status(default_checkpoint)
748
+ print(initial_status_msg)
749
+ else:
750
+ print("⚠ Warning: No default model to load.")
751
+
752
+
753
+ # --- 3. Create interface FIRST (so it shows even if model fails) ---
754
+ print("Creating Gradio interface...")
755
+ demo = create_interface(checkpoint_list, default_checkpoint, initial_status_msg)
756
+ print("✓ Interface created successfully!")
757
+
758
+
759
+ if __name__ == "__main__":
760
+ import argparse
761
+
762
+ parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
763
+ parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
764
+ help="Path to the checkpoint directory (overridden by UI)")
765
+ parser.add_argument("--detection_confidence", type=float, default=0.5,
766
+ help="Confidence threshold for face detection")
767
+ parser.add_argument("--port", type=int, default=7860,
768
+ help="Port to run the Gradio app")
769
+ parser.add_argument("--share", action="store_true",
770
+ help="Create a public share link")
771
+ parser.add_argument("--server_name", type=str, default="0.0.0.0",
772
+ help="Server name/IP to bind to")
773
+ args = parser.parse_args()
774
+
775
+ # Update global config if args are provided (though UI dropdown is primary)
776
+ CHECKPOINTS_DIR = args.ckpt_dir
777
+ # Note: detection_confidence is passed to init_model, so it's handled.
778
+
779
+ print(f"\nLaunching server on {args.server_name}:{args.port}")
780
+ print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
781
+ print("="*60)
782
+
783
+ demo.launch(
784
+ share=args.share,
785
+ server_name=args.server_name,
786
+ server_port=args.port,
787
+ show_error=True,
788
+ )
core/args.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ from typing import Type, TypeVar
5
+
6
+ from omegaconf import DictConfig, ListConfig, OmegaConf
7
+
8
+ logger = logging.getLogger()
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ def set_struct_recursively(cfg, strict: bool = True):
14
+ # Set struct mode for the current level
15
+ OmegaConf.set_struct(cfg, strict)
16
+
17
+ # Traverse through nested dictionaries and lists
18
+ if isinstance(cfg, DictConfig):
19
+ for key, value in cfg.items():
20
+ if isinstance(value, (DictConfig, ListConfig)):
21
+ set_struct_recursively(value, strict)
22
+ elif isinstance(cfg, ListConfig):
23
+ for item in cfg:
24
+ if isinstance(item, (DictConfig, ListConfig)):
25
+ set_struct_recursively(item, strict)
26
+
27
+
28
+ def flatten_dict(d, parent_key="", sep="_"):
29
+ items = []
30
+ for k, v in d.items():
31
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
32
+ if isinstance(v, dict):
33
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
34
+ else:
35
+ items.append((new_key, v))
36
+ return dict(items)
37
+
38
+
39
+ def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
40
+ """
41
+ Converts a dictionary to a dataclass instance, recursively for nested structures.
42
+ """
43
+ base = OmegaConf.structured(cls())
44
+ OmegaConf.set_struct(base, strict)
45
+ override = OmegaConf.create(data)
46
+ return OmegaConf.to_object(OmegaConf.merge(base, override))
47
+
48
+
49
+ def dataclass_to_dict(dataclass_instance: T) -> dict:
50
+ """
51
+ Converts a dataclass instance to a dictionary, recursively for nested structures.
52
+ """
53
+ if isinstance(dataclass_instance, dict):
54
+ return dataclass_instance
55
+
56
+ return OmegaConf.to_container(
57
+ OmegaConf.structured(dataclass_instance), resolve=True
58
+ )
59
+
60
+
61
+ def load_config_file(config_file, dataclass_cls: Type[T]) -> T:
62
+ config = OmegaConf.to_container(OmegaConf.load(config_file), resolve=True)
63
+ return dataclass_from_dict(dataclass_cls, config)
64
+
65
+
66
+ def dump_config(config, path, log_config=True):
67
+ yaml_dump = OmegaConf.to_yaml(OmegaConf.structured(config))
68
+ with open(path, "w") as f:
69
+ if log_config:
70
+ logger.info("Using the following config for this run:")
71
+ logger.info(yaml_dump)
72
+ f.write(yaml_dump)
core/checkpoint.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.distributed.checkpoint as dcp
14
+ import torch.nn as nn
15
+ import torch.optim.optimizer
16
+ from omegaconf import OmegaConf
17
+ from torch.distributed._tensor import DeviceMesh
18
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
19
+ from torch.distributed.checkpoint.state_dict import (get_model_state_dict,
20
+ get_state_dict,
21
+ set_state_dict)
22
+
23
+ from core.distributed import get_is_master
24
+
25
+ logger = logging.getLogger("CHECKPOINT")
26
+
27
+ FOLDER_NAME = "{:010d}"
28
+ RE_FOLDER = r"\d{10}"
29
+
30
+ RE_CKPT = r"__\d_\d\.distcp"
31
+
32
+ CONSOLIDATE_FOLDER = "consolidated"
33
+ CONSOLIDATE_NAME = "consolidated.pth"
34
+
35
+ CONFIG_NAME = "params.json"
36
+ TRAIN_STATE_NAME = "train_state_{:05d}.json"
37
+ RE_DIGITS = re.compile(r"\d+")
38
+
39
+
40
+ @dataclass
41
+ class SaveEvery:
42
+ every: int = 1000
43
+ keep: int = 0
44
+
45
+
46
+ @dataclass
47
+ class CheckpointArgs:
48
+ dump: SaveEvery = field(default_factory=SaveEvery)
49
+ eval: SaveEvery = field(default_factory=SaveEvery)
50
+ path: Optional[str] = None
51
+ init_ckpt_path: Optional[str] = None
52
+ vision_model_path: Optional[str] = None
53
+ is_consolidated_model: bool = False
54
+ continue_training_from_init: bool = False
55
+
56
+
57
+ def _get_key_step(name: str):
58
+ return int(re.findall(RE_DIGITS, name)[-1])
59
+
60
+
61
+ def consolidate_checkpoints(ckpt_dir: str):
62
+ """
63
+ Consolidates all FSDP checkpoints in a directory to a single file
64
+ Consolidate checkpoint is saved in a subdirectory of ckpt_dir
65
+
66
+ Parameters:
67
+ ckpt_dir: str - path to the directory containing the checkpoints
68
+
69
+ Returns the path to the consolidated checkpoint
70
+ """
71
+ consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
72
+ if not (consolidate_path / CONSOLIDATE_NAME).exists():
73
+ consolidate_path.mkdir(exist_ok=True)
74
+ logger.info(f"Consolidating to: {str(consolidate_path)}")
75
+ dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
76
+ (consolidate_path / CONFIG_NAME).write_text(
77
+ (Path(ckpt_dir) / CONFIG_NAME).read_text()
78
+ )
79
+ logger.info("Consolidated !")
80
+ return consolidate_path
81
+
82
+
83
+ def load_from_checkpoint(
84
+ ckpt_dir: str,
85
+ model: nn.Module,
86
+ optimizer: Optional[torch.optim.Optimizer] = None,
87
+ model_key: str = "model",
88
+ optim_key: str = "optim",
89
+ ):
90
+ if not (Path(ckpt_dir) / ".metadata").exists():
91
+ raise ValueError(
92
+ "Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
93
+ )
94
+
95
+ state_dict = {}
96
+ if optimizer is not None:
97
+ state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
98
+ else:
99
+ state_dict[model_key] = get_model_state_dict(model)
100
+ if model_key == "": # If only loading a model directly, the key should be empty
101
+ state_dict = state_dict.pop(model_key)
102
+
103
+ dcp.load(state_dict, checkpoint_id=ckpt_dir)
104
+
105
+
106
+ class CheckpointManager:
107
+ def __init__(self, args: CheckpointArgs):
108
+ self.path = args.path
109
+ self.dump_every = args.dump
110
+ self.eval_every = args.eval
111
+ self.init_ckpt_path = args.init_ckpt_path
112
+ self.continue_training_from_init = args.continue_training_from_init
113
+
114
+ assert os.path.exists(
115
+ self.path
116
+ ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
117
+
118
+ self.existing_saves = self.get_existing_saves()
119
+
120
+ def get_existing_saves(self) -> List[Path]:
121
+ folders = [
122
+ p
123
+ for p in Path(self.path).iterdir()
124
+ if p.is_dir() and re.match(RE_FOLDER, p.name)
125
+ ]
126
+ folders.sort(key=lambda p: _get_key_step(p.name))
127
+ return folders
128
+
129
+ def clean_up(self):
130
+ logger.info("Cleaning up checkpoints...")
131
+ dump_folders = []
132
+ eval_folders = []
133
+ other_folders = []
134
+ for p in self.existing_saves:
135
+ is_dump = _get_key_step(p.name) % self.dump_every.every == 0
136
+ is_eval = _get_key_step(p.name) % self.eval_every.every == 0
137
+ if is_dump:
138
+ dump_folders.append(p)
139
+ if is_eval:
140
+ eval_folders.append(p)
141
+ if not (is_dump or is_eval):
142
+ other_folders.append(p)
143
+
144
+ logger.info(f"Dump folders: {dump_folders}")
145
+ logger.info(f"Eval folders: {eval_folders}")
146
+ logger.info(f"Other folders: {other_folders}")
147
+
148
+ if self.dump_every.keep > 0:
149
+ dump_folders = dump_folders[-self.dump_every.keep :]
150
+ if self.eval_every.keep > 0:
151
+ eval_folders = eval_folders[-self.eval_every.keep :]
152
+
153
+ folder_to_keep = set(other_folders + dump_folders + eval_folders)
154
+ folder_to_remove = set(self.existing_saves) - folder_to_keep
155
+
156
+ logger.info(f"Removing folders: {folder_to_remove}")
157
+
158
+ if dist.get_rank() == 0:
159
+ for folder in folder_to_remove:
160
+ for file in folder.iterdir():
161
+ if file.is_file():
162
+ file.unlink()
163
+ elif file.is_dir():
164
+ assert file.name in [CONSOLIDATE_FOLDER]
165
+ for f in file.iterdir():
166
+ f.unlink()
167
+ file.rmdir()
168
+ folder.rmdir()
169
+
170
+ dist.barrier()
171
+
172
+ self.existing_saves = list(folder_to_keep)
173
+ self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
174
+
175
+ def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
176
+ path = None
177
+ for p in reversed(self.existing_saves):
178
+ if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
179
+ path = p
180
+ break
181
+ return path
182
+
183
+ def _create_folder(self, base_path: Path, folder_name: str) -> Path:
184
+ folder = base_path / folder_name
185
+ if get_is_master():
186
+ folder.mkdir(parents=False, exist_ok=True)
187
+ if dist.is_initialized():
188
+ dist.barrier()
189
+ return folder
190
+
191
+ def _get_dp_tp_mesh(
192
+ self, device_mesh: Optional[DeviceMesh] = None
193
+ ) -> Tuple[int, int]:
194
+ dp_rank = 0
195
+ tp_rank = 0
196
+ if device_mesh is not None:
197
+ if "dp_replicate" in device_mesh.mesh_dim_names:
198
+ dp_rank = device_mesh.get_local_rank("dp_replicate")
199
+ if "dp_shard" in device_mesh.mesh_dim_names:
200
+ dp_rank = dp_rank * device_mesh[
201
+ "dp_replicate"
202
+ ].size() + device_mesh.get_local_rank("dp_shard")
203
+ if "tp" in device_mesh.mesh_dim_names:
204
+ tp_rank = device_mesh.get_local_rank("tp")
205
+ return dp_rank, tp_rank
206
+
207
+ @torch.no_grad()
208
+ def get_state_dict(
209
+ self,
210
+ model,
211
+ optimizer,
212
+ ):
213
+ model_sd, optim_sd = get_state_dict(model, optimizer)
214
+ return {"model": model_sd, "optim": optim_sd}
215
+
216
+ def save(
217
+ self,
218
+ model,
219
+ optimizer,
220
+ train_state,
221
+ config,
222
+ device_mesh: Optional[DeviceMesh] = None,
223
+ ) -> bool:
224
+
225
+ # When creating directory check if only rank0 or is there other solution
226
+ path = Path(self.path)
227
+ curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
228
+ logger.info(f"Saving to: {str(curr_save_dir)}")
229
+
230
+ if dist.is_initialized():
231
+ dist.barrier()
232
+
233
+ logger.info("Saving...")
234
+ state_dict = self.get_state_dict(model, optimizer)
235
+ dcp.save(state_dict, checkpoint_id=curr_save_dir)
236
+ logger.info("State dict saved!")
237
+
238
+ if dist.is_initialized():
239
+ dist.barrier()
240
+
241
+ if get_is_master():
242
+ with open(curr_save_dir / CONFIG_NAME, "w") as f:
243
+ json.dump(
244
+ OmegaConf.to_container(OmegaConf.structured(config), resolve=True),
245
+ f,
246
+ indent=4,
247
+ )
248
+
249
+ # Add json dump here
250
+ dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
251
+ if tp_rank == 0:
252
+ train_state_name = TRAIN_STATE_NAME.format(dp_rank)
253
+ logger.info(
254
+ f"Saving train state to: {str(curr_save_dir / train_state_name)}"
255
+ )
256
+ # logger.info(f"train_state.state_dict()={train_state.state_dict()}")
257
+ with open(curr_save_dir / train_state_name, "w") as f:
258
+ json.dump(train_state.state_dict(), f)
259
+ logger.info("Train state saved !")
260
+
261
+ self.existing_saves.append(curr_save_dir)
262
+
263
+ self.clean_up()
264
+
265
+ if dist.is_initialized():
266
+ dist.barrier()
267
+ return True
268
+
269
+ @torch.no_grad()
270
+ def load(
271
+ self,
272
+ model: nn.Module,
273
+ optimizer,
274
+ train_state,
275
+ device_mesh: DeviceMesh,
276
+ path: Optional[Path] = None,
277
+ ):
278
+ dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
279
+ # Loading tries to load the provided path, if not available the last saved step and finally from the init path
280
+ path = path or self.get_last_step_path(dp_rank=dp_rank)
281
+ # If none of those are available don't do anything
282
+ if path is None:
283
+ # If no checkpoints exist do nothing
284
+ return
285
+
286
+ # Only load train state if it's provided, the files exist and we're not loading from init path
287
+ train_state_name = TRAIN_STATE_NAME.format(dp_rank)
288
+ logger.info("Reloading train state")
289
+ with open(path / train_state_name, "r") as f:
290
+ train_state_dict = json.load(f)
291
+ train_state.load_state_dict(train_state_dict)
292
+ logger.info("Train state reloaded")
293
+
294
+ logger.info(f"Loading from: {str(path)}")
295
+ state_dict = self.get_state_dict(
296
+ model=model,
297
+ optimizer=optimizer,
298
+ )
299
+ dcp.load(state_dict, checkpoint_id=path)
300
+ logger.info("State dict loaded.")
301
+
302
+ logger.info("Reloading model and optim")
303
+
304
+ set_state_dict(
305
+ model,
306
+ optimizer,
307
+ model_state_dict=state_dict["model"],
308
+ optim_state_dict=state_dict["optim"],
309
+ )
310
+ logger.info("Model and optim reloaded")
311
+
312
+ @classmethod
313
+ def instantiate_and_make_dir(cls, args: CheckpointArgs):
314
+ if get_is_master():
315
+ os.makedirs(args.path, exist_ok=True)
316
+ dist.barrier()
317
+
318
+ return cls(args)
319
+
320
+
321
+ def get_consolidated_ckpt_path(ckpt_dir: Path, mp_rank: int = 0, mp_size: int = 1):
322
+ if mp_size == 1:
323
+ assert mp_rank == 0
324
+ no_rank_path = ckpt_dir / "consolidated.pth"
325
+ if no_rank_path.exists():
326
+ return no_rank_path
327
+ return ckpt_dir / f"consolidated.{mp_rank:02d}.pth"
328
+
329
+
330
+ def load_consolidated_checkpoint(
331
+ model: nn.Module,
332
+ consolidated_path: str,
333
+ vision_model_path: Optional[str] = None,
334
+ ):
335
+ """
336
+ Loads a consolidated checkpoint into the model.
337
+ This version supports both:
338
+ - a single file named 'consolidated.pth'
339
+ - multiple parts named like 'consolidated.00.pth', 'consolidated.01.pth', etc.
340
+ """
341
+ ckpt_path = Path(consolidated_path)
342
+ cp_file = get_consolidated_ckpt_path(ckpt_path, mp_rank=0, mp_size=1)
343
+ if cp_file.exists():
344
+ # Use the single file
345
+ st_dict = torch.load(cp_file, weights_only=True)
346
+ if "model" in st_dict:
347
+ st_dict = st_dict["model"]
348
+ else:
349
+ # Fall back to multi-part consolidated files (e.g. consolidated.00.pth, consolidated.01.pth, …)
350
+ checkpoint_files = sorted(ckpt_path.glob("consolidated.*.pth"))
351
+ if not checkpoint_files:
352
+ raise FileNotFoundError(
353
+ f"No consolidated checkpoint file found in {ckpt_path}."
354
+ )
355
+ st_dict = {}
356
+ for ckpt_file in checkpoint_files:
357
+ part = torch.load(ckpt_file, weights_only=True)
358
+ # If the checkpoint part is wrapped with "model", unwrap it
359
+ if "model" in part:
360
+ part = part["model"]
361
+ # Merge the state dicts (assumes the keys are all unique or will correctly overwrite)
362
+ st_dict.update(part)
363
+
364
+ model.vision_projector.init_tensors()
365
+ model.vision_model.init_tensors()
366
+ model.rope_embeddings.reset_parameters()
367
+
368
+ if vision_model_path is not None:
369
+ model.vision_model.load_ckpt(vision_model_path)
370
+
371
+ missing_keys, unexpected_keys = model.load_state_dict(st_dict, strict=False)
372
+ missing_keys = [k for k in missing_keys if "tied_module.weight" not in k]
373
+ if vision_model_path is not None:
374
+ # vision_model is already loaded separately
375
+ missing_keys = [k for k in missing_keys if "vision_model." not in k]
376
+ if len(missing_keys) > 0:
377
+ logger.warning(f"Missing keys when reloading: {missing_keys}")
378
+ if len(unexpected_keys) > 0:
379
+ logger.warning(f"Unexpected keys when reloading: {unexpected_keys}")
core/transformer.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
12
+ flex_attention)
13
+ from xformers.ops import AttentionBias, fmha
14
+
15
+ from core import probe
16
+
17
+
18
+ class InitStdFactor(Enum):
19
+ DISABLED = "disabled" # Init std is divided by 1.0
20
+ GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
21
+ CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
22
+ DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
23
+
24
+
25
+ @dataclass
26
+ class BaseTransformerArgs:
27
+ dim: int = 512
28
+ n_layers: int = 8
29
+ head_dim: Optional[int] = None
30
+ n_heads: Optional[int] = None
31
+ n_kv_heads: Optional[int] = None
32
+
33
+ ffn_dim_multiplier: Optional[float] = None
34
+
35
+ multiple_of: int = 256
36
+
37
+ norm_eps: float = 1e-5
38
+
39
+ rope_theta: float = 10000.0
40
+
41
+ old_context_len: int = 8192
42
+ rope_scale_factor: int = 1
43
+ low_freq_factor: int = 1
44
+ high_freq_factor: int = 32
45
+
46
+ init_base_std: Optional[float] = None
47
+ init_std_factor: str = "disabled"
48
+
49
+ max_seqlen: int = 1024
50
+
51
+
52
+ def cross_entropy(pred, target, **kwargs):
53
+ return F.nll_loss(
54
+ F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
55
+ target.flatten(end_dim=-1),
56
+ **kwargs,
57
+ )
58
+
59
+
60
+ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
61
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
62
+ assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
63
+ bs, slen, n_kv_heads, head_dim = x.shape
64
+ if n_rep == 1:
65
+ return x
66
+ return (
67
+ x[:, :, :, None, :]
68
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
69
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
70
+ )
71
+
72
+
73
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
74
+ """
75
+ Reshape frequency tensor for broadcasting it with another tensor.
76
+
77
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
78
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
79
+
80
+ Args:
81
+ freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ seq_dim (int): Sequence dimension index.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+ """
88
+ ndim = x.ndim
89
+ assert 0 <= seq_dim < ndim
90
+ assert freqs_cis.shape == (
91
+ x.shape[seq_dim],
92
+ x.shape[-3],
93
+ 2,
94
+ 2,
95
+ ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
96
+ shape = [
97
+ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
98
+ ] + [2, 2]
99
+ return freqs_cis.view(*shape)
100
+
101
+
102
+ def apply_rotary_emb(
103
+ xq: torch.Tensor,
104
+ xk: torch.Tensor,
105
+ seq_dim: int,
106
+ freqs_cis: torch.Tensor,
107
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
108
+ xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
109
+ xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
110
+ freqs_cis = reshape_for_broadcast(
111
+ freqs_cis, xq_, seq_dim
112
+ ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
113
+ xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
114
+ xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
115
+ return xq_out.type_as(xq), xk_out.type_as(xk)
116
+
117
+
118
+ def causal_mask(b, h, q_idx, kv_idx):
119
+ return q_idx >= kv_idx
120
+
121
+
122
+ def lengths_to_start_ids(lengths):
123
+ doc_start = lengths.cumsum(0)
124
+ doc_start = doc_start.roll(1)
125
+ doc_start[0] = 0
126
+ return doc_start
127
+
128
+
129
+ def lengths_to_local_ids(lengths):
130
+ assert lengths.ndim == 1
131
+ nb_seqs = lengths.size(0)
132
+ total_seqlen = lengths.sum()
133
+ # This gives the document id of each token
134
+ doc_id = torch.repeat_interleave(lengths)
135
+ # Compute document start for each document
136
+ doc_start = lengths_to_start_ids(lengths)
137
+ # Compute document start for each token
138
+ doc_start = doc_start[doc_id]
139
+ # Compute the position of each token within each document
140
+ tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
141
+
142
+ return doc_id, tok_id
143
+
144
+
145
+ def generate_doc_mask_mod(
146
+ mask_mod: _mask_mod_signature,
147
+ lengths: torch.Tensor,
148
+ kv_lengths: Optional[torch.Tensor] = None,
149
+ ) -> _mask_mod_signature:
150
+ """Generates mask mods that apply to inputs to flex attention in the sequence stacked
151
+ format.
152
+
153
+ Args:
154
+ mask_mod: The mask mod to apply to the documents
155
+ lengths: Lengths of each document
156
+
157
+ Note:
158
+ What is the sequence stacked format? When assembling batches of inputs, we
159
+ take multiple sequences and stack them together to form 1 large sequence. We then
160
+ use masking to ensure that the attention scores are only applied to tokens within
161
+ the same document.
162
+
163
+ Example:
164
+
165
+ - Square mask
166
+ doc_mask lengths
167
+ a a b b b c c 2 3 2
168
+ a 1 0 0 0 0 0 0
169
+ a 1 1 0 0 0 0 0
170
+ b 0 0 1 0 0 0 0
171
+ b 0 0 1 1 0 0 0
172
+ b 0 0 1 1 1 0 0
173
+ c 0 0 0 0 0 1 0
174
+ c 0 0 0 0 0 1 1
175
+
176
+ """
177
+ kv_lengths = kv_lengths if kv_lengths is not None else lengths
178
+ q_document_id, q_token_id = lengths_to_local_ids(lengths)
179
+ kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
180
+ q_max_idx = lengths.sum() - 1
181
+ kv_max_idx = kv_lengths.sum() - 1
182
+
183
+ def doc_mask_mod(b, h, q_idx, kv_idx):
184
+ q_idx_cap = torch.minimum(q_max_idx, q_idx)
185
+ kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
186
+ valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
187
+ same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
188
+ q_logical = q_token_id[q_idx_cap]
189
+ kv_logical = kv_token_id[kv_idx_cap]
190
+ inner_mask = mask_mod(b, h, q_logical, kv_logical)
191
+ return same_doc & inner_mask & valid_idx
192
+
193
+ return doc_mask_mod
194
+
195
+
196
+ # Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
197
+ class RotaryEmbedding(torch.nn.Module):
198
+ """
199
+ RotaryEmbedding Module
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ theta: float,
205
+ head_dim: int,
206
+ max_seqlen: int = 1024,
207
+ scale_factor: int = 1,
208
+ low_freq_factor: int = 1,
209
+ high_freq_factor: int = 32,
210
+ old_context_len: int = 8192,
211
+ ):
212
+ super().__init__()
213
+
214
+ self.theta = theta
215
+ self.head_dim = head_dim
216
+ self.max_seqlen = max_seqlen
217
+ self.scale_factor = scale_factor
218
+ self.low_freq_factor = low_freq_factor
219
+ self.high_freq_factor = high_freq_factor
220
+ self.old_context_len = old_context_len
221
+ if scale_factor != 1:
222
+ self.low_freq_wavelen = old_context_len / low_freq_factor
223
+ self.high_freq_wavelen = old_context_len / high_freq_factor
224
+ assert self.low_freq_wavelen >= self.high_freq_wavelen
225
+
226
+ def reset_parameters(self):
227
+ self.register_buffer(
228
+ "freqs_cis",
229
+ self.precompute_freqs_cis(
230
+ dim=self.head_dim, end=self.max_seqlen, theta=self.theta
231
+ ),
232
+ persistent=False,
233
+ )
234
+
235
+ def apply_scaling(self, freqs):
236
+ if self.scale_factor == 1:
237
+ return freqs
238
+ new_freqs = []
239
+ for freq in freqs:
240
+ wavelen = 2 * math.pi / freq
241
+ if wavelen < self.high_freq_wavelen:
242
+ new_freqs.append(freq)
243
+ elif wavelen > self.low_freq_wavelen:
244
+ new_freqs.append(freq / self.scale_factor)
245
+ else:
246
+ assert self.low_freq_wavelen != self.high_freq_wavelen
247
+ smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
248
+ self.high_freq_factor - self.low_freq_factor
249
+ )
250
+ new_freqs.append(
251
+ (1 - smooth) * freq / self.scale_factor + smooth * freq
252
+ )
253
+ return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
254
+
255
+ def precompute_freqs_cis(
256
+ self,
257
+ dim: int,
258
+ end: int,
259
+ theta: float = 10000.0,
260
+ ):
261
+ """
262
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
263
+
264
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
265
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
266
+ The returned tensor contains complex values in complex64 data type.
267
+
268
+ Args:
269
+ dim (int): Dimension of the frequency tensor.
270
+ end (int): End index for precomputing frequencies.
271
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
272
+
273
+ Returns:
274
+ torch.Tensor: Precomputed frequency tensor with complex exponentials.
275
+ """
276
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
277
+ freqs = self.apply_scaling(freqs)
278
+
279
+ t = torch.arange(end, device=freqs.device)
280
+ freqs = torch.outer(t, freqs).float()
281
+
282
+ cos, sin = freqs.cos(), freqs.sin()
283
+
284
+ return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
285
+
286
+ def forward(
287
+ self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
288
+ ):
289
+ """
290
+ Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
291
+ Args:
292
+ seqlen (int): Contiguous sequence length
293
+ tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
294
+
295
+ Returns:
296
+ Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
297
+ """
298
+ test = (seqlen is not None) or (tok_idx is not None)
299
+ assert test, "Should provide atleast seqlen or tok_idx"
300
+ if tok_idx is not None:
301
+ return self.freqs_cis[tok_idx]
302
+ elif seqlen is not None:
303
+ return self.freqs_cis[0:seqlen]
304
+
305
+
306
+ class RMSNorm(nn.Module):
307
+ """
308
+ Initialize the RMSNorm normalization layer.
309
+
310
+ Args:
311
+ dim (int): The dimension of the input tensor.
312
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
313
+
314
+ Attributes:
315
+ eps (float): A small value added to the denominator for numerical stability.
316
+ weight (nn.Parameter): Learnable scaling parameter.
317
+
318
+ """
319
+
320
+ def __init__(self, dim: int, eps: float = 1e-6):
321
+ super().__init__()
322
+ self.eps = eps
323
+ self.weight = nn.Parameter(torch.ones(dim))
324
+
325
+ def _norm(self, x: torch.Tensor):
326
+ return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
327
+
328
+ def forward(self, x: torch.Tensor):
329
+ x = probe.log_stats(x, "resid")
330
+ output = self._norm(x.float())
331
+ return (output * self.weight.float()).type_as(x)
332
+
333
+ def reset_parameters(self):
334
+ torch.nn.init.ones_(self.weight) # type: ignore
335
+
336
+
337
+ class TiedLinear(nn.Module):
338
+ def __init__(self, tied_module: nn.Module) -> None:
339
+ super().__init__()
340
+ self.tied_module = tied_module
341
+ if not hasattr(tied_module, "weight"):
342
+ raise AttributeError(
343
+ "Provided module does not have attribute 'weight'. Please check your tied_module."
344
+ )
345
+
346
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
347
+ return F.linear(x, self.tied_module.weight)
348
+
349
+
350
+ class Attention(nn.Module):
351
+ def __init__(
352
+ self,
353
+ dim: int,
354
+ head_dim: int,
355
+ n_heads: int,
356
+ n_kv_heads: int,
357
+ rope_theta: float,
358
+ ):
359
+ super().__init__()
360
+
361
+ self.dim = dim
362
+ self.head_dim = head_dim
363
+ self.rope_theta = rope_theta
364
+
365
+ self.n_heads = n_heads
366
+ self.n_kv_heads = n_kv_heads
367
+ self.heads_per_group = self.n_heads // self.n_kv_heads
368
+
369
+ self.wq = nn.Linear(
370
+ dim,
371
+ n_heads * head_dim,
372
+ bias=False,
373
+ )
374
+ self.wk = nn.Linear(
375
+ dim,
376
+ n_kv_heads * head_dim,
377
+ bias=False,
378
+ )
379
+ self.wv = nn.Linear(
380
+ dim,
381
+ n_kv_heads * head_dim,
382
+ bias=False,
383
+ )
384
+
385
+ self.wo = nn.Linear(
386
+ n_heads * head_dim,
387
+ dim,
388
+ bias=False,
389
+ )
390
+
391
+ def forward(
392
+ self,
393
+ x: torch.Tensor,
394
+ freq_cis: torch.Tensor,
395
+ tok_idx: Optional[torch.Tensor] = None,
396
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
397
+ attn_impl: str = "sdpa",
398
+ ) -> torch.Tensor:
399
+ # B S D
400
+ bsz, seq_len, dim = x.shape
401
+ xq = self.wq(x.view_as(x))
402
+ xk = self.wk(x.view_as(x))
403
+ xv = self.wv(x.view_as(x))
404
+
405
+ output_shape = xq.shape
406
+ # B S D -> B S H D
407
+ xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
408
+ xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
409
+ xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
410
+
411
+ xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
412
+
413
+ # This condition helps us be easily compatible
414
+ # with inference by adding a pluggable KVCache
415
+ if hasattr(self, "kv_cache"):
416
+ xk, xv = self.kv_cache.update(xk, xv, tok_idx)
417
+
418
+ xk = repeat_kv(xk, self.heads_per_group, dim=2)
419
+ xv = repeat_kv(xv, self.heads_per_group, dim=2)
420
+
421
+ if attn_impl == "flex_attention":
422
+ assert mask is None or isinstance(mask, BlockMask)
423
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
424
+ output = flex_attention(xq, xk, xv, block_mask=mask)
425
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
426
+
427
+ elif attn_impl == "fmha":
428
+ assert mask is None or isinstance(mask, AttentionBias)
429
+ output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
430
+ # This uses B S H D instead of B H S D of pytorch
431
+
432
+ elif attn_impl == "sdpa":
433
+ xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
434
+ assert mask is None or isinstance(mask, (str, torch.Tensor))
435
+ is_causal = (mask == "causal") if isinstance(mask, str) else False
436
+ mask = mask if isinstance(mask, torch.Tensor) else None
437
+ output = F.scaled_dot_product_attention(
438
+ xq,
439
+ xk,
440
+ xv,
441
+ is_causal=is_causal,
442
+ attn_mask=mask,
443
+ )
444
+ output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
445
+ else:
446
+ raise NotImplementedError(
447
+ f"Attention implementation {attn_impl} not supported"
448
+ )
449
+
450
+ output = self.wo(output.reshape(output_shape))
451
+
452
+ return output
453
+
454
+ def reset_parameters(self, init_std=None, factor=1.0):
455
+ init_std = init_std or (self.dim ** (-0.5))
456
+
457
+ for w in [self.wq, self.wk, self.wv]:
458
+ nn.init.trunc_normal_(
459
+ w.weight,
460
+ mean=0.0,
461
+ std=init_std,
462
+ a=-3 * init_std,
463
+ b=3 * init_std,
464
+ )
465
+
466
+ nn.init.trunc_normal_(
467
+ self.wo.weight,
468
+ mean=0.0,
469
+ std=init_std / factor,
470
+ a=-3 * init_std,
471
+ b=3 * init_std,
472
+ )
473
+
474
+
475
+ class FeedForward(nn.Module):
476
+ def __init__(
477
+ self,
478
+ dim: int,
479
+ hidden_dim: int,
480
+ multiple_of: int,
481
+ ffn_dim_multiplier: Optional[float],
482
+ mp_size: int = 1,
483
+ ):
484
+ super().__init__()
485
+
486
+ hidden_dim = int(2 * hidden_dim / 3)
487
+ if ffn_dim_multiplier is not None:
488
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
489
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
490
+ assert hidden_dim % mp_size == 0
491
+
492
+ self.dim = dim
493
+ self.hidden_dim = hidden_dim
494
+
495
+ self.w1 = nn.Linear(
496
+ dim,
497
+ hidden_dim,
498
+ bias=False,
499
+ )
500
+ self.w3 = nn.Linear(
501
+ dim,
502
+ hidden_dim,
503
+ bias=False,
504
+ )
505
+ self.w2 = nn.Linear(
506
+ hidden_dim,
507
+ dim,
508
+ bias=False,
509
+ )
510
+
511
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
512
+ # B S D
513
+ x1 = self.w1(x.view_as(x))
514
+ x3 = self.w3(x.view_as(x))
515
+ output = self.w2(F.silu(x1) * x3)
516
+ return output
517
+
518
+ def reset_parameters(self, init_std=None, factor=1.0):
519
+ in_init_std = init_std or (self.dim ** (-0.5))
520
+ out_init_std = init_std or (self.hidden_dim ** (-0.5))
521
+ in_init_std = in_init_std
522
+ out_init_std = out_init_std / factor
523
+ for w in [self.w1, self.w3]:
524
+ nn.init.trunc_normal_(
525
+ w.weight,
526
+ mean=0.0,
527
+ std=in_init_std,
528
+ a=-3 * in_init_std,
529
+ b=3 * in_init_std,
530
+ )
531
+ nn.init.trunc_normal_(
532
+ self.w2.weight,
533
+ mean=0.0,
534
+ std=out_init_std,
535
+ a=-3 * out_init_std,
536
+ b=3 * out_init_std,
537
+ )
538
+
539
+
540
+ class TransformerBlock(nn.Module):
541
+ def __init__(self, args: BaseTransformerArgs):
542
+ super().__init__()
543
+
544
+ assert (args.head_dim is not None) or (
545
+ args.n_heads is not None
546
+ ), "Should specify at least head_dim or n_heads"
547
+ self.head_dim = args.head_dim or args.dim // args.n_heads
548
+ self.n_heads = args.n_heads or args.dim // args.head_dim
549
+ self.n_kv_heads = args.n_kv_heads or self.n_heads
550
+
551
+ assert args.n_heads % self.n_kv_heads == 0
552
+ assert args.dim % args.n_heads == 0
553
+
554
+ self.attention = Attention(
555
+ dim=args.dim,
556
+ head_dim=self.head_dim,
557
+ n_heads=self.n_heads,
558
+ n_kv_heads=self.n_kv_heads,
559
+ rope_theta=args.rope_theta,
560
+ )
561
+ self.feed_forward = FeedForward(
562
+ dim=args.dim,
563
+ hidden_dim=4 * args.dim,
564
+ multiple_of=args.multiple_of,
565
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
566
+ )
567
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
568
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
569
+
570
+ def forward(
571
+ self,
572
+ x: torch.Tensor,
573
+ freq_cis: torch.Tensor,
574
+ tok_idx: Optional[torch.Tensor] = None,
575
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
576
+ attn_impl: str = "sdpa",
577
+ ) -> torch.Tensor:
578
+
579
+ h = x + self.attention(
580
+ self.attention_norm(x),
581
+ freq_cis,
582
+ tok_idx=tok_idx,
583
+ mask=mask,
584
+ attn_impl=attn_impl,
585
+ )
586
+ out = h + self.feed_forward(self.ffn_norm(h))
587
+ return out
588
+
589
+ def init_weights(self, init_std=None, factor=1.0):
590
+ self.attention.reset_parameters(init_std, factor)
591
+ self.attention_norm.reset_parameters()
592
+
593
+ self.feed_forward.reset_parameters(init_std, factor)
594
+ self.ffn_norm.reset_parameters()
595
+
596
+
597
+ class BaseTransformer(nn.Module):
598
+ def __init__(self, args: BaseTransformerArgs):
599
+ super().__init__()
600
+ self.dim = args.dim
601
+ self.init_base_std = args.init_base_std
602
+ self.init_std_factor = InitStdFactor(args.init_std_factor)
603
+ self.max_seqlen = args.max_seqlen
604
+ self.rope_embeddings = RotaryEmbedding(
605
+ theta=args.rope_theta,
606
+ head_dim=args.head_dim or args.dim // args.n_heads,
607
+ max_seqlen=args.max_seqlen,
608
+ scale_factor=args.rope_scale_factor,
609
+ low_freq_factor=args.low_freq_factor,
610
+ high_freq_factor=args.high_freq_factor,
611
+ old_context_len=args.old_context_len,
612
+ )
613
+
614
+ self.layers = nn.ModuleList()
615
+ for _ in range(args.n_layers):
616
+ self.layers.append(TransformerBlock(args))
617
+
618
+ def forward(
619
+ self,
620
+ h,
621
+ tok_idx: Optional[torch.Tensor] = None,
622
+ mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
623
+ attn_impl: str = "sdpa",
624
+ ):
625
+
626
+ freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
627
+
628
+ for i, layer in enumerate(self.layers):
629
+ h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
630
+ return h
631
+
632
+ def reset_parameters(self):
633
+ # Either use fixed base std or sqrt model dim
634
+ self.rope_embeddings.reset_parameters()
635
+
636
+ def init_weights(self):
637
+ self.reset_parameters()
638
+ for depth, layer in enumerate(self.layers):
639
+ factor = {
640
+ InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
641
+ InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
642
+ InitStdFactor.DIM_RATIO: self.dim / 4096,
643
+ InitStdFactor.DISABLED: 1.0,
644
+ }[self.init_std_factor]
645
+
646
+ layer.init_weights(self.init_base_std, factor)
core/transforms/image_transform.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import math
4
+ from functools import reduce
5
+ from logging import getLogger
6
+ from typing import Any, Callable, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torchvision.transforms as tv
11
+ from PIL import Image
12
+ from torchvision.transforms import functional as F
13
+ from torchvision.transforms.functional import InterpolationMode
14
+
15
+ logger = getLogger()
16
+
17
+
18
+ MEAN = (0.5, 0.5, 0.5)
19
+ STD = (0.5, 0.5, 0.5)
20
+
21
+
22
+ def get_image_transform(
23
+ vision_input_type: str = "vanilla",
24
+ image_res: int = 336,
25
+ max_num_tiles: int = 1,
26
+ normalize_img: bool = True,
27
+ ) -> Tuple[Callable, int]:
28
+
29
+ if vision_input_type == "thumb+tile":
30
+ transforms = VariableSizeImageTransform(
31
+ size=image_res,
32
+ max_num_tiles=max_num_tiles,
33
+ normalize_img=normalize_img,
34
+ use_thumbnail="before",
35
+ )
36
+ else:
37
+ transforms = ImageTransform(
38
+ size=image_res,
39
+ normalize_img=normalize_img,
40
+ )
41
+
42
+ logger.info(
43
+ f"Initalized transforms with: vision_input_type: '{vision_input_type}' and max_num_tiles: {max_num_tiles}."
44
+ )
45
+
46
+ return transforms
47
+
48
+
49
+ class ImageTransform(object):
50
+ """
51
+ Image transform will resize the longer edge to a given size and pad the shorter edge with mean pixel value of the image.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ size: int = 336,
57
+ normalize_img: bool = True,
58
+ ) -> None:
59
+ self.size = size
60
+ self._mean = MEAN
61
+ self._std = STD
62
+
63
+ logger.info(f"ImageTransform size: {self.size}")
64
+
65
+ self.to_tensor = tv.ToTensor()
66
+ self.normalize = (
67
+ tv.Normalize(
68
+ mean=self._mean,
69
+ std=self._std,
70
+ inplace=True,
71
+ )
72
+ if normalize_img
73
+ else lambda x: x
74
+ )
75
+
76
+ def __call__(self, image: Image.Image):
77
+ w, h = image.size
78
+ image = F.resize(
79
+ image, (self.size, self.size), interpolation=InterpolationMode.BICUBIC
80
+ )
81
+ image = self.to_tensor(image)
82
+ image = self.normalize(image)
83
+
84
+ # Add chunk dim to make it compatible with existing dataloaders
85
+ image = image.view(1, 3, self.size, self.size)
86
+ return image, (w, h)
87
+
88
+ def _transform_torch_tensor(self, image: torch.Tensor):
89
+ h, w = image.shape[-2:] # Image shape (C, H, W) or (N, C, H, W)
90
+ image = F.resize(
91
+ image, size=(self.size, self.size), interpolation=InterpolationMode.BICUBIC
92
+ )
93
+ image = (
94
+ image.to(torch.float32) / 255.0
95
+ ) # Convert to float and scale to [0, 1] range
96
+ image = self.normalize(image)
97
+ return image, (w, h)
98
+
99
+
100
+ class VariableSizeImageTransform(object):
101
+ """
102
+ The variable size image transform will resize the image dynamically
103
+ based on the image aspect ratio and the number of image chunks we allow.
104
+
105
+ The algorithm will not upsample low-res images to fit a certain aspect
106
+ ratio, because that leads to a significant degradation in image quality.
107
+
108
+ For example, if an input image is of size 300x800, and we want to allow
109
+ a maximum of 16 image chunks, it will find the closest aspect ratio that
110
+ is allowed within 16 image chunks, i.e., 2:5 = 2 horizontal patches and
111
+ 5 vertical patches, giving a total of 10 chunks.
112
+
113
+ The image will then be resized to products of the base size (default is
114
+ 224px because MetaCLIP takes that), so in this case it will be resized to
115
+ 2*224:5*224 = 448:1120, where we maintain the original aspect ratio and
116
+ pad with the mean value for the rest. This approach minimizes the amount
117
+ of padding required for any arbitrary resolution.
118
+
119
+ The final output will therefore be of shape (11, 3, 224, 224), where 10
120
+ patches are coming from the resizing and chunking, and the first patch
121
+ is a downsampled version of the image that preserves aspect ratios.
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ size: int = 336,
127
+ normalize_img: bool = True,
128
+ max_num_tiles: int = 1,
129
+ use_thumbnail: str = "no",
130
+ area_limit: bool = False,
131
+ ) -> None:
132
+ self.size = size
133
+ self._mean = MEAN
134
+ self._std = STD
135
+
136
+ logger.info(f"VariableSizeImageTransform size: {self.size}")
137
+
138
+ self.to_tensor = tv.ToTensor()
139
+ self.normalize = (
140
+ tv.Normalize(
141
+ mean=self._mean,
142
+ std=self._std,
143
+ inplace=True,
144
+ )
145
+ if normalize_img
146
+ else lambda x: x
147
+ )
148
+ self.area_limit = area_limit
149
+ self.max_num_tiles = max_num_tiles
150
+ self.use_thumbnail = use_thumbnail
151
+ if self.use_thumbnail != "no":
152
+ self.thumbnail_transform = ImageTransform(
153
+ size=self.size,
154
+ normalize_img=normalize_img,
155
+ )
156
+
157
+ @staticmethod
158
+ def _factors(n: int):
159
+ """Return all factors of a number."""
160
+ return set(
161
+ reduce(
162
+ list.__add__,
163
+ ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
164
+ )
165
+ )
166
+
167
+ def _find_supported_aspect_ratios(self):
168
+ """
169
+ This function computes all the allowed aspect ratios for a fixed
170
+ number of input chunks.
171
+
172
+ For example, with `num_tiles=5`, it will return:
173
+ {
174
+ 0.2: [(1, 5)],
175
+ 5.0: [(5, 1)],
176
+ 0.25: [(1, 4)],
177
+ 1.0: [(2, 2), (1, 1)],
178
+ 4.0: [(4, 1)],
179
+ 0.3333333333333333: [(1, 3)],
180
+ 3.0: [(3, 1)],
181
+ 0.5: [(1, 2)],
182
+ 2.0: [(2, 1)]
183
+ }
184
+ """
185
+ asp_dict = {}
186
+ for chunk_size in range(self.max_num_tiles, 0, -1):
187
+ _factors = sorted(VariableSizeImageTransform._factors(chunk_size))
188
+ _asp_ratios = [(x, chunk_size // x) for x in _factors]
189
+ for ratio in _asp_ratios:
190
+ k = ratio[0] / ratio[1]
191
+ if k not in asp_dict:
192
+ asp_dict[k] = [ratio]
193
+ else:
194
+ asp_dict[k].append(ratio)
195
+ return asp_dict
196
+
197
+ def _find_closest_aspect_ratio(self, img_width: int, img_height: int) -> Tuple:
198
+ """
199
+ Given an image width, height and target number of chunks
200
+ this function will find the closest supported aspect ratio.
201
+ """
202
+ tgt_ar = img_width / img_height
203
+ asp_dict = self._find_supported_aspect_ratios()
204
+ cl_d, cl_p = 1e23, None
205
+ if tgt_ar >= 1:
206
+ cl_p = min(
207
+ [k for k in asp_dict.keys() if k <= tgt_ar],
208
+ key=lambda x: abs(x - tgt_ar),
209
+ )
210
+ v = asp_dict[cl_p]
211
+ # select width
212
+ widths = [(idx, self.size * vv[0]) for idx, vv in enumerate(v)]
213
+ tgt_idx = max(widths, key=lambda x: x[1])[0]
214
+ else:
215
+ cl_p = min(
216
+ [k for k in asp_dict.keys() if k > tgt_ar],
217
+ key=lambda x: abs(1 / x - 1 / tgt_ar),
218
+ )
219
+ v = asp_dict[cl_p]
220
+ # select height
221
+ heights = [(idx, self.size * vv[1]) for idx, vv in enumerate(v)]
222
+ tgt_idx = max(heights, key=lambda x: x[1])[0]
223
+ out = v[tgt_idx]
224
+ return out
225
+
226
+ def _resize(
227
+ self, image: Image.Image, target_width: int, target_height: int
228
+ ) -> Image.Image:
229
+ # Resize longer edge to given size.
230
+ w, h = image.size
231
+ scale = w / h
232
+
233
+ if scale > 1.0:
234
+ # width > height
235
+ new_w = target_width
236
+ new_h = math.floor(new_w / scale)
237
+ else:
238
+ # height >= width
239
+ new_h = target_height
240
+ new_w = math.floor(new_h * scale)
241
+
242
+ image = F.resize(image, (new_h, new_w))
243
+ return image
244
+
245
+ def _pad(self, image: Image.Image, new_width: int, new_height: int) -> Image.Image:
246
+ mean_per_channel = tuple(
247
+ np.clip(np.array(image).mean(axis=(0, 1)), 0, 255).astype(np.uint8)
248
+ )
249
+ new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
250
+ new_im.paste(image)
251
+ return new_im
252
+
253
+ def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
254
+ # Split image into number of required tiles (width x height)
255
+ num_channels, height, width = image.size()
256
+ image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
257
+ # Permute dimensions to reorder the axes
258
+ image = image.permute(1, 3, 0, 2, 4).contiguous()
259
+ # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
260
+ image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
261
+ return image
262
+
263
+ def _get_image_height_width(
264
+ self, image_width: int, image_height: int, target_width: int, target_height: int
265
+ ) -> Tuple[int, int]:
266
+ """
267
+ Given image width, height and target width, height for the canvas, return the dimensions of how the image would be resized
268
+ with aspect ratio preservation.
269
+ """
270
+ scale = image_width / image_height
271
+
272
+ if scale > 1.0:
273
+ # Width is larger than height
274
+
275
+ # Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
276
+ rescaling_factor = min(
277
+ target_width / image_width, target_height / image_height
278
+ )
279
+
280
+ # Set new width to target width and height to the rescaled height.
281
+ new_w = rescaling_factor * image_width
282
+ new_h = math.floor(new_w / scale)
283
+
284
+ else:
285
+ # Height is larger than width
286
+
287
+ # Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
288
+ rescaling_factor = min(
289
+ target_width / image_width, target_height / image_height
290
+ )
291
+
292
+ # Set new height to target height and width to the rescaled width.
293
+ new_h = rescaling_factor * image_height
294
+ new_w = math.floor(new_h * scale)
295
+
296
+ return new_w, new_h
297
+
298
+ def _fit_image_to_canvas(
299
+ self, img_width: int, img_height: int, area_limit: bool
300
+ ) -> Any:
301
+ """
302
+ Given an image width, height and target number of chunks this function will see if the image
303
+ can be fit into any of the canvases that can be build from arranging the tiles in a grid.
304
+ If the image can be fit onto several canvases, it will return the canvas where the shorter edge
305
+ of the image will be largest.
306
+
307
+ If area_limit is set to True, the tie-breaking prefers the canvas where area is less than 2x the original area.
308
+ """
309
+ # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None.
310
+ optimal_canvas = None
311
+ optimal_image_width_height = None
312
+
313
+ scale = img_width / img_height
314
+
315
+ # Gather all potential supported image resolutions and iterate through them to find best match
316
+ potential_arrangements = [
317
+ item
318
+ for sublist in self._find_supported_aspect_ratios().values()
319
+ for item in sublist
320
+ ]
321
+ for n_w, n_h in potential_arrangements:
322
+ # Compute the canvas size
323
+ canvas_width, canvas_height = n_w * self.size, n_h * self.size
324
+
325
+ # Check if image can fit into the canvas without downsampling
326
+ if canvas_width >= img_width and canvas_height >= img_height:
327
+ # If we did not find a good canvas yet, we will use the current one
328
+ if optimal_canvas is None:
329
+ # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling
330
+ optimal_canvas = (n_w, n_h)
331
+ optimal_image_width_height = self._get_image_height_width(
332
+ image_width=img_width,
333
+ image_height=img_height,
334
+ target_width=n_w * self.size,
335
+ target_height=n_h * self.size,
336
+ )
337
+ else:
338
+ # If we already found an optimal canvas before, we will check if the shorter edge of the image will be larger than the current optimal canvas.
339
+ # This means we can potentially upsample the image resolution which is beneficial to performance.
340
+ image_width_height = self._get_image_height_width(
341
+ image_width=img_width,
342
+ image_height=img_height,
343
+ target_width=n_w * self.size,
344
+ target_height=n_h * self.size,
345
+ )
346
+ if area_limit:
347
+ # Prioritize aspect ratio, and choose best within area limit when tied.
348
+ curr_scale = image_width_height[0] / image_width_height[1]
349
+ optim_scale = (
350
+ optimal_image_width_height[0]
351
+ / optimal_image_width_height[1]
352
+ )
353
+ if abs(scale - curr_scale) < abs(scale - optim_scale):
354
+ # 1. optimize aspect ratio
355
+ optimal_canvas = (n_w, n_h)
356
+ optimal_image_width_height = image_width_height
357
+ elif abs(scale - curr_scale) == abs(scale - optim_scale):
358
+ # 2. optimize area
359
+ if (
360
+ image_width_height[0] * image_width_height[1]
361
+ < 2 * img_width * img_height
362
+ ):
363
+ # 2.1 area is less than 2x the original area
364
+ optimal_canvas = (n_w, n_h)
365
+ optimal_image_width_height = image_width_height
366
+ else:
367
+ # NOTE: L3V dynamid tiling. Priortize biggest canvas.
368
+ if (
369
+ scale < 1.0
370
+ and (image_width_height[0] >= optimal_image_width_height[0])
371
+ ) or (
372
+ scale >= 1.0
373
+ and (image_width_height[1] >= optimal_image_width_height[1])
374
+ ):
375
+ optimal_canvas = (n_w, n_h)
376
+ optimal_image_width_height = image_width_height
377
+ return optimal_canvas
378
+
379
+ def __call__(self, image: Image.Image) -> Tuple[Any, Any]:
380
+ assert isinstance(image, Image.Image), type(image)
381
+ if self.use_thumbnail != "no":
382
+ thumbnail = self.thumbnail_transform(image)[0]
383
+
384
+ w, h = image.size
385
+ # Check if the image can be fit to the canvas without downsampling
386
+ ar = self._fit_image_to_canvas(
387
+ img_width=w, img_height=h, area_limit=self.area_limit
388
+ )
389
+ if ar is None:
390
+ # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image
391
+ ar = self._find_closest_aspect_ratio(img_width=w, img_height=h)
392
+
393
+ image = F.resize(
394
+ image,
395
+ (ar[1] * self.size, ar[0] * self.size), # (h, w)
396
+ interpolation=InterpolationMode.BICUBIC,
397
+ )
398
+ image = self._pad(image, ar[0] * self.size, ar[1] * self.size)
399
+ image = self.to_tensor(image)
400
+ image = self.normalize(image)
401
+ image = self._split(image, ar[0], ar[1]) # type: ignore
402
+ if self.use_thumbnail == "before":
403
+ image = torch.cat((thumbnail, image), dim=0)
404
+ elif self.use_thumbnail == "after":
405
+ image = torch.cat((image, thumbnail), dim=0)
406
+ elif self.use_thumbnail == "both":
407
+ image = torch.cat((thumbnail, image, thumbnail), dim=0)
408
+
409
+ return image, ar
core/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class InitArgs:
11
+ use_gaussian: bool = True # gaussian vs uniform
12
+ coeff_std: Optional[float] = None # std coeff multiplier
13
+ no_init: bool = False
14
+
15
+
16
+ def get_init_fn(
17
+ args: InitArgs, input_dim: int, init_depth: Optional[int]
18
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
19
+ """
20
+ Init functions.
21
+ """
22
+ if args.no_init:
23
+ return lambda x: x
24
+
25
+ # standard deviation
26
+ std = 1 / math.sqrt(input_dim)
27
+ std = std if args.coeff_std is None else (args.coeff_std * std)
28
+
29
+ # rescale with depth
30
+ if init_depth is not None:
31
+ std = std / math.sqrt(2 * init_depth)
32
+
33
+ # gaussian vs uniform
34
+ if args.use_gaussian:
35
+ return partial(
36
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
37
+ )
38
+ else:
39
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
40
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
core/vision_encoder/__init__.py ADDED
File without changes
core/vision_encoder/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (159 Bytes). View file
 
core/vision_encoder/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (218 Bytes). View file
 
core/vision_encoder/__pycache__/config.cpython-312.pyc ADDED
Binary file (4.31 kB). View file
 
core/vision_encoder/__pycache__/config.cpython-313.pyc ADDED
Binary file (4.26 kB). View file
 
core/vision_encoder/__pycache__/pe.cpython-312.pyc ADDED
Binary file (42.9 kB). View file
 
core/vision_encoder/__pycache__/pe.cpython-313.pyc ADDED
Binary file (38.1 kB). View file
 
core/vision_encoder/__pycache__/pe_lora.cpython-312.pyc ADDED
Binary file (35.5 kB). View file
 
core/vision_encoder/__pycache__/rope.cpython-312.pyc ADDED
Binary file (14.6 kB). View file
 
core/vision_encoder/__pycache__/rope.cpython-313.pyc ADDED
Binary file (14.6 kB). View file
 
core/vision_encoder/__pycache__/tokenizer.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
core/vision_encoder/__pycache__/tokenizer.cpython-313.pyc ADDED
Binary file (17.3 kB). View file
 
core/vision_encoder/__pycache__/transforms.cpython-312.pyc ADDED
Binary file (3.51 kB). View file
 
core/vision_encoder/__pycache__/transforms.cpython-313.pyc ADDED
Binary file (3.3 kB). View file
 
core/vision_encoder/config.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """
4
+ Include all available vision encoder configurations.
5
+ """
6
+
7
+ from dataclasses import dataclass, replace
8
+
9
+ from typing import Optional
10
+
11
+ from huggingface_hub import hf_hub_download
12
+
13
+
14
+
15
+ def fetch_pe_checkpoint(name: str, path: Optional[str] = None):
16
+ path = path or f"hf://facebook/{name}:{name}.pt"
17
+
18
+ if path.startswith("hf://"):
19
+ # Load from huggingface
20
+ path = path[len("hf://"):]
21
+ repo, file = path.split(":")
22
+
23
+ return hf_hub_download(repo_id=repo, filename=file)
24
+ else:
25
+ return path
26
+
27
+
28
+
29
+
30
+ @dataclass
31
+ class PEConfig:
32
+ """ Vision Tower Config. """
33
+ patch_size: int
34
+ width: int
35
+ layers: int
36
+ heads: int
37
+ mlp_ratio: float
38
+ output_dim: Optional[int]
39
+
40
+ ls_init_value: float = None
41
+ drop_path: float = 0.0
42
+
43
+ image_size: int = 224,
44
+ use_abs_posemb: bool = True
45
+ use_cls_token: bool = False
46
+ use_rope2d: bool = True
47
+
48
+ pool_type: str = "attn"
49
+ attn_pooler_heads: int = 8
50
+
51
+ use_ln_pre: bool = True
52
+ use_ln_post: bool = True
53
+
54
+
55
+ @dataclass
56
+ class PETextConfig:
57
+ """ Text Tower Config. """
58
+ context_length: int
59
+ width: int
60
+ heads: int
61
+ layers: int
62
+
63
+ output_dim: int
64
+
65
+ mlp_ratio: float = 4.0
66
+ vocab_size: int = 49408
67
+
68
+
69
+
70
+
71
+ PE_VISION_CONFIG = {}
72
+ PE_TEXT_CONFIG = {}
73
+
74
+
75
+
76
+ #########################################
77
+ # PE CORE #
78
+ #########################################
79
+
80
+ PE_VISION_CONFIG["PE-Core-G14-448"] = PEConfig(
81
+ image_size=448,
82
+ patch_size=14,
83
+ width=1536,
84
+ layers=50,
85
+ heads=16,
86
+ mlp_ratio=8960 / 1536,
87
+ pool_type="attn",
88
+ output_dim=1280,
89
+ use_cls_token=False,
90
+ )
91
+ PE_TEXT_CONFIG["PE-Core-G14-448"] = PETextConfig(
92
+ context_length=72,
93
+ width=1280,
94
+ heads=20,
95
+ layers=24,
96
+ output_dim=1280
97
+ )
98
+
99
+
100
+ PE_VISION_CONFIG["PE-Core-L14-336"] = PEConfig(
101
+ image_size=336,
102
+ patch_size=14,
103
+ width=1024,
104
+ layers=24,
105
+ heads=16,
106
+ mlp_ratio=4.0,
107
+ pool_type="attn",
108
+ output_dim=1024,
109
+ use_cls_token=True,
110
+ )
111
+ PE_TEXT_CONFIG["PE-Core-L14-336"] = PETextConfig(
112
+ context_length=32,
113
+ width=1024,
114
+ heads=16,
115
+ layers=24,
116
+ output_dim=1024
117
+ )
118
+
119
+
120
+ PE_VISION_CONFIG["PE-Core-B16-224"] = PEConfig(
121
+ image_size=224,
122
+ patch_size=16,
123
+ width=768,
124
+ layers=12,
125
+ heads=12,
126
+ mlp_ratio=4.0,
127
+ pool_type="attn",
128
+ output_dim=1024,
129
+ use_cls_token=True,
130
+ )
131
+ PE_TEXT_CONFIG["PE-Core-B16-224"] = PE_TEXT_CONFIG["PE-Core-L14-336"]
132
+
133
+
134
+
135
+
136
+ PE_VISION_CONFIG["PE-Core-S16-384"] = PEConfig(
137
+ image_size=384,
138
+ patch_size=16,
139
+ width=384,
140
+ layers=12,
141
+ heads=6,
142
+ mlp_ratio=4.0,
143
+ pool_type="attn",
144
+ output_dim=512,
145
+ use_cls_token=True,
146
+ )
147
+ PE_TEXT_CONFIG["PE-Core-S16-384"] = PETextConfig(
148
+ context_length=32,
149
+ width=512,
150
+ heads=8,
151
+ layers=12,
152
+ output_dim=512
153
+ )
154
+
155
+
156
+
157
+ PE_VISION_CONFIG["PE-Core-T16-384"] = PEConfig(
158
+ image_size=384,
159
+ patch_size=16,
160
+ width=192,
161
+ layers=12,
162
+ heads=3,
163
+ mlp_ratio=4.0,
164
+ pool_type="attn",
165
+ output_dim=512,
166
+ use_cls_token=True,
167
+ )
168
+ PE_TEXT_CONFIG["PE-Core-T16-384"] = PE_TEXT_CONFIG["PE-Core-S16-384"]
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+ #########################################
177
+ # PE Lang #
178
+ #########################################
179
+
180
+ PE_VISION_CONFIG["PE-Lang-G14-448"] = replace(
181
+ PE_VISION_CONFIG["PE-Core-G14-448"],
182
+ image_size=448,
183
+ pool_type="none",
184
+ use_ln_post=False,
185
+ output_dim=None,
186
+ ls_init_value=0.1,
187
+ layers=47,
188
+ )
189
+
190
+ PE_VISION_CONFIG["PE-Lang-L14-448"] = replace(
191
+ PE_VISION_CONFIG["PE-Core-L14-336"],
192
+ image_size=448,
193
+ pool_type="none",
194
+ use_ln_post=False,
195
+ output_dim=None,
196
+ ls_init_value=0.1,
197
+ layers=23
198
+ )
199
+
200
+
201
+ # Stage 2 checkpoints for PLM-8B and PLM-3B respectively. Pretrained with tiling.
202
+ # Use these checkpoints if you're building a model that uses tiling downstream!
203
+ PE_VISION_CONFIG["PE-Lang-G14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-G14-448"]
204
+ PE_VISION_CONFIG["PE-Lang-L14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-L14-448"]
205
+
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+ #########################################
214
+ # PE Spatial #
215
+ #########################################
216
+
217
+ PE_VISION_CONFIG["PE-Spatial-G14-448"] = replace(
218
+ PE_VISION_CONFIG["PE-Core-G14-448"],
219
+ image_size=448,
220
+ pool_type="none",
221
+ use_ln_post=False,
222
+ output_dim=None,
223
+ ls_init_value=0.1,
224
+ )
225
+
226
+ # No layerscale on the smaller spatial models
227
+ PE_VISION_CONFIG["PE-Spatial-L14-448"] = replace(
228
+ PE_VISION_CONFIG["PE-Core-L14-336"],
229
+ image_size=448,
230
+ pool_type="none",
231
+ use_ln_post=False,
232
+ output_dim=None,
233
+ )
234
+
235
+
236
+ PE_VISION_CONFIG["PE-Spatial-B16-512"] = replace(
237
+ PE_VISION_CONFIG["PE-Core-B16-224"],
238
+ image_size=512,
239
+ pool_type="none",
240
+ use_ln_post=False,
241
+ output_dim=None,
242
+ )
243
+
244
+
245
+ PE_VISION_CONFIG["PE-Spatial-S16-512"] = replace(
246
+ PE_VISION_CONFIG["PE-Core-S16-384"],
247
+ image_size=512,
248
+ pool_type="none",
249
+ use_ln_post=False,
250
+ output_dim=None,
251
+ )
252
+
253
+
254
+ PE_VISION_CONFIG["PE-Spatial-T16-512"] = replace(
255
+ PE_VISION_CONFIG["PE-Core-T16-384"],
256
+ image_size=512,
257
+ pool_type="none",
258
+ use_ln_post=False,
259
+ output_dim=None,
260
+ )
core/vision_encoder/pe.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from dataclasses import asdict
3
+ from functools import partial
4
+ from logging import getLogger
5
+ from typing import Callable, Optional, Literal
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from timm.layers import DropPath
12
+ from torch.nn import functional as F
13
+ from torch.nn.init import constant_, xavier_uniform_
14
+ from torch.nn.parameter import Parameter
15
+ from torch.utils.checkpoint import checkpoint
16
+ import types
17
+ from core.vision_encoder.rope import Rope2D
18
+ from core.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint
19
+
20
+
21
+
22
+ logger = getLogger()
23
+
24
+
25
+
26
+ class LayerScale(nn.Module):
27
+ def __init__(self, dim, init_values=1e-5, inplace=False):
28
+ super().__init__()
29
+ self.inplace = inplace
30
+ self.dim = dim
31
+ self.init_values = init_values
32
+
33
+ def forward(self, x):
34
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
35
+
36
+ def init_tensors(self):
37
+ self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))
38
+
39
+
40
+ class AttentionPooling(nn.Module):
41
+ def __init__(
42
+ self,
43
+ embed_dim: int,
44
+ num_heads: int,
45
+ num_probe: int = 1,
46
+ mlp_ratio: int = 4,
47
+ act_layer: Callable = nn.GELU,
48
+ norm_layer: Callable = nn.LayerNorm,
49
+ ):
50
+ super().__init__()
51
+ self.embed_dim = embed_dim
52
+ self.num_heads = num_heads
53
+ self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
54
+ self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
55
+ self.layernorm = norm_layer(embed_dim)
56
+ self.mlp_width = int(embed_dim * mlp_ratio)
57
+ self.mlp = nn.Sequential(
58
+ OrderedDict(
59
+ [
60
+ ("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
61
+ ("gelu", act_layer()),
62
+ ("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
63
+ ]
64
+ )
65
+ )
66
+ self._is_converted = False
67
+
68
+ def forward(self, x: torch.Tensor):
69
+ # This is the original forward method that will be replaced.
70
+ batch, _, _ = x.shape
71
+ q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
72
+ x_attn = self.attn(q, x, x, need_weights=False)[0]
73
+ x = x_attn + self.mlp(self.layernorm(x_attn))
74
+ return x
75
+
76
+
77
+
78
+ class SelfAttention(nn.Module):
79
+ r"""
80
+ Implements sequence packed attention and RoPe
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ embed_dim: int,
86
+ num_heads: int,
87
+ rope: Optional[nn.Module] = None,
88
+ ):
89
+ super(SelfAttention, self).__init__()
90
+ self.embed_dim = embed_dim
91
+
92
+ self.num_heads = num_heads
93
+ self.head_dim = embed_dim // num_heads
94
+ assert (
95
+ self.head_dim * num_heads == self.embed_dim
96
+ ), "embed_dim must be divisible by num_heads"
97
+
98
+ # To make this compatibile with nn.MultiHeadAttention
99
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
100
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
101
+
102
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
103
+
104
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
105
+
106
+ self.rope = rope
107
+ self.scale = self.head_dim ** (-0.5)
108
+
109
+ def init_tensors(self):
110
+ xavier_uniform_(self.in_proj_weight)
111
+ constant_(self.in_proj_bias, 0.0)
112
+ constant_(self.out_proj.bias, 0.0)
113
+
114
+
115
+ def del_muda(self):
116
+ del self.in_proj_weight
117
+ del self.in_proj_bias
118
+
119
+ def migrate_weights(self):
120
+ """
121
+ MUST be called *after* loading the state_dict.
122
+ This copies the weights from the old Parameters to the new nn.Linear layer.
123
+ """
124
+ # Use torch.no_grad() to ensure this is done without tracking gradients
125
+ with torch.no_grad():
126
+ self.in_proj.weight.copy_(self.in_proj_weight)
127
+ self.in_proj.bias.copy_(self.in_proj_bias)
128
+
129
+ # del self.in_proj_weight
130
+ # del self.in_proj_bias
131
+ # print("Migration complete. Old parameters have been removed.")
132
+
133
+ def forward(self, x, attn_mask=None, need_weights=False):
134
+ batch, seq, embed_dim = x.shape
135
+
136
+ #proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
137
+ proj = self.in_proj(x)
138
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
139
+ proj = (
140
+ proj.unflatten(-1, (3, embed_dim))
141
+ .unsqueeze(0)
142
+ .transpose(0, -2)
143
+ .squeeze(-2)
144
+ .contiguous()
145
+ )
146
+ q, k, v = proj[0], proj[1], proj[2]
147
+
148
+ # Use "q_" so that we don't accidentally quit in pdb :)
149
+ q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
150
+ k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
151
+ v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
152
+
153
+ if self.rope:
154
+ q, k = self.rope(q, k)
155
+
156
+ if not need_weights:
157
+ # Original efficient path
158
+ attn = F.scaled_dot_product_attention(
159
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
160
+ )
161
+ attn = rearrange(attn, "b h s d -> b s (h d)")
162
+ return self.out_proj(attn)
163
+ else:
164
+ # Path to get attention weights
165
+ q_scaled = q * self.scale
166
+ # attn_weights shape: (batch, num_heads, seq_len, seq_len)
167
+ attn_weights = torch.matmul(q_scaled, k.transpose(-2, -1))
168
+
169
+ if attn_mask is not None:
170
+ attn_weights += attn_mask
171
+
172
+ attn_weights = F.softmax(attn_weights, dim=-1)
173
+
174
+ attn_output = torch.matmul(attn_weights, v)
175
+ attn_output = rearrange(attn_output, "b h s d -> b s (h d)")
176
+
177
+ output = self.out_proj(attn_output)
178
+ return output, attn_weights
179
+
180
+
181
+ class ResidualAttentionBlock(nn.Module):
182
+ def __init__(
183
+ self,
184
+ d_model: int,
185
+ n_head: int,
186
+ mlp_ratio: float = 4.0,
187
+ ls_init_value: float = None,
188
+ act_layer: Callable = nn.GELU,
189
+ norm_layer: Callable = nn.LayerNorm,
190
+ drop_path: float = 0.0,
191
+ rope: Optional[nn.Module] = None,
192
+ ):
193
+ super().__init__()
194
+
195
+ if rope:
196
+ self.attn = SelfAttention(d_model, n_head, rope=rope)
197
+ else:
198
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
199
+
200
+ self.ls_1 = (
201
+ LayerScale(d_model, ls_init_value)
202
+ if ls_init_value is not None
203
+ else nn.Identity()
204
+ )
205
+ self.ls_2 = (
206
+ LayerScale(d_model, ls_init_value)
207
+ if ls_init_value is not None
208
+ else nn.Identity()
209
+ )
210
+
211
+ self.ln_1 = norm_layer(d_model)
212
+ self.ln_2 = norm_layer(d_model)
213
+
214
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
215
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
216
+
217
+ mlp_width = int(d_model * mlp_ratio)
218
+ self.mlp = nn.Sequential(
219
+ OrderedDict(
220
+ [
221
+ ("c_fc", nn.Linear(d_model, mlp_width)),
222
+ ("gelu", act_layer()),
223
+ ("c_proj", nn.Linear(mlp_width, d_model)),
224
+ ]
225
+ )
226
+ )
227
+
228
+ def _call_attn(
229
+ self,
230
+ q_x: torch.Tensor,
231
+ attn_mask: Optional[torch.Tensor] = None,
232
+ need_weights: bool = False,
233
+ ):
234
+
235
+ if attn_mask is not None:
236
+ if not attn_mask.dtype == torch.bool:
237
+ attn_mask = attn_mask.to(q_x.dtype)
238
+
239
+ if isinstance(self.attn, SelfAttention):
240
+ # Pass the flag to your custom SelfAttention
241
+ return self.attn(q_x, attn_mask=attn_mask, need_weights=need_weights)
242
+ else:
243
+ # Standard nn.MultiheadAttention
244
+ return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=need_weights)[0]
245
+
246
+ def forward(
247
+ self,
248
+ x: torch.Tensor,
249
+ attn_mask: Optional[torch.Tensor] = None,
250
+ need_weights: bool = False,
251
+ ):
252
+ attn_result = self._call_attn(self.ln_1(x), attn_mask=attn_mask, need_weights=need_weights)
253
+
254
+ attn_weights = None
255
+ if need_weights:
256
+ # Unpack the output and the weights
257
+ attn_output, attn_weights = attn_result
258
+ else:
259
+ attn_output = attn_result
260
+
261
+ x = x + self.drop_path1(self.ls_1(attn_output))
262
+ x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
263
+
264
+ if need_weights:
265
+ return x, attn_weights # Return weights
266
+
267
+ return x
268
+
269
+ def del_muda(self):
270
+ self.attn.del_muda()
271
+
272
+ class Transformer(nn.Module):
273
+ def __init__(
274
+ self,
275
+ width: int,
276
+ layers: int,
277
+ heads: int,
278
+ mlp_ratio: float = 4.0,
279
+ ls_init_value: float = None,
280
+ act_layer: Callable = nn.GELU,
281
+ norm_layer: Callable = nn.LayerNorm,
282
+ drop_path: float = 0.0,
283
+ rope: Optional[nn.Module] = None,
284
+ ):
285
+ super().__init__()
286
+ self.width = width
287
+ self.layers = layers
288
+ self.grad_checkpointing = False
289
+
290
+ self.resblocks = nn.ModuleList(
291
+ [
292
+ ResidualAttentionBlock(
293
+ width,
294
+ heads,
295
+ mlp_ratio,
296
+ ls_init_value=ls_init_value,
297
+ act_layer=act_layer,
298
+ norm_layer=norm_layer,
299
+ drop_path=drop_path,
300
+ rope=rope,
301
+ )
302
+ for _ in range(layers)
303
+ ]
304
+ )
305
+
306
+ @torch.jit.ignore
307
+ def set_grad_checkpointing(self, enable=True):
308
+ self.grad_checkpointing = enable
309
+
310
+ @torch.jit.ignore
311
+ def truncate(self, layer_idx: int):
312
+ """ Delete layers so the last layer is the given layer index. """
313
+ self.layers = ((self.layers + layer_idx) % self.layers) + 1
314
+ self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
315
+
316
+ def del_muda(self):
317
+ for resblock in self.resblocks:
318
+ resblock.del_muda()
319
+
320
+ def forward(
321
+ self,
322
+ x: torch.Tensor,
323
+ attn_mask: Optional[torch.Tensor] = None,
324
+ layer_idx: int = -1,
325
+ need_weights: bool = False, # Add need_weights flag
326
+ ):
327
+ stop_idx = (self.layers + layer_idx) % self.layers
328
+
329
+ attention_maps = [] # List to store maps from each layer
330
+
331
+ for i, r in enumerate(self.resblocks):
332
+ if self.grad_checkpointing and not torch.jit.is_scripting():
333
+ if need_weights:
334
+ raise ValueError("Cannot get attention maps with gradient checkpointing enabled.")
335
+ x = checkpoint(r, x, attn_mask, use_reentrant=False)
336
+ else:
337
+ if need_weights:
338
+ x, attn_map = r(x, attn_mask=attn_mask, need_weights=True)
339
+ attention_maps.append(attn_map)
340
+ else:
341
+ x = r(x, attn_mask=attn_mask, need_weights=False)
342
+
343
+ if i == stop_idx:
344
+ break
345
+
346
+ if need_weights:
347
+ return x, attention_maps # Return the list of maps
348
+
349
+ return x
350
+
351
+
352
+ class VisionTransformer(nn.Module):
353
+ def __init__(
354
+ self,
355
+ patch_size: int,
356
+ width: int,
357
+ layers: int,
358
+ heads: int,
359
+ mlp_ratio: float,
360
+ act_layer: Callable = nn.GELU,
361
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
362
+ use_ln_pre: bool = True,
363
+ use_ln_post: bool = True,
364
+ ls_init_value: float = None,
365
+ drop_path: float = 0.0,
366
+ image_size: int = 448, # Pretrain image size only; you can pass in any image size
367
+ use_abs_posemb: bool = True,
368
+ use_rope2d: bool = True,
369
+ use_cls_token: bool = False,
370
+ output_dim: Optional[int] = 1280,
371
+ attn_pooler_heads: int = 8,
372
+ pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
373
+ ):
374
+ super().__init__()
375
+ assert pool_type in ("attn", "tok", "avg", "none")
376
+ self.pool_type = pool_type
377
+ self.patch_size = patch_size
378
+
379
+ self.output_dim = output_dim or width
380
+ self.proj_dim = output_dim
381
+ self.heads = heads
382
+ self.width = width
383
+ self.layers = layers
384
+
385
+ self.use_abs_posemb = use_abs_posemb
386
+ self.use_cls_token = use_cls_token
387
+ self.use_rope2d = use_rope2d
388
+ self.image_size = image_size
389
+
390
+ self.conv1 = nn.Conv2d(
391
+ in_channels=3,
392
+ out_channels=width,
393
+ kernel_size=patch_size,
394
+ stride=patch_size,
395
+ bias=False,
396
+ )
397
+ self.rope = (
398
+ Rope2D(
399
+ dim=width // heads,
400
+ use_cls_token=self.use_cls_token,
401
+ )
402
+ if self.use_rope2d
403
+ else None
404
+ )
405
+
406
+ self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
407
+ self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
408
+
409
+ self.transformer = Transformer(
410
+ width,
411
+ layers,
412
+ heads,
413
+ mlp_ratio,
414
+ ls_init_value=ls_init_value,
415
+ act_layer=act_layer,
416
+ norm_layer=norm_layer,
417
+ drop_path=drop_path,
418
+ rope=self.rope,
419
+ )
420
+
421
+ if pool_type == "attn":
422
+ self.attn_pool = AttentionPooling(
423
+ embed_dim=width,
424
+ num_heads=attn_pooler_heads,
425
+ act_layer=act_layer,
426
+ norm_layer=norm_layer,
427
+ )
428
+ else:
429
+ self.attn_pool = None
430
+
431
+ self.init_tensors()
432
+
433
+
434
+ def del_muda(self):
435
+ self.transformer.del_muda()
436
+
437
+ def delete_attn_pool(self):
438
+ del self.attn_pool
439
+
440
+
441
+ def init_tensors(self):
442
+ def init_submodule_tensors(module):
443
+ for name, child in module.named_children():
444
+ if hasattr(child, "init_tensors"):
445
+ logger.debug(f"Initializing tensors for submodule: {name}")
446
+ child.init_tensors()
447
+ init_submodule_tensors(child)
448
+
449
+ init_submodule_tensors(self)
450
+ self.rope.init_tensors()
451
+
452
+ # class embeddings and positional embeddings
453
+ init_scale = self.width**-0.5
454
+
455
+ if self.use_cls_token:
456
+ self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
457
+
458
+ if self.use_abs_posemb:
459
+ self.posemb_grid_size = self.image_size // self.patch_size
460
+ self.positional_embedding = nn.Parameter(
461
+ init_scale
462
+ * torch.randn(
463
+ int(self.use_cls_token) + self.posemb_grid_size**2, self.width
464
+ )
465
+ )
466
+
467
+ if self.proj_dim is not None:
468
+ self.proj = nn.Parameter(
469
+ init_scale * torch.randn(self.width, self.proj_dim)
470
+ )
471
+
472
+
473
+ def load_ckpt(self, ckpt_path: str, verbose: bool = True):
474
+ _sd = torch.load(ckpt_path, weights_only=True)
475
+ if "state_dict" in _sd:
476
+ _sd = _sd["state_dict"]
477
+ elif "weights" in _sd:
478
+ _sd = _sd["weights"]
479
+
480
+ # for backwards compatibility
481
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
482
+ if any(k.startswith("visual.") for k in _sd):
483
+ _sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}
484
+
485
+ m, u = self.load_state_dict(_sd, strict=False)
486
+
487
+ if verbose or (m or u):
488
+ logger.info(f"Missing keys for loading vision encoder: {m}")
489
+ logger.info(f"Unexpected keys for loading vision encoder: {u}")
490
+ print(f"Missing keys for loading vision encoder: {m}")
491
+ print(f"Unexpected keys for loading vision encoder: {u}")
492
+
493
+
494
+ def truncate(self, layer_idx: int):
495
+ """ Delete layers so the last layer is the given layer index. """
496
+ self.transformer.truncate(layer_idx)
497
+ self.layers = self.transformer.layers
498
+
499
+
500
+ @classmethod
501
+ def from_config(
502
+ cls,
503
+ name: str,
504
+ pretrained: bool = False,
505
+ checkpoint_path: Optional[str] = None,
506
+ **kwdargs
507
+ ):
508
+ if name not in PE_VISION_CONFIG:
509
+ raise RuntimeError(f"{name} not found in configs.")
510
+
511
+ args = asdict(PE_VISION_CONFIG[name])
512
+ args.update(kwdargs)
513
+
514
+ model = cls(**args)
515
+ if pretrained:
516
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
517
+
518
+ return model
519
+
520
+ @classmethod
521
+ def available_configs(cls):
522
+ return list(PE_VISION_CONFIG.keys())
523
+
524
+
525
+ @torch.jit.ignore
526
+ def set_grad_checkpointing(self, enable=True):
527
+ self.transformer.set_grad_checkpointing(enable=enable)
528
+
529
+ def _sample_abs_posemb(self, grid_h: int, grid_w: int):
530
+ """Interpolates the absolute position embedding if necessary."""
531
+ if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
532
+ return self.positional_embedding[None, ...]
533
+
534
+ pos_embed = self.positional_embedding
535
+ if self.use_cls_token:
536
+ cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
537
+
538
+ pos_embed = (
539
+ pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
540
+ .permute(0, 3, 1, 2)
541
+ .contiguous()
542
+ )
543
+ pos_embed = F.interpolate(
544
+ pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
545
+ )
546
+ pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
547
+
548
+ if self.use_cls_token:
549
+ pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
550
+
551
+ return pos_embed[None, ...]
552
+
553
+ def _pool(self, x: torch.Tensor):
554
+ if self.pool_type == "tok":
555
+ return x[:, 0]
556
+ elif self.pool_type == "avg":
557
+ return x.mean(dim=1)
558
+ elif self.pool_type == "attn":
559
+ return self.attn_pool(x).squeeze(1)
560
+ elif self.pool_type == "none":
561
+ return x
562
+ else:
563
+ raise NotImplementedError
564
+
565
+ def forward_features(
566
+ self,
567
+ x: torch.Tensor,
568
+ norm: bool = False,
569
+ layer_idx: int = -1,
570
+ strip_cls_token: bool = False,
571
+ need_weights: bool = False, # Add need_weights flag
572
+ ):
573
+ batch, _, h, w = x.shape
574
+ grid_h, grid_w = h // self.patch_size, w // self.patch_size
575
+
576
+ x = self.conv1(x)
577
+ x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
578
+
579
+ if self.use_cls_token:
580
+ x = torch.cat(
581
+ [self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
582
+ dim=1,
583
+ )
584
+
585
+ if self.use_abs_posemb:
586
+ x = x + self._sample_abs_posemb(grid_h, grid_w)
587
+
588
+ if self.use_rope2d:
589
+ self.rope.update_grid(x.device, grid_h, grid_w)
590
+
591
+ x = self.ln_pre(x)
592
+
593
+ # Get output from the transformer
594
+ transformer_output = self.transformer(x, layer_idx=layer_idx, need_weights=need_weights)
595
+
596
+ attention_maps = None
597
+ if need_weights:
598
+ x, attention_maps = transformer_output
599
+ else:
600
+ x = transformer_output
601
+
602
+ if norm:
603
+ x = self.ln_post(x)
604
+
605
+ if strip_cls_token and self.use_cls_token:
606
+ x = x[:, 1:, :]
607
+
608
+ if need_weights:
609
+ return x, attention_maps # Return maps
610
+
611
+ return x
612
+
613
+ def forward(self, x: torch.Tensor, **kwargs):
614
+ x = self.forward_features(x, norm=True, **kwargs)
615
+ x = self._pool(x)
616
+
617
+ if self.proj_dim is not None:
618
+ x = x @ self.proj
619
+
620
+ return x
621
+
622
+
623
+
624
+ class TextTransformer(nn.Module):
625
+ def __init__(
626
+ self,
627
+ context_length: int = 72,
628
+ vocab_size: int = 49408,
629
+ width: int = 512,
630
+ heads: int = 8,
631
+ layers: int = 12,
632
+ mlp_ratio: float = 4.0,
633
+ ls_init_value: float = None,
634
+ output_dim: int = 1280,
635
+ no_causal_mask: bool = False,
636
+ pad_id: int = 0,
637
+ pool_type: str = "argmax",
638
+ proj_bias: bool = False,
639
+ act_layer: Callable = nn.GELU,
640
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
641
+ output_tokens: bool = False,
642
+ use_ln_post: bool = True,
643
+ ):
644
+ super().__init__()
645
+ assert pool_type in ("first", "last", "argmax", "none")
646
+ self.pool_type = pool_type
647
+ self.output_tokens = output_tokens
648
+ self.num_pos = self.context_length = context_length
649
+ self.vocab_size = vocab_size
650
+ self.width = width
651
+ self.output_dim = output_dim
652
+ self.heads = heads
653
+ self.pad_id = pad_id
654
+ self.layers = layers
655
+
656
+ self.token_embedding = nn.Embedding(vocab_size, width)
657
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
658
+
659
+ self.transformer = Transformer(
660
+ width=width,
661
+ layers=layers,
662
+ heads=heads,
663
+ mlp_ratio=mlp_ratio,
664
+ ls_init_value=ls_init_value,
665
+ act_layer=act_layer,
666
+ norm_layer=norm_layer,
667
+ )
668
+
669
+ self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
670
+
671
+ if no_causal_mask:
672
+ self.attn_mask = None
673
+ else:
674
+ self.register_buffer(
675
+ "attn_mask", self.build_causal_mask(), persistent=False
676
+ )
677
+
678
+ if pool_type == "attn" or pool_type == "attn_eos":
679
+ self.attn_pool = AttentionPooling(
680
+ embed_dim=width,
681
+ num_heads=heads,
682
+ act_layer=act_layer,
683
+ norm_layer=norm_layer,
684
+ )
685
+ else: # argmax
686
+ self.attn_pool = None
687
+
688
+ if proj_bias:
689
+ self.text_projection = nn.Linear(width, output_dim)
690
+ else:
691
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
692
+
693
+ def build_causal_mask(self):
694
+ # lazily create causal attention mask, with full attention between the tokens
695
+ # pytorch uses additive attention mask; fill with -inf
696
+ mask = torch.empty(self.num_pos, self.num_pos)
697
+ mask.fill_(float("-inf"))
698
+ mask.triu_(1) # zero out the lower diagonal
699
+ return mask
700
+
701
+ def load_ckpt(self, ckpt_path: str, verbose: bool = True):
702
+ _sd = torch.load(ckpt_path, weights_only=True)
703
+ if "state_dict" in _sd:
704
+ _sd = _sd["state_dict"]
705
+ elif "weights" in _sd:
706
+ _sd = _sd["weights"]
707
+
708
+ _sd = {k.replace("module.", ""): v for k, v in _sd.items()}
709
+
710
+ m, u = self.load_state_dict(_sd, strict=False)
711
+
712
+ if verbose or (m or u):
713
+ logger.info(f"Missing keys for loading model: {m}")
714
+ logger.info(f"Unexpected keys for loading model: {u}")
715
+ print(f"Missing keys for loading model: {m}")
716
+ print(f"Unexpected keys for loading model: {u}")
717
+
718
+ def build_cls_mask(self, text):
719
+ cls_mask = (text != self.pad_id).unsqueeze(1)
720
+ cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
721
+ additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
722
+ additive_mask.fill_(0)
723
+ additive_mask.masked_fill_(~cls_mask, float("-inf"))
724
+ additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
725
+ return additive_mask
726
+
727
+ def text_global_pool(
728
+ self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
729
+ ):
730
+ if pool_type == "first":
731
+ pooled, tokens = x[:, 0], x[:, 1:]
732
+ elif pool_type == "last":
733
+ pooled, tokens = x[:, -1], x[:, :-1]
734
+ elif pool_type == "argmax":
735
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
736
+ assert text is not None
737
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
738
+ else:
739
+ pooled = tokens = x
740
+
741
+ return pooled, tokens
742
+
743
+ def forward(self, text):
744
+ seq_len = text.shape[1]
745
+ x = self.token_embedding(
746
+ text
747
+ )
748
+ attn_mask = self.attn_mask
749
+ if attn_mask is not None:
750
+ attn_mask = attn_mask[:seq_len, :seq_len]
751
+
752
+ x = x + self.positional_embedding[:seq_len]
753
+ x = self.transformer(x, attn_mask=attn_mask)
754
+
755
+ x = self.ln_final(x)
756
+ pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type)
757
+
758
+ if self.text_projection is not None:
759
+ if isinstance(self.text_projection, nn.Linear):
760
+ pooled = self.text_projection(pooled)
761
+ else:
762
+ pooled = pooled @ self.text_projection
763
+
764
+ if self.output_tokens:
765
+ return pooled, tokens
766
+
767
+ return pooled
768
+
769
+
770
+
771
+
772
+ class CLIP(TextTransformer):
773
+ def __init__(
774
+ self,
775
+ vision_cfg: PEConfig,
776
+ text_cfg: PETextConfig,
777
+ init_logit_scale: float = np.log(1 / 0.07)
778
+ ):
779
+ super(CLIP, self).__init__(**asdict(text_cfg))
780
+ self.visual = VisionTransformer(**asdict(vision_cfg))
781
+ self.image_size = self.visual.image_size # For ease of use
782
+ self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
783
+
784
+
785
+ def encode_image(self, image, normalize: bool = False):
786
+ x = self.visual(image)
787
+ return F.normalize(x, dim=-1) if normalize else x
788
+
789
+ def encode_video(self, video, normalize: bool = False): # b n c h w
790
+ b, n, c, h, w = video.shape
791
+ frms = video.reshape(b * n, c, h, w)
792
+ frm_feats = self.encode_image(frms, normalize=normalize)
793
+ video_feats = frm_feats.reshape(b, n, -1)
794
+ video_feats = video_feats.mean(dim=1)
795
+ return video_feats
796
+
797
+ def encode_text(self, text, normalize: bool = False):
798
+ x = super().forward(text)
799
+ return F.normalize(x, dim=-1) if normalize else x
800
+
801
+ def forward(
802
+ self,
803
+ image: Optional[torch.Tensor] = None,
804
+ text: Optional[torch.Tensor] = None,
805
+ ):
806
+ image_features = (
807
+ self.encode_image(image, normalize=True) if image is not None else None
808
+ )
809
+ text_features = (
810
+ self.encode_text(text, normalize=True) if text is not None else None
811
+ )
812
+ return image_features, text_features, self.logit_scale.exp()
813
+
814
+
815
+ @classmethod
816
+ def from_config(
817
+ cls,
818
+ name: str,
819
+ pretrained: bool = False,
820
+ checkpoint_path: Optional[str] = None # To load your own
821
+ ):
822
+ if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG:
823
+ raise RuntimeError(f"{name} not found in configs.")
824
+
825
+ model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name])
826
+ if pretrained:
827
+ model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
828
+
829
+ return model
830
+
831
+ @classmethod
832
+ def available_configs(cls):
833
+ return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG]
core/vision_encoder/rope.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from typing import Literal, Optional, Union
3
+
4
+ import torch
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor, broadcast_tensors, einsum, nn
7
+ from torch.amp import autocast
8
+ from torch.nn import Module
9
+
10
+ # helper functions
11
+
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+
21
+ # broadcat, as tortoise-tts was using it
22
+
23
+
24
+ def broadcat(tensors, dim=-1):
25
+ broadcasted_tensors = broadcast_tensors(*tensors)
26
+ return torch.cat(broadcasted_tensors, dim=dim)
27
+
28
+
29
+ # rotary embedding helper functions
30
+
31
+
32
+ def rotate_half(x):
33
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
34
+ x1, x2 = x.unbind(dim=-1)
35
+ x = torch.stack((-x2, x1), dim=-1)
36
+ return rearrange(x, "... d r -> ... (d r)")
37
+
38
+
39
+ @autocast("cuda", enabled=False)
40
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
41
+ dtype = t.dtype
42
+
43
+ if t.ndim == 3:
44
+ seq_len = t.shape[seq_dim]
45
+ freqs = freqs[-seq_len:]
46
+
47
+ rot_dim = freqs.shape[-1]
48
+ end_index = start_index + rot_dim
49
+
50
+ assert (
51
+ rot_dim <= t.shape[-1]
52
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
53
+
54
+ t_left, t, t_right = (
55
+ t[..., :start_index],
56
+ t[..., start_index:end_index],
57
+ t[..., end_index:],
58
+ )
59
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
60
+ out = torch.cat((t_left, t, t_right), dim=-1)
61
+
62
+ return out.type(dtype)
63
+
64
+
65
+ # learned rotation helpers
66
+
67
+
68
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
69
+ if exists(freq_ranges):
70
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
71
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
72
+
73
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
74
+ return apply_rotary_emb(rotations, t, start_index=start_index)
75
+
76
+
77
+ # classes
78
+
79
+
80
+ class RotaryEmbedding(Module):
81
+ def __init__(
82
+ self,
83
+ dim,
84
+ custom_freqs: Optional[Tensor] = None,
85
+ freqs_for: Union[
86
+ Literal["lang"], Literal["pixel"], Literal["constant"]
87
+ ] = "lang",
88
+ theta=10000,
89
+ max_freq=10,
90
+ num_freqs=1,
91
+ learned_freq=False,
92
+ use_xpos=False,
93
+ xpos_scale_base=512,
94
+ interpolate_factor=1.0,
95
+ theta_rescale_factor=1.0,
96
+ seq_before_head_dim=False,
97
+ cache_if_possible=True,
98
+ ):
99
+ super().__init__()
100
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
101
+ # has some connection to NTK literature
102
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
103
+
104
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
105
+
106
+ self.freqs_for = freqs_for
107
+
108
+ if exists(custom_freqs):
109
+ freqs = custom_freqs
110
+ elif freqs_for == "lang":
111
+ freqs = 1.0 / (
112
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
113
+ )
114
+ elif freqs_for == "pixel":
115
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
116
+ elif freqs_for == "constant":
117
+ freqs = torch.ones(num_freqs).float()
118
+
119
+ self.cache_if_possible = cache_if_possible
120
+
121
+ self.tmp_store("cached_freqs", None)
122
+ self.tmp_store("cached_scales", None)
123
+
124
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
125
+
126
+ self.learned_freq = learned_freq
127
+
128
+ # dummy for device
129
+
130
+ self.tmp_store("dummy", torch.tensor(0))
131
+
132
+ # default sequence dimension
133
+
134
+ self.seq_before_head_dim = seq_before_head_dim
135
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
136
+
137
+ # interpolation factors
138
+
139
+ assert interpolate_factor >= 1.0
140
+ self.interpolate_factor = interpolate_factor
141
+
142
+ # xpos
143
+
144
+ self.use_xpos = use_xpos
145
+ if not use_xpos:
146
+ self.tmp_store("scale", None)
147
+ return
148
+
149
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
150
+
151
+ self.scale_base = xpos_scale_base
152
+ self.tmp_store("scale", scale)
153
+
154
+ # add apply_rotary_emb as static method
155
+
156
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
157
+
158
+ @property
159
+ def device(self):
160
+ return self.dummy.device
161
+
162
+ def tmp_store(self, key, value):
163
+ self.register_buffer(key, value, persistent=False)
164
+
165
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
166
+ return (
167
+ torch.arange(seq_len, device=device, dtype=dtype) + offset
168
+ ) / self.interpolate_factor
169
+
170
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
171
+ seq_dim = default(seq_dim, self.default_seq_dim)
172
+
173
+ assert (
174
+ not self.use_xpos
175
+ ), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
176
+
177
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
178
+
179
+ freqs = self.forward(
180
+ self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
181
+ seq_len=seq_len,
182
+ offset=offset,
183
+ )
184
+
185
+ if seq_dim == -3:
186
+ freqs = rearrange(freqs, "n d -> n 1 d")
187
+
188
+ return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
189
+
190
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
191
+ seq_dim = default(seq_dim, self.default_seq_dim)
192
+
193
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
194
+ assert q_len <= k_len
195
+
196
+ rotated_q = self.rotate_queries_or_keys(
197
+ q, seq_dim=seq_dim, offset=k_len - q_len + offset
198
+ )
199
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
200
+
201
+ rotated_q = rotated_q.type(q.dtype)
202
+ rotated_k = rotated_k.type(k.dtype)
203
+
204
+ return rotated_q, rotated_k
205
+
206
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
207
+ seq_dim = default(seq_dim, self.default_seq_dim)
208
+
209
+ assert self.use_xpos
210
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
211
+
212
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
213
+
214
+ freqs = self.forward(seq, seq_len=seq_len)
215
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
216
+
217
+ if seq_dim == -3:
218
+ freqs = rearrange(freqs, "n d -> n 1 d")
219
+ scale = rearrange(scale, "n d -> n 1 d")
220
+
221
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
222
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
223
+
224
+ rotated_q = rotated_q.type(q.dtype)
225
+ rotated_k = rotated_k.type(k.dtype)
226
+
227
+ return rotated_q, rotated_k
228
+
229
+ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
230
+ assert self.use_xpos
231
+
232
+ should_cache = self.cache_if_possible and exists(seq_len)
233
+
234
+ if (
235
+ should_cache
236
+ and exists(self.cached_scales)
237
+ and (seq_len + offset) <= self.cached_scales.shape[0]
238
+ ):
239
+ return self.cached_scales[offset : (offset + seq_len)]
240
+
241
+ scale = 1.0
242
+ if self.use_xpos:
243
+ power = (t - len(t) // 2) / self.scale_base
244
+ scale = self.scale ** rearrange(power, "n -> n 1")
245
+ scale = torch.cat((scale, scale), dim=-1)
246
+
247
+ if should_cache:
248
+ self.tmp_store("cached_scales", scale)
249
+
250
+ return scale
251
+
252
+ def get_axial_freqs(self, *dims):
253
+ Colon = slice(None)
254
+ all_freqs = []
255
+
256
+ for ind, dim in enumerate(dims):
257
+ if self.freqs_for == "pixel":
258
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
259
+ else:
260
+ pos = torch.arange(dim, device=self.device)
261
+
262
+ freqs = self.forward(pos, seq_len=dim)
263
+
264
+ all_axis = [None] * len(dims)
265
+ all_axis[ind] = Colon
266
+
267
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
268
+ all_freqs.append(freqs[new_axis_slice])
269
+
270
+ all_freqs = broadcast_tensors(*all_freqs)
271
+ return torch.cat(all_freqs, dim=-1)
272
+
273
+ @autocast("cuda", enabled=False)
274
+ def forward(self, t: Tensor, seq_len=None, offset=0):
275
+ should_cache = (
276
+ self.cache_if_possible
277
+ and not self.learned_freq
278
+ and exists(seq_len)
279
+ and self.freqs_for != "pixel"
280
+ )
281
+
282
+ if (
283
+ should_cache
284
+ and exists(self.cached_freqs)
285
+ and (offset + seq_len) <= self.cached_freqs.shape[0]
286
+ ):
287
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
288
+
289
+ freqs = self.freqs
290
+
291
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
292
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
293
+
294
+ if should_cache:
295
+ self.tmp_store("cached_freqs", freqs.detach())
296
+
297
+ return freqs
298
+
299
+
300
+
301
+
302
+
303
+ class Rope2D:
304
+ """ Helper class to apply RoPE2D as well as interpolate on the fly. """
305
+
306
+ def __init__(self, dim, use_cls_token=False):
307
+ self.dim = dim
308
+ self.use_cls_token = use_cls_token
309
+ self.grid_size = None
310
+ self.freq = None
311
+
312
+ def init_tensors(self):
313
+ self.rope = RotaryEmbedding(self.dim // 2)
314
+
315
+ def update_grid(self, device, grid_h, grid_w):
316
+ if self.grid_size != (grid_h, grid_w):
317
+ self.grid_size = (grid_h, grid_w)
318
+
319
+ self.rope = self.rope.to(device)
320
+
321
+ if self.use_cls_token:
322
+ # +1 to leave space for the cls token to be (0, 0)
323
+ grid_y_range = torch.arange(grid_h, device=device) + 1
324
+ grid_x_range = torch.arange(grid_w, device=device) + 1
325
+ else:
326
+ grid_y_range = torch.arange(grid_h, device=device)
327
+ grid_x_range = torch.arange(grid_w, device=device)
328
+
329
+ freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
330
+ freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
331
+ freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
332
+
333
+ if self.use_cls_token:
334
+ freq = torch.cat(
335
+ [torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
336
+ )
337
+
338
+ self.freq = freq[None, ...]
339
+
340
+ self.freq = self.freq.to(device)
341
+
342
+ def __call__(self, q, k):
343
+ # batch, heads, seq, dim = q.shape
344
+ q = apply_rotary_emb(self.freq[:, None, :, :], q)
345
+ k = apply_rotary_emb(self.freq[:, None, :, :], k)
346
+
347
+ return q, k
core/vision_encoder/transforms.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as T
2
+
3
+
4
+
5
+ def get_image_transform(
6
+ image_size: int,
7
+ center_crop: bool = False,
8
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
9
+ ):
10
+ if center_crop:
11
+ crop = [
12
+ T.Resize(image_size, interpolation=interpolation),
13
+ T.CenterCrop(image_size)
14
+ ]
15
+ else:
16
+ # "Squash": most versatile
17
+ crop = [
18
+ T.Resize((image_size, image_size), interpolation=interpolation)
19
+ ]
20
+
21
+ return T.Compose(crop + [
22
+ T.Lambda(lambda x: x.convert("RGB")),
23
+ T.ToTensor(),
24
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
25
+ ])
26
+
27
+
28
+
29
+
30
+
31
+ import torchvision.transforms as T
32
+
33
+ from PIL import Image
34
+
35
+ def get_image_transform(
36
+ image_size: int,
37
+ center_crop: bool = False,
38
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
39
+ ):
40
+ if center_crop:
41
+ crop = [
42
+ T.Resize(image_size, interpolation=interpolation),
43
+ T.CenterCrop(image_size)
44
+ ]
45
+ else:
46
+ # "Squash": most versatile
47
+ crop = [
48
+ T.Resize((image_size, image_size), interpolation=interpolation)
49
+ ]
50
+
51
+ return T.Compose(crop + [
52
+ T.Lambda(lambda x: x.convert("RGB")),
53
+ T.ToTensor(),
54
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
55
+ ])
56
+
57
+
58
+ def _convert_to_rgb(image: Image.Image) -> Image.Image:
59
+ """Converts a PIL Image to RGB format."""
60
+ return image.convert("RGB")
61
+
62
+
63
+ def get_image_transform_fix(
64
+ image_size: int,
65
+ center_crop: bool = False,
66
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR
67
+ ):
68
+ if center_crop:
69
+ crop = [
70
+ T.Resize(image_size, interpolation=interpolation),
71
+ T.CenterCrop(image_size)
72
+ ]
73
+ else:
74
+ # "Squash": most versatile
75
+ crop = [
76
+ T.Resize((image_size, image_size), interpolation=interpolation)
77
+ ]
78
+
79
+ return T.Compose(crop + [
80
+ T.Lambda(_convert_to_rgb),
81
+ T.ToTensor(),
82
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
83
+ ])
84
+
85
+ def get_text_tokenizer(context_length: int):
86
+ return SimpleTokenizer(context_length=context_length)
core/vision_projector/base.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from torch import nn
6
+
7
+
8
+ class BaseProjector(nn.Module, ABC):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.adaptive_avg_pool = None
12
+
13
+ @abstractmethod
14
+ def setup_projector(self):
15
+ """
16
+ Setup the vision_projector attribute in subclasses.
17
+ """
18
+ pass
19
+
20
+ def forward(self, x):
21
+ x = x.permute(1, 0, 2) # NLD -> LND
22
+ x = self.projector(x)
23
+ x = x.permute(1, 0, 2)
24
+ if self.adaptive_avg_pool is not None:
25
+ x = self.adaptive_avg_pool(x)
26
+ return x
core/vision_projector/mlp.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from core.utils import get_init_fn
9
+ from core.vision_projector.base import BaseProjector
10
+
11
+
12
+ class AdaptiveAvgPooling(nn.Module):
13
+ def __init__(self, pooling_ratio=2):
14
+ super(AdaptiveAvgPooling, self).__init__()
15
+ self.pooling_ratio = pooling_ratio
16
+
17
+ def forward(self, x):
18
+ b, num_tokens, c = x.shape
19
+ h = int(math.sqrt(num_tokens))
20
+ assert h * h == num_tokens
21
+
22
+ shape = (h // self.pooling_ratio, h // self.pooling_ratio)
23
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
24
+ x = F.adaptive_avg_pool2d(x, shape)
25
+ x = x.flatten(2).transpose(1, 2)
26
+
27
+ return x
28
+
29
+
30
+ class MLPProjector(BaseProjector):
31
+ def __init__(self, args):
32
+ super().__init__()
33
+ self.setup_projector(args)
34
+ self.pooling_ratio = args.pooling_ratio
35
+ self.adaptive_avg_pool = AdaptiveAvgPooling(pooling_ratio=args.pooling_ratio)
36
+ self.remove_vision_class_token = args.remove_vision_class_token
37
+
38
+ def init_tensors(self):
39
+ self.init_method(self.projector[0].weight)
40
+ self.init_method(self.projector[0].bias)
41
+ self.init_method(self.projector[2].weight)
42
+ self.init_method(self.projector[2].bias)
43
+
44
+ def setup_projector(self, args):
45
+ self.init_method = get_init_fn(args.mlp_init, args.dim, init_depth=None)
46
+ input_size = args.vision_model["width"]
47
+ output_size = args.dim
48
+ self.projector = nn.Sequential(
49
+ nn.Linear(
50
+ in_features=input_size,
51
+ out_features=output_size,
52
+ bias=True,
53
+ dtype=torch.get_default_dtype(),
54
+ ),
55
+ nn.GELU(),
56
+ nn.Linear(
57
+ in_features=output_size,
58
+ out_features=output_size,
59
+ bias=True,
60
+ dtype=torch.get_default_dtype(),
61
+ ),
62
+ )
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python-headless
4
+ Pillow
5
+ numpy
6
+ einops
7
+ peft
8
+ python-dotenv
9
+ tqdm
10
+ gradio
setup.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='pe_adaptation',
5
+ version='0.1',
6
+ packages=find_packages(),
7
+ )
src/model.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ from torch.nn import functional as F
3
+ from dotenv import load_dotenv
4
+ import os
5
+ import sys
6
+
7
+
8
+
9
+ from core.vision_encoder.pe import SelfAttention, AttentionPooling
10
+ import torch.nn as nn
11
+ from typing import Dict, List
12
+ from utils.task_config import Task
13
+ import torch
14
+ from typing import Optional, Union, Mapping,OrderedDict
15
+ from src.dlora import *
16
+ from peft import PeftModel, get_peft_model, LoraConfig
17
+ DROPOUT_P = 0.5
18
+
19
+
20
+ class MTLModel(nn.Module):
21
+ def __init__(self, backbone, tasks: List[Task],
22
+ rank: int = 64,
23
+ use_lora: bool = True,
24
+ truncate_idx: int = 22,
25
+ last_lora_layers: int = -99,
26
+ lora_dropout: float = 0.5,
27
+ use_mtl_lora :bool = False,
28
+ use_deep_head:bool = False,
29
+ use_batch_norm:bool = True,
30
+ use_mtl_attn_pool: bool = True,
31
+ use_dora:bool = True):
32
+
33
+ super().__init__()
34
+ self.use_mtl_attn_pool=use_mtl_attn_pool
35
+ self.tasks = tasks
36
+ self.use_mtl_lora = use_mtl_lora
37
+ self.use_deep_head= use_deep_head
38
+ self.use_lora = use_lora
39
+ self.use_mtlora = use_mtl_lora
40
+ output_dim = backbone.output_dim
41
+ # log_vars is for uncertainty weighting
42
+ self.log_vars = nn.Parameter(torch.zeros(len(tasks)))
43
+ task_names = [task.name for task in tasks]
44
+ self.backbone = backbone
45
+ width = backbone.width
46
+ heads = backbone.heads
47
+ rope = backbone.rope
48
+
49
+ if self.use_mtl_lora:
50
+ # save last residual attention block, as we need the weights values to seed the new mtl version
51
+ orig_last_block = backbone.transformer.resblocks[-1]
52
+ self.ln_post = backbone.ln_post
53
+
54
+ # save the attention pooling, as we need the weights values to seed the task specifics attention pooling layers
55
+ orig_attn_pool = backbone.attn_pool.to('cuda')
56
+
57
+ self.backbone.truncate(layer_idx=truncate_idx) # 23th block becomes the last (the idx is 22)
58
+
59
+ # mtl block that produces t-task specific features maps, plus a shared one
60
+ self.mtl_layer = MTLoRAResidualAttentionBlock(
61
+ d_model=width,
62
+ n_head=heads,
63
+ rope=rope,
64
+ r={'shared': rank, **{name: rank for name in task_names}},
65
+ tasks=task_names,
66
+ shared_mode='matrix' ,
67
+ lora_shared_scale=0.0 # We do not use the shared matrix, so we set it's scale to 0
68
+ )
69
+
70
+
71
+ self.mtl_layer.load_from_original_block(orig_last_block)
72
+ print("MTL-LoRA final block created and initialized from pretrained weights.")
73
+
74
+
75
+ if self.use_mtl_attn_pool:
76
+ self.attn_pool = MTLoRAAttentionPooling(
77
+ embed_dim=width,
78
+ num_heads=8,
79
+ tasks=task_names,
80
+ r={'shared': rank, **{name: rank for name in task_names}},
81
+ lora_dropout=lora_dropout,
82
+ lora_task_scale=1.0,
83
+ lora_shared_scale=0.0
84
+ )
85
+ self.attn_pool.load_from_original(orig_attn_pool)
86
+ else:
87
+ self.task_specific_attn_pool = nn.ModuleDict({
88
+ task.name: AttentionPooling(embed_dim=width, num_heads=8)
89
+ for task in self.tasks
90
+ })
91
+ for task in self.tasks:
92
+ self.task_specific_attn_pool[task.name].load_state_dict(orig_attn_pool.state_dict())
93
+ print("Task-specific Attention Pooling layers created and initialized.")
94
+
95
+
96
+ del self.backbone.attn_pool
97
+
98
+
99
+
100
+ if use_lora:
101
+ # You can modify this list if you want to target only attention layers or mlp layers
102
+ target_layers = ["attn.in_proj", "attn.out_proj", "mlp.c_fc", "mlp.c_proj"]
103
+ target_modules = []
104
+ for name, param in self.backbone.named_modules():
105
+ if not isinstance(param, nn.Linear):
106
+ continue
107
+ is_target_layer = any(s in name for s in target_layers)
108
+ if is_target_layer:
109
+ if "attn_pool" in name:
110
+ target_modules.append(name)
111
+ elif "transformer.resblocks" in name:
112
+ layer_idx = int(name.split('.')[2])
113
+ if layer_idx >= last_lora_layers:
114
+ target_modules.append(name)
115
+
116
+ lora_config = LoraConfig(
117
+ r=rank,
118
+ lora_alpha=rank,
119
+ target_modules= target_modules,
120
+ use_dora=use_dora,
121
+ lora_dropout=lora_dropout,
122
+ bias = "none"
123
+ )
124
+
125
+ self.backbone = get_peft_model(self.backbone,lora_config)
126
+ print("PEFT LoRA module added")
127
+
128
+
129
+ if self.use_deep_head == False:
130
+ self.prediction_layers = nn.ModuleDict({
131
+ task.name: nn.Sequential(
132
+ nn.BatchNorm1d(backbone.output_dim) if use_batch_norm else nn.Identity(),
133
+ nn.Dropout(p=DROPOUT_P),
134
+ nn.Linear( backbone.output_dim, len(task.class_labels))
135
+ )
136
+ for task in self.tasks
137
+ })
138
+ print("Task-specific prediction heads created.")
139
+ else:
140
+ self.prediction_layers = nn.ModuleDict({
141
+ task.name: nn.Sequential(
142
+ nn.BatchNorm1d(backbone.output_dim) if use_batch_norm else nn.Identity(),
143
+ nn.Dropout(p=DROPOUT_P),
144
+ nn.Linear(backbone.output_dim, backbone.output_dim),
145
+ nn.GELU(),
146
+ nn.Linear(backbone.output_dim, len(task.class_labels)),
147
+ )
148
+ for task in self.tasks
149
+ })
150
+ print("Task-specific prediction deep-heads created.")
151
+
152
+
153
+ self.backbone.del_muda()
154
+
155
+
156
+
157
+ def enable_gradient_checkpointing(self):
158
+ """Call this method after setting up parameter requires_grad"""
159
+ backbone_has_trainable = any(param.requires_grad for param in self.backbone.parameters())
160
+ if backbone_has_trainable:
161
+ self.backbone.set_grad_checkpointing()
162
+ print("Gradient checkpointing enabled for backbone (has trainable parameters)")
163
+ else:
164
+ print("Gradient checkpointing not enabled - backbone has no trainable parameters")
165
+
166
+ def forward(self, x: torch.Tensor):
167
+ if self.use_mtl_lora:
168
+ return self._forward_mtl_block(x)
169
+ else:
170
+ return self._forward_shared(x)
171
+
172
+ def _forward_shared(self, x: torch.Tensor):
173
+ logits = {}
174
+
175
+ #if self.attention_specific_pool == True:
176
+ # features = self.backbone.forward_features(x, norm=True, strip_cls_token=False)
177
+ # for task in self.tasks:
178
+ #
179
+ # pooled_feat = self.task_specific_attn_pool[task_name](features)
180
+ # pooled_feat = pooled_feat.squeeze(1)
181
+ # logits[task_name] = self.prediction_layers[task_name](pooled_feat)
182
+ #else:
183
+ features = self.backbone(x)
184
+ # print(features.shape)
185
+ for task in self.tasks:
186
+ logits[task.name] = self.prediction_layers[task.name](features)
187
+
188
+
189
+ return logits
190
+
191
+ def _forward_mtl_block(self, x: torch.Tensor, return_feat=False, feat_to_return="None"):
192
+ # Shared feature map from the backbone
193
+ # norm=False, because normalization is "trained" on the feature map of the output of the last ResidualAttentionBlock
194
+ # so we will normalize the task specific feature map, instead of the shared one
195
+ # strip_cls_token=False, because in the PE paper it has been shown to be beneficial to keep it
196
+ features = self.backbone.forward_features(x, norm=False, strip_cls_token=False)
197
+
198
+ # Equal for each task, as our mtl layer follows a task-agnostic layer
199
+ task_features_input = {task.name: features for task in self.tasks}
200
+
201
+ # Returns also a shared features map, that is discarded,
202
+ # task features is a dictionary, the key is task name, and the value is a tensor of shape (batch_size, n_tokens, d_model)
203
+ # rappresting the task specific features map
204
+ _, task_features = self.mtl_layer(features, x_tasks=task_features_input)
205
+
206
+ normalized_task_features = {
207
+ task.name: self.ln_post(task_features[task.name])
208
+ for task in self.tasks
209
+ }
210
+
211
+ if self.use_mtl_attn_pool:
212
+ pooled_features = self.attn_pool(normalized_task_features)
213
+ else:
214
+ pooled_features = {}
215
+ for task in self.tasks:
216
+ feat = normalized_task_features[task.name]
217
+ pooled_features[task.name] = self.task_specific_attn_pool[task.name](feat)
218
+
219
+ # this stuff is for pca/tsne visualization
220
+ if return_feat:
221
+ if feat_to_return == "Age":
222
+ return pooled_features['Age']
223
+ elif feat_to_return == "Emotion":
224
+ return pooled_features['Emotion']
225
+ elif feat_to_return == "Gender":
226
+ return pooled_features['Gender']
227
+
228
+
229
+ logits = {}
230
+ for task in self.tasks:
231
+ # Squeeze the pooling dimension (1)
232
+ pooled_feat = pooled_features[task.name].squeeze(1) # (batch, 1, d_model) -> (batch, d_model)
233
+ logits[task.name] = self.prediction_layers[task.name](pooled_feat)
234
+
235
+ return logits
236
+
237
+ def save_whole_model(self, filepath: str):
238
+ print(f"Saving model state_dict to {filepath}")
239
+ torch.save(self.state_dict(), filepath)
240
+
241
+ def load_model(self, filepath:str,map_location='cuda'):
242
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
243
+ if self.use_lora or self.use_mtlora:
244
+ self.backbone.merge_and_unload()
245
+ self.to(device)
246
+ state_dict = torch.load(filepath, map_location=map_location)
247
+ self.load_state_dict(state_dict, strict=True)
248
+
249
+ def save_adapters_peft(self, save_directory: str):
250
+
251
+ print(f"Saving adapters to directory: {save_directory}")
252
+ os.makedirs(save_directory, exist_ok=True)
253
+
254
+ custom_layers_state_dict = {
255
+ 'prediction_layers': self.prediction_layers.state_dict()
256
+ }
257
+
258
+ if self.use_lora:
259
+ self.backbone.save_pretrained(save_directory)
260
+
261
+ if self.use_mtlora:
262
+ custom_layers_state_dict['mtl_layer'] = self.mtl_layer.state_dict()
263
+ #custom_layers_state_dict['task_specific_attn_pooling'] = self.task_specific_attn_pool.state_dict()
264
+ custom_layers_state_dict['mtl_attn_pool'] = self.attn_pool.state_dict()
265
+
266
+
267
+ torch.save(custom_layers_state_dict, os.path.join(save_directory, 'custom_layers.pt'))
268
+ print("Successfully saved PEFT backbone and custom task heads.")
269
+
270
+ def load_heads(self, filepaths: List[str],device='cuda'):
271
+
272
+ for ckpt in filepaths:
273
+ checkpoint = torch.load(ckpt, map_location=device)
274
+ model_state_dict = self.state_dict()
275
+
276
+ if "prediction_layers" in checkpoint:
277
+ for loaded_key, value in checkpoint["prediction_layers"].items():
278
+ new_key = loaded_key
279
+
280
+ # Remap prefix: 'heads.emotion.' -> 'prediction_layers.Emotion.'
281
+ if new_key.startswith('heads.emotion.'):
282
+ new_key = new_key.replace('heads.emotion.', 'prediction_layers.Emotion.')
283
+
284
+ if new_key.startswith('heads.age.'):
285
+ new_key = new_key.replace('heads.age.', 'prediction_layers.Age.')
286
+
287
+ if new_key.startswith('heads.gender.'):
288
+ new_key = new_key.replace('heads.gender.', 'prediction_layers.Gender.')
289
+
290
+ # Remap final layer index for deep head: '.5.' -> '.4.'
291
+ if '.5.' in new_key:
292
+ new_key = new_key.replace('.5.', '.4.')
293
+
294
+ if new_key in model_state_dict:
295
+ if model_state_dict[new_key].shape == value.shape:
296
+ model_state_dict[new_key].copy_(value)
297
+
298
+ def load_adapters_peft(self, load_directory: str, custom_head_name:str = 'custom_layers.pt'):
299
+
300
+ print(f"Loading adapters from directory: {load_directory}")
301
+ if self.use_lora:
302
+ self.backbone = self.backbone.merge_and_unload()
303
+ self.backbone = PeftModel.from_pretrained(self.backbone, load_directory)
304
+
305
+ custom_layers_path = os.path.join(load_directory, custom_head_name)
306
+ if not os.path.exists(custom_layers_path):
307
+ raise FileNotFoundError(f"Custom task heads file not found at {custom_layers_path}")
308
+
309
+ checkpoint = torch.load(custom_layers_path, map_location=("cuda" if torch.cuda.is_available() else "cpu"))
310
+
311
+ self.prediction_layers.load_state_dict(checkpoint['prediction_layers'])
312
+
313
+ if self.use_mtlora:
314
+ try:
315
+ self.mtl_layer.load_state_dict(checkpoint['mtl_layer'][0])
316
+ except KeyError:
317
+ self.mtl_layer.load_state_dict(checkpoint['mtl_layer'])
318
+ self.attn_pool.load_state_dict(checkpoint['mtl_attn_pool'])
319
+
320
+ print("Successfully loaded PEFT backbone and custom task heads.")
321
+
322
+ def save_trained(self, filepath: str):
323
+
324
+ trainable_param_names = {name for name, param in self.named_parameters() if param.requires_grad}
325
+ trainable_module_paths = {'.'.join(name.split('.')[:-1]) for name in trainable_param_names}
326
+
327
+ state_to_save = {}
328
+ full_state_dict = self.state_dict()
329
+
330
+ for key, value in full_state_dict.items():
331
+ if key in trainable_param_names:
332
+ state_to_save[key] = value
333
+ continue
334
+
335
+
336
+ current_module_path = '.'.join(key.split('.')[:-1])
337
+ if current_module_path in trainable_module_paths:
338
+ state_to_save[key] = value
339
+
340
+ print(f"Saving {len(state_to_save)} state entries (parameters and buffers) to {filepath}")
341
+ torch.save(state_to_save, filepath)
342
+
343
+
344
+ def load_trained_legacy(self, filepath: str, device='cuda'):
345
+ """The training of some checkpoint where done with a different model class,
346
+ so there is the need of remapping the key names, so they match with this new model class"""
347
+ print(f"Loading trained states from structured checkpoint: {filepath}")
348
+
349
+ checkpoint = torch.load(filepath, map_location=device)
350
+
351
+ model_state_dict = self.state_dict()
352
+
353
+ loaded_keys_count = 0
354
+ skipped_keys = []
355
+ remapped_keys_examples = {}
356
+
357
+ if "backbone_state_dict" in checkpoint:
358
+ print("\n--- Processing Backbone Weights ---")
359
+ for loaded_key, value in checkpoint["backbone_state_dict"].items():
360
+ new_key = loaded_key
361
+
362
+ if new_key.startswith('strategy.backbone.'):
363
+ new_key = new_key.replace('strategy.backbone.', 'backbone.')
364
+
365
+ if 'attn.in_proj_weight' in new_key and 'attn.in_proj.weight' not in new_key:
366
+ new_key = new_key.replace('attn.in_proj_weight', 'attn.in_proj.weight')
367
+ if 'attn.in_proj_bias' in new_key and 'attn.in_proj.bias' not in new_key:
368
+ new_key = new_key.replace('attn.in_proj_bias', 'attn.in_proj.bias')
369
+
370
+ if new_key in model_state_dict:
371
+ if model_state_dict[new_key].shape == value.shape:
372
+ model_state_dict[new_key].copy_(value)
373
+ loaded_keys_count += 1
374
+ if loaded_key != new_key and len(remapped_keys_examples) < 5:
375
+ remapped_keys_examples[loaded_key] = new_key
376
+ else:
377
+ skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
378
+ else:
379
+ skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
380
+
381
+ if "prediction_layers" in checkpoint:
382
+ print("\n--- Processing Prediction Head Weights ---")
383
+ for loaded_key, value in checkpoint["prediction_layers"].items():
384
+ new_key = loaded_key
385
+
386
+ if new_key.startswith('heads.emotion.'):
387
+ new_key = new_key.replace('heads.emotion.', 'prediction_layers.Emotion.')
388
+
389
+ if new_key.startswith('heads.age.'):
390
+ new_key = new_key.replace('heads.age.', 'prediction_layers.Age.')
391
+
392
+ if new_key.startswith('heads.gender.'):
393
+ new_key = new_key.replace('heads.gender.', 'prediction_layers.Gender.')
394
+
395
+ if '.5.' in new_key:
396
+ new_key = new_key.replace('.5.', '.4.')
397
+
398
+ # Validate, load, and update trackers
399
+ if new_key in model_state_dict:
400
+ if model_state_dict[new_key].shape == value.shape:
401
+ model_state_dict[new_key].copy_(value)
402
+ loaded_keys_count += 1
403
+ if loaded_key != new_key and len(remapped_keys_examples) < 10:
404
+ remapped_keys_examples[loaded_key] = new_key
405
+ else:
406
+ skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
407
+ else:
408
+ skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
409
+
410
+ if "attn_pool" in checkpoint:
411
+ print("\n--- Processing Attention Pool Weights ---")
412
+ for loaded_key, value in checkpoint["attn_pool"].items():
413
+ # The attn_pool keys in the source file also have the 'strategy.backbone' prefix
414
+ new_key = loaded_key.replace('strategy.backbone.attn_pool.', 'backbone.attn_pool.')
415
+
416
+ # Validate, load, and update trackers
417
+ if new_key in model_state_dict:
418
+ if model_state_dict[new_key].shape == value.shape:
419
+ model_state_dict[new_key].copy_(value)
420
+ loaded_keys_count += 1
421
+ if loaded_key != new_key and len(remapped_keys_examples) < 15:
422
+ remapped_keys_examples[loaded_key] = new_key
423
+ else:
424
+ skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
425
+ else:
426
+ skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
427
+
428
+
429
+
430
+
431
+ if loaded_keys_count == 0:
432
+ print('LAODED 0')
433
+ self.load_state_dict(torch.load(filepath, map_location=device), strict=False)
434
+
435
+
436
+
437
+ class MTLoRAResidualAttentionBlock(nn.Module):
438
+ """Adaptation of Perception Encoder ResidualAttentionBlock with MTLora, to produce t-task specific feature-maps and a shared feature map"""
439
+ def __init__(
440
+ self,
441
+ d_model: int,
442
+ n_head: int,
443
+ mlp_ratio: float = 4.0,
444
+ ls_init_value: float = None,
445
+ act_layer = nn.GELU,
446
+ norm_layer = nn.LayerNorm,
447
+ drop_path: float = 0.0,
448
+ rope: Optional[nn.Module] = None,
449
+ r: Union[int, Mapping[str, int]] = 0,
450
+ lora_shared_scale: float = 1.0,
451
+ lora_task_scale: float = 1.0,
452
+ lora_dropout: float = DROPOUT_P,
453
+ tasks=None,
454
+ trainable_scale_shared=False,
455
+ trainable_scale_per_task=False,
456
+ shared_mode: str = 'matrix',
457
+ ):
458
+ super().__init__()
459
+ self.tasks = tasks
460
+ self.num_heads = n_head
461
+ self.head_dim = d_model // n_head
462
+ self.scale = self.head_dim ** -0.5
463
+ self.rope = rope
464
+
465
+ task_scales = {t: lora_task_scale for t in tasks}
466
+
467
+
468
+ # MultiTask Lora for QKV matrices
469
+ # (MTLoRAQKV does not actually compute attention, but returns the shared QKV matrices and the task-specific QKV matrices)
470
+ self.attn = MTLoRAQKV(
471
+ in_features=d_model,
472
+ out_features=d_model,
473
+ r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
474
+ lora_dropout=lora_dropout, tasks=tasks, trainable_scale_shared=trainable_scale_shared,
475
+ trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode
476
+ )
477
+
478
+ # MultiTask Lora for projection matrices in mha
479
+ self.out_proj = MTLoRALinear(
480
+ in_features=d_model,
481
+ out_features=d_model,
482
+ r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
483
+ lora_dropout=lora_dropout, tasks=tasks, trainable_scale_shared=trainable_scale_shared,
484
+ trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode
485
+ )
486
+
487
+ self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
488
+ self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
489
+
490
+ self.ln_1 = norm_layer(d_model)
491
+ self.ln_2 = norm_layer(d_model)
492
+
493
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
494
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
495
+
496
+ # LoRA-enabled MLP
497
+ mlp_width = int(d_model * mlp_ratio)
498
+ self.mlp = nn.Sequential(
499
+ OrderedDict([
500
+ ("c_fc", MTLoRALinear(
501
+ d_model, mlp_width, r=r, lora_shared_scale=lora_shared_scale,
502
+ lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks,
503
+ trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task,
504
+ shared_mode=shared_mode
505
+ )),
506
+ ("gelu", act_layer()),
507
+ ("c_proj", MTLoRALinear(
508
+ mlp_width, d_model, r=r, lora_shared_scale=lora_shared_scale,
509
+ lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks,
510
+ trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task,
511
+ shared_mode=shared_mode
512
+ )),
513
+ ])
514
+ )
515
+
516
+ def _call_attn(
517
+ self,
518
+ x_shared: torch.Tensor,
519
+ attn_mask: Optional[torch.Tensor] = None,
520
+ x_tasks: Optional[Dict[str, torch.Tensor]] = None,
521
+ ):
522
+ # s is the number of patches/tokens, sequence length
523
+ proj, proj_tasks = self.attn(x_shared, x_tasks) # proj is (b s 3*d_model), proj_tasks is dict of (b s 3*d_model), one entry per task
524
+
525
+ def compute_attention(projection_tensor):
526
+ # Reshape Q, K, V
527
+ # projection_tensor is (b s 3*d_model), need to split and rearrange
528
+ _, s, _ = projection_tensor.shape
529
+ # output_features from MTLoRAQKV is d_model, so 3 * d_model
530
+ split_size = self.attn.q.linear.out_features # This should be d_model
531
+
532
+ # Unflatten into (b s 3 d_model) then transpose to get (3 b s d_model)
533
+ q, k, v = projection_tensor.unflatten(-1, (3, split_size)).permute(2, 0, 1, 3).contiguous()
534
+ # Rearrange for multi-head attention (b h s d)
535
+ q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
536
+ k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
537
+ v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
538
+
539
+ if self.rope:
540
+ q, k = self.rope(q, k)
541
+
542
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
543
+ return rearrange(attn_output, "b h s d -> b s (h d)")
544
+
545
+ # Process shared path
546
+ attn_result = compute_attention(proj)
547
+
548
+ # Process task-specific paths
549
+ attn_tasks_results = {}
550
+ if proj_tasks:
551
+ for task, task_proj in proj_tasks.items():
552
+ attn_tasks_results[task] = compute_attention(task_proj)
553
+
554
+ # Apply output projection
555
+ # out_proj is an MTLoRALinear, so its forward expects (x, x_tasks)
556
+ shared_out, tasks_out = self.out_proj(attn_result, x_tasks=attn_tasks_results if attn_tasks_results else None)
557
+
558
+ return shared_out, tasks_out
559
+
560
+ def forward(
561
+ self,
562
+ x: torch.Tensor,
563
+ attn_mask: Optional[torch.Tensor] = None,
564
+ x_tasks: Optional[Dict[str, torch.Tensor]] = None,
565
+ ):
566
+ # Attention block
567
+ norm_x = self.ln_1(x)
568
+ norm_x_tasks = {task: self.ln_1(x_tasks[task]) for task in self.tasks} if x_tasks else None
569
+
570
+ attn_out, attn_tasks_out = self._call_attn(norm_x, attn_mask=attn_mask, x_tasks=norm_x_tasks)
571
+
572
+ x = x + self.drop_path1(self.ls_1(attn_out))
573
+ if attn_tasks_out and x_tasks:
574
+ for task in self.tasks:
575
+ x_tasks[task] = x_tasks[task] + self.drop_path1(self.ls_1(attn_tasks_out[task]))
576
+
577
+ # MLP block
578
+ norm_x = self.ln_2(x)
579
+ norm_x_tasks = {task: self.ln_2(x_tasks[task]) for task in self.tasks} if x_tasks else None
580
+
581
+ # The MTLoRALinear forward needs to be called directly for the sequential MLP
582
+ mlp_fc_out, mlp_fc_tasks_out = self.mlp.c_fc(norm_x, norm_x_tasks)
583
+ gelu_out = self.mlp.gelu(mlp_fc_out)
584
+ gelu_tasks_out = {task: self.mlp.gelu(mlp_fc_tasks_out[task]) for task in self.tasks} if mlp_fc_tasks_out else None
585
+
586
+ mlp_proj_out, mlp_proj_tasks_out = self.mlp.c_proj(gelu_out, gelu_tasks_out)
587
+
588
+ x = x + self.drop_path2(self.ls_2(mlp_proj_out))
589
+ if mlp_proj_tasks_out and x_tasks:
590
+ for task in self.tasks:
591
+ x_tasks[task] = x_tasks[task] + self.drop_path2(self.ls_2(mlp_proj_tasks_out[task]))
592
+
593
+ return x, x_tasks
594
+
595
+ def load_from_original_block(self, original_block):
596
+ """
597
+ Initializes the weights of this block from a pre-trained ResidualAttentionBlock.
598
+ The LoRA-specific parameters are reset to their initial state.
599
+ """
600
+ with torch.no_grad():
601
+ # Copy LayerNorm and LayerScale weights
602
+ self.ln_1.load_state_dict(original_block.ln_1.state_dict())
603
+ self.ln_2.load_state_dict(original_block.ln_2.state_dict())
604
+ self.ls_1.load_state_dict(original_block.ls_1.state_dict())
605
+ self.ls_2.load_state_dict(original_block.ls_2.state_dict())
606
+
607
+ # Copy MLP weights into the .linear attribute of the MTLoRALinear layers
608
+ self.mlp.c_fc.linear.load_state_dict(original_block.mlp.c_fc.state_dict())
609
+ self.mlp.c_proj.linear.load_state_dict(original_block.mlp.c_proj.state_dict())
610
+
611
+ # Copy Attention weights
612
+ # Both SelfAttention and nn.MultiheadAttention store QKV weights combined
613
+ if isinstance(original_block.attn, SelfAttention):
614
+ # Using migrate_weights ensures the Parameters are copied to the Linear layer first
615
+ # Then we can extract from the Linear layer
616
+ original_block.attn.migrate_weights() # Ensure weights are in .in_proj and .out_proj
617
+
618
+ # Split the combined weight and bias tensors into Q, K, V from .in_proj
619
+ qkv_weight = original_block.attn.in_proj.weight
620
+ qkv_bias = original_block.attn.in_proj.bias
621
+
622
+ q_w, k_w, v_w = qkv_weight.chunk(3)
623
+ q_b, k_b, v_b = qkv_bias.chunk(3)
624
+
625
+ # Load into the .linear attributes of the MTLoRAQKV module
626
+ self.attn.q.linear.weight.copy_(q_w)
627
+ self.attn.q.linear.bias.copy_(q_b)
628
+
629
+ self.attn.k.linear.weight.copy_(k_w)
630
+ self.attn.k.linear.bias.copy_(k_b)
631
+
632
+ self.attn.v.linear.weight.copy_(v_w)
633
+ self.attn.v.linear.bias.copy_(v_b)
634
+
635
+ # Load the output projection weights
636
+ self.out_proj.linear.load_state_dict(original_block.attn.out_proj.state_dict())
637
+ elif isinstance(original_block.attn, nn.MultiheadAttention):
638
+ self.attn.q.linear.weight.copy_(original_block.attn.in_proj_weight[:self.attn.q.linear.out_features, :])
639
+ self.attn.q.linear.bias.copy_(original_block.attn.in_proj_bias[:self.attn.q.linear.out_features])
640
+
641
+ self.attn.k.linear.weight.copy_(original_block.attn.in_proj_weight[self.attn.q.linear.out_features:2*self.attn.q.linear.out_features, :])
642
+ self.attn.k.linear.bias.copy_(original_block.attn.in_proj_bias[self.attn.q.linear.out_features:2*self.attn.q.linear.out_features])
643
+
644
+ self.attn.v.linear.weight.copy_(original_block.attn.in_proj_weight[2*self.attn.q.linear.out_features:3*self.attn.q.linear.out_features, :])
645
+ self.attn.v.linear.bias.copy_(original_block.attn.in_proj_bias[2*self.attn.q.linear.out_features:3*self.attn.q.linear.out_features])
646
+
647
+ self.out_proj.linear.weight.copy_(original_block.attn.out_proj.weight)
648
+ self.out_proj.linear.bias.copy_(original_block.attn.out_proj.bias)
649
+
650
+ else:
651
+ raise TypeError(f"Unsupported attention module type in original_block: {type(original_block.attn)}")
652
+
653
+
654
+ # After loading pretrained weights, re-initialize LoRA-specific parameters
655
+ # This ensures that at the start of finetuning, the LoRA adjustment is zero.
656
+ self.attn.reset_parameters()
657
+ self.out_proj.reset_parameters()
658
+ self.mlp.c_fc.reset_parameters()
659
+ self.mlp.c_proj.reset_parameters()
660
+
661
+ print("Successfully loaded weights from original ResidualAttentionBlock and reset LoRA parameters.")
662
+
663
+
664
+ class MTLoRAAttentionPooling(nn.Module):
665
+ """
666
+ A MT-LoRA equivalent of the AttentionPooling transformer block.
667
+
668
+ This module replicates the full original architecture:
669
+ 1. Task-specific probes for attention pooling.
670
+ 2. MT-LoRA enabled Q/K/V and Output projections.
671
+ 3. A LayerNorm layer.
672
+ 4. An MLP block with MT-LoRA enabled linear layers.
673
+ 5. A final residual connection, matching the original's structure.
674
+ """
675
+ def __init__(
676
+ self,
677
+ embed_dim: int,
678
+ num_heads: int,
679
+ tasks: List[str],
680
+ r: Union[int, Mapping[str, int]] = 0,
681
+ lora_shared_scale: float = 1.0,
682
+ lora_task_scale: float = 1.0,
683
+ lora_dropout: float = 0.0,
684
+ mlp_ratio: int = 4,
685
+ act_layer = nn.GELU,
686
+ norm_layer = nn.LayerNorm,
687
+ ):
688
+ super().__init__()
689
+ self.tasks = tasks
690
+ self.num_heads = num_heads
691
+
692
+ self.probe = nn.ParameterDict({
693
+ task: nn.Parameter(torch.randn(1, 1, embed_dim))
694
+ for task in tasks
695
+ })
696
+
697
+ task_scales = {t: lora_task_scale for t in tasks}
698
+
699
+ self.q_proj = MTLoRALinear(
700
+ embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
701
+ lora_dropout=lora_dropout, tasks=tasks
702
+ )
703
+ self.k_proj = MTLoRALinear(
704
+ embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
705
+ lora_dropout=lora_dropout, tasks=tasks
706
+ )
707
+ self.v_proj = MTLoRALinear(
708
+ embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
709
+ lora_dropout=lora_dropout, tasks=tasks
710
+ )
711
+ self.out_proj = MTLoRALinear(
712
+ embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
713
+ lora_dropout=lora_dropout, tasks=tasks
714
+ )
715
+
716
+ self.layernorm = norm_layer(embed_dim)
717
+ mlp_width = int(embed_dim * mlp_ratio)
718
+ self.mlp = nn.Sequential(
719
+ OrderedDict([
720
+ ("c_fc", MTLoRALinear(
721
+ embed_dim, mlp_width, r=r, lora_shared_scale=lora_shared_scale,
722
+ lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks
723
+ )),
724
+ ("gelu", nn.GELU()),
725
+ ("c_proj", MTLoRALinear(
726
+ mlp_width, embed_dim, r=r, lora_shared_scale=lora_shared_scale,
727
+ lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks
728
+ )),
729
+ ])
730
+ )
731
+
732
+ def load_from_original(self, original_pool: AttentionPooling):
733
+ """Initializes all weights from the pretrained AttentionPooling block."""
734
+ with torch.no_grad():
735
+ original_attn = original_pool.attn
736
+
737
+ for task in self.tasks:
738
+ self.probe[task].copy_(original_pool.probe)
739
+
740
+ q_w, k_w, v_w = original_attn.in_proj_weight.chunk(3)
741
+ q_b, k_b, v_b = original_attn.in_proj_bias.chunk(3)
742
+
743
+ self.q_proj.linear.weight.copy_(q_w)
744
+ self.q_proj.linear.bias.copy_(q_b)
745
+ self.k_proj.linear.weight.copy_(k_w)
746
+ self.k_proj.linear.bias.copy_(k_b)
747
+ self.v_proj.linear.weight.copy_(v_w)
748
+ self.v_proj.linear.bias.copy_(v_b)
749
+
750
+ self.out_proj.linear.load_state_dict(original_attn.out_proj.state_dict())
751
+
752
+ self.layernorm.load_state_dict(original_pool.layernorm.state_dict())
753
+
754
+ self.mlp.c_fc.linear.load_state_dict(original_pool.mlp.c_fc.state_dict())
755
+ self.mlp.c_proj.linear.load_state_dict(original_pool.mlp.c_proj.state_dict())
756
+
757
+ self.q_proj.reset_parameters()
758
+ self.k_proj.reset_parameters()
759
+ self.v_proj.reset_parameters()
760
+ self.out_proj.reset_parameters()
761
+ self.mlp.c_fc.reset_parameters()
762
+ self.mlp.c_proj.reset_parameters()
763
+ print("Full MT-LoRA Attention Pooling block created and initialized from pretrained weights.")
764
+
765
+ def forward(self, x_tasks: Dict[str, torch.Tensor]):
766
+ """
767
+ Forward pass that correctly handles unique inputs for each task.
768
+
769
+ In this version, K and V are calculated inside the loop based on
770
+ the task-specific input 'x', and the each task has it's unique probe.
771
+ """
772
+
773
+
774
+ final_outputs = {}
775
+ for task, x in x_tasks.items():
776
+ B, N, C = x.shape
777
+ probe = self.probe[task].repeat(B, 1, 1)
778
+
779
+
780
+ _, q_task_dict = self.q_proj(probe, x_tasks={task: probe})
781
+ q = q_task_dict[task]
782
+
783
+ _, k_task_dict = self.k_proj(x, x_tasks={task: x})
784
+ k = k_task_dict[task]
785
+
786
+ _, v_task_dict = self.v_proj(x, x_tasks={task: x})
787
+ v = v_task_dict[task]
788
+
789
+ q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
790
+ k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
791
+ v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
792
+
793
+ attn_out = F.scaled_dot_product_attention(q, k, v)
794
+ attn_out_rearranged = rearrange(attn_out, 'b h n d -> b n (h d)')
795
+
796
+ _, out_proj_dict = self.out_proj(attn_out_rearranged, x_tasks={task: attn_out_rearranged})
797
+ x_attn = out_proj_dict[task]
798
+
799
+ norm_attn = self.layernorm(x_attn)
800
+
801
+ _, fc_task_dict = self.mlp.c_fc(norm_attn, x_tasks={task: norm_attn})
802
+ gelu_out = self.mlp.gelu(fc_task_dict[task])
803
+ _, proj_task_dict = self.mlp.c_proj(gelu_out, x_tasks={task: gelu_out})
804
+ mlp_out = proj_task_dict[task]
805
+
806
+ final_outputs[task] = x_attn + mlp_out
807
+
808
+ return final_outputs
809
+
utils/__pycache__/commons.cpython-313.pyc ADDED
Binary file (8.06 kB). View file
 
utils/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (26.6 kB). View file
 
utils/__pycache__/face_detector.cpython-313.pyc ADDED
Binary file (6.43 kB). View file
 
utils/__pycache__/task_config.cpython-313.pyc ADDED
Binary file (1.19 kB). View file
 
utils/commons.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains functions used for loading and logging models """
2
+
3
+ import sys
4
+ import os
5
+ from transformers import AutoModel, AutoProcessor
6
+
7
+ import os
8
+ import core.vision_encoder.pe as pe
9
+ import core.vision_encoder.transforms as transforms_pe
10
+ from core.vision_encoder.config import PE_VISION_CONFIG
11
+ import torchvision.transforms as transforms
12
+ from PIL import Image
13
+ import requests
14
+
15
+
16
+ def print_trainable_params(model):
17
+ total_params = sum(p.numel() for p in model.parameters())
18
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
19
+ percent = (trainable_params / total_params * 100) if total_params > 0 else 0
20
+ print("\n--- Summary ---")
21
+ print(f"Trainable parameters: {trainable_params:,}")
22
+ print(f"Total parameters: {total_params:,}")
23
+ print(f"Percentage: {percent:.2f}%")
24
+
25
+
26
+ def get_backbone_pe(version, print_info=False, apply_migration_flag=False,pretrained=True):
27
+ """
28
+ Load PE ViT model, return model, transforms and size of output (dimension of embedding of last token)
29
+ """
30
+ print(f'Loading {version}...')
31
+ backbone = pe.VisionTransformer.from_config(version, pretrained=pretrained)
32
+ backbone_config = PE_VISION_CONFIG[version]
33
+ transform = transforms_pe.get_image_transform_fix(image_size=backbone_config.image_size)
34
+
35
+ print("\nYou can ignore the Missing keys list above.")
36
+ print(f"Applying migration = {apply_migration_flag}")
37
+
38
+ if print_info:
39
+ attnpool= backbone.attn_pool
40
+ print(f'embed_dim={attnpool.embed_dim}\nnum_heads={attnpool.num_heads}')
41
+ print(f'OUTPUT DIM = {backbone_config.output_dim}')
42
+
43
+ def apply_migration(m):
44
+ if isinstance(m, pe.SelfAttention):
45
+ m.migrate_weights()
46
+
47
+ if apply_migration_flag == True: # when testing/resuming no migration should be used
48
+ print('[MIGRATION] Migrating weights for PEFT compatibiltyy')
49
+ backbone.apply(apply_migration)
50
+
51
+ return backbone, transform, backbone_config.output_dim
52
+
53
+
54
+ def get_backbone_dinov3(model_name: str="facebook/dinov3-vitb16-pretrain-lvd1689m", print_info=False):
55
+ print(f"Loading Hugging Face model: {model_name}")
56
+ processor = AutoProcessor.from_pretrained(model_name)
57
+
58
+ # Extract image processing configuration from the loaded processor
59
+ image_processor_config = processor
60
+ image_size = image_processor_config.size['height']
61
+ image_mean = image_processor_config.image_mean
62
+ image_std = image_processor_config.image_std
63
+
64
+ transform = transforms.Compose([
65
+ transforms.Lambda(_convert_to_rgb),
66
+ transforms.Resize((image_size, image_size), antialias=True),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize(mean=image_mean, std=image_std)
69
+ ])
70
+
71
+ # Load the model and return only the vision backbone
72
+ vision_model = AutoModel.from_pretrained(model_name)
73
+
74
+ if print_info:
75
+ print(f'\nVISION CONFIGS:\n{vision_model.config}')
76
+ print(f'\n\n\n{vision_model}')
77
+
78
+
79
+ return vision_model, transform, vision_model.config.hidden_size
80
+
81
+
82
+ def get_backbone_siglip2(model_name: str='google/siglip2-base-patch16-224',print_info=False):
83
+ """
84
+ Load siglip2 ViT model, return model, transforms and size of output (dimension of embedding of last token)
85
+ """
86
+ print(f"Loading Hugging Face model: {model_name}")
87
+ processor = AutoProcessor.from_pretrained(model_name)
88
+
89
+
90
+ # Extract image processing configuration from the loaded processor
91
+ image_processor_config = processor.image_processor
92
+ image_size = image_processor_config.size['height']
93
+ image_mean = image_processor_config.image_mean
94
+ image_std = image_processor_config.image_std
95
+
96
+ transform = transforms.Compose([
97
+ transforms.Lambda(_convert_to_rgb),
98
+ transforms.Resize((image_size, image_size), antialias=True),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize(mean=image_mean, std=image_std)
101
+ ])
102
+
103
+ # Load the model and return only the vision backbone
104
+ model = AutoModel.from_pretrained(model_name)
105
+ vision_model = model.vision_model
106
+
107
+ if print_info:
108
+ print(f'\nVISION CONFIGS:\n{vision_model.config}')
109
+ print(f'\n\n***************MHAP\n{vision_model.head}')
110
+
111
+
112
+ return vision_model, transform, vision_model.config.hidden_size
113
+
114
+ def _convert_to_rgb(image: Image.Image) -> Image.Image:
115
+ """Converts a PIL Image to RGB format."""
116
+ return image.convert("RGB")
117
+
118
+
119
+ def get_backbone(version: str, apply_migration : bool = False):
120
+ """
121
+ Returns vision transformer backbone
122
+ Args:
123
+ version: Name of the backbone to use, PE-Core or Siglip
124
+ ckpt: if different from null, loads backbone from .pt file specified, only for PE
125
+ """
126
+ if 'PE-Core-' in version:
127
+ return get_backbone_pe(version, False, apply_migration)
128
+ elif 'siglip2' in version:
129
+ print('[LOADING SIGLIP2]')
130
+ return get_backbone_siglip2(version)
131
+ elif 'dinov3' in version:
132
+ return get_backbone_dinov3(version)
133
+
134
+
135
+
136
+ def send_telegram_message(message: str):
137
+ """Sends a message to a Telegram chat using credentials from the config."""
138
+ # Get credentials from your config object. Use getattr for safety.
139
+ token = os.getenv("BOT_TOKEN")
140
+ chat_id = "1220514183"
141
+
142
+ if not token or not chat_id:
143
+ # Silently fail if credentials are not set
144
+ return
145
+
146
+ api_url = f"https://api.telegram.org/bot{token}/sendMessage"
147
+ payload = {
148
+ 'chat_id': chat_id,
149
+ 'text': message,
150
+ 'parse_mode': 'Markdown' # For nice formatting (bold, italics, etc.)
151
+ }
152
+
153
+ try:
154
+ response = requests.post(api_url, data=payload, timeout=10)
155
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
156
+ except requests.exceptions.RequestException as e:
157
+ # Don't crash the training loop if Telegram is down
158
+ print(f"\nWarning: Could not send Telegram message. Error: {e}")
utils/deploy.prototxt ADDED
@@ -0,0 +1,1789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input: "data"
2
+ input_shape {
3
+ dim: 1
4
+ dim: 3
5
+ dim: 300
6
+ dim: 300
7
+ }
8
+
9
+ layer {
10
+ name: "data_bn"
11
+ type: "BatchNorm"
12
+ bottom: "data"
13
+ top: "data_bn"
14
+ param {
15
+ lr_mult: 0.0
16
+ }
17
+ param {
18
+ lr_mult: 0.0
19
+ }
20
+ param {
21
+ lr_mult: 0.0
22
+ }
23
+ }
24
+ layer {
25
+ name: "data_scale"
26
+ type: "Scale"
27
+ bottom: "data_bn"
28
+ top: "data_bn"
29
+ param {
30
+ lr_mult: 1.0
31
+ decay_mult: 1.0
32
+ }
33
+ param {
34
+ lr_mult: 2.0
35
+ decay_mult: 1.0
36
+ }
37
+ scale_param {
38
+ bias_term: true
39
+ }
40
+ }
41
+ layer {
42
+ name: "conv1_h"
43
+ type: "Convolution"
44
+ bottom: "data_bn"
45
+ top: "conv1_h"
46
+ param {
47
+ lr_mult: 1.0
48
+ decay_mult: 1.0
49
+ }
50
+ param {
51
+ lr_mult: 2.0
52
+ decay_mult: 1.0
53
+ }
54
+ convolution_param {
55
+ num_output: 32
56
+ pad: 3
57
+ kernel_size: 7
58
+ stride: 2
59
+ weight_filler {
60
+ type: "msra"
61
+ variance_norm: FAN_OUT
62
+ }
63
+ bias_filler {
64
+ type: "constant"
65
+ value: 0.0
66
+ }
67
+ }
68
+ }
69
+ layer {
70
+ name: "conv1_bn_h"
71
+ type: "BatchNorm"
72
+ bottom: "conv1_h"
73
+ top: "conv1_h"
74
+ param {
75
+ lr_mult: 0.0
76
+ }
77
+ param {
78
+ lr_mult: 0.0
79
+ }
80
+ param {
81
+ lr_mult: 0.0
82
+ }
83
+ }
84
+ layer {
85
+ name: "conv1_scale_h"
86
+ type: "Scale"
87
+ bottom: "conv1_h"
88
+ top: "conv1_h"
89
+ param {
90
+ lr_mult: 1.0
91
+ decay_mult: 1.0
92
+ }
93
+ param {
94
+ lr_mult: 2.0
95
+ decay_mult: 1.0
96
+ }
97
+ scale_param {
98
+ bias_term: true
99
+ }
100
+ }
101
+ layer {
102
+ name: "conv1_relu"
103
+ type: "ReLU"
104
+ bottom: "conv1_h"
105
+ top: "conv1_h"
106
+ }
107
+ layer {
108
+ name: "conv1_pool"
109
+ type: "Pooling"
110
+ bottom: "conv1_h"
111
+ top: "conv1_pool"
112
+ pooling_param {
113
+ kernel_size: 3
114
+ stride: 2
115
+ }
116
+ }
117
+ layer {
118
+ name: "layer_64_1_conv1_h"
119
+ type: "Convolution"
120
+ bottom: "conv1_pool"
121
+ top: "layer_64_1_conv1_h"
122
+ param {
123
+ lr_mult: 1.0
124
+ decay_mult: 1.0
125
+ }
126
+ convolution_param {
127
+ num_output: 32
128
+ bias_term: false
129
+ pad: 1
130
+ kernel_size: 3
131
+ stride: 1
132
+ weight_filler {
133
+ type: "msra"
134
+ }
135
+ bias_filler {
136
+ type: "constant"
137
+ value: 0.0
138
+ }
139
+ }
140
+ }
141
+ layer {
142
+ name: "layer_64_1_bn2_h"
143
+ type: "BatchNorm"
144
+ bottom: "layer_64_1_conv1_h"
145
+ top: "layer_64_1_conv1_h"
146
+ param {
147
+ lr_mult: 0.0
148
+ }
149
+ param {
150
+ lr_mult: 0.0
151
+ }
152
+ param {
153
+ lr_mult: 0.0
154
+ }
155
+ }
156
+ layer {
157
+ name: "layer_64_1_scale2_h"
158
+ type: "Scale"
159
+ bottom: "layer_64_1_conv1_h"
160
+ top: "layer_64_1_conv1_h"
161
+ param {
162
+ lr_mult: 1.0
163
+ decay_mult: 1.0
164
+ }
165
+ param {
166
+ lr_mult: 2.0
167
+ decay_mult: 1.0
168
+ }
169
+ scale_param {
170
+ bias_term: true
171
+ }
172
+ }
173
+ layer {
174
+ name: "layer_64_1_relu2"
175
+ type: "ReLU"
176
+ bottom: "layer_64_1_conv1_h"
177
+ top: "layer_64_1_conv1_h"
178
+ }
179
+ layer {
180
+ name: "layer_64_1_conv2_h"
181
+ type: "Convolution"
182
+ bottom: "layer_64_1_conv1_h"
183
+ top: "layer_64_1_conv2_h"
184
+ param {
185
+ lr_mult: 1.0
186
+ decay_mult: 1.0
187
+ }
188
+ convolution_param {
189
+ num_output: 32
190
+ bias_term: false
191
+ pad: 1
192
+ kernel_size: 3
193
+ stride: 1
194
+ weight_filler {
195
+ type: "msra"
196
+ }
197
+ bias_filler {
198
+ type: "constant"
199
+ value: 0.0
200
+ }
201
+ }
202
+ }
203
+ layer {
204
+ name: "layer_64_1_sum"
205
+ type: "Eltwise"
206
+ bottom: "layer_64_1_conv2_h"
207
+ bottom: "conv1_pool"
208
+ top: "layer_64_1_sum"
209
+ }
210
+ layer {
211
+ name: "layer_128_1_bn1_h"
212
+ type: "BatchNorm"
213
+ bottom: "layer_64_1_sum"
214
+ top: "layer_128_1_bn1_h"
215
+ param {
216
+ lr_mult: 0.0
217
+ }
218
+ param {
219
+ lr_mult: 0.0
220
+ }
221
+ param {
222
+ lr_mult: 0.0
223
+ }
224
+ }
225
+ layer {
226
+ name: "layer_128_1_scale1_h"
227
+ type: "Scale"
228
+ bottom: "layer_128_1_bn1_h"
229
+ top: "layer_128_1_bn1_h"
230
+ param {
231
+ lr_mult: 1.0
232
+ decay_mult: 1.0
233
+ }
234
+ param {
235
+ lr_mult: 2.0
236
+ decay_mult: 1.0
237
+ }
238
+ scale_param {
239
+ bias_term: true
240
+ }
241
+ }
242
+ layer {
243
+ name: "layer_128_1_relu1"
244
+ type: "ReLU"
245
+ bottom: "layer_128_1_bn1_h"
246
+ top: "layer_128_1_bn1_h"
247
+ }
248
+ layer {
249
+ name: "layer_128_1_conv1_h"
250
+ type: "Convolution"
251
+ bottom: "layer_128_1_bn1_h"
252
+ top: "layer_128_1_conv1_h"
253
+ param {
254
+ lr_mult: 1.0
255
+ decay_mult: 1.0
256
+ }
257
+ convolution_param {
258
+ num_output: 128
259
+ bias_term: false
260
+ pad: 1
261
+ kernel_size: 3
262
+ stride: 2
263
+ weight_filler {
264
+ type: "msra"
265
+ }
266
+ bias_filler {
267
+ type: "constant"
268
+ value: 0.0
269
+ }
270
+ }
271
+ }
272
+ layer {
273
+ name: "layer_128_1_bn2"
274
+ type: "BatchNorm"
275
+ bottom: "layer_128_1_conv1_h"
276
+ top: "layer_128_1_conv1_h"
277
+ param {
278
+ lr_mult: 0.0
279
+ }
280
+ param {
281
+ lr_mult: 0.0
282
+ }
283
+ param {
284
+ lr_mult: 0.0
285
+ }
286
+ }
287
+ layer {
288
+ name: "layer_128_1_scale2"
289
+ type: "Scale"
290
+ bottom: "layer_128_1_conv1_h"
291
+ top: "layer_128_1_conv1_h"
292
+ param {
293
+ lr_mult: 1.0
294
+ decay_mult: 1.0
295
+ }
296
+ param {
297
+ lr_mult: 2.0
298
+ decay_mult: 1.0
299
+ }
300
+ scale_param {
301
+ bias_term: true
302
+ }
303
+ }
304
+ layer {
305
+ name: "layer_128_1_relu2"
306
+ type: "ReLU"
307
+ bottom: "layer_128_1_conv1_h"
308
+ top: "layer_128_1_conv1_h"
309
+ }
310
+ layer {
311
+ name: "layer_128_1_conv2"
312
+ type: "Convolution"
313
+ bottom: "layer_128_1_conv1_h"
314
+ top: "layer_128_1_conv2"
315
+ param {
316
+ lr_mult: 1.0
317
+ decay_mult: 1.0
318
+ }
319
+ convolution_param {
320
+ num_output: 128
321
+ bias_term: false
322
+ pad: 1
323
+ kernel_size: 3
324
+ stride: 1
325
+ weight_filler {
326
+ type: "msra"
327
+ }
328
+ bias_filler {
329
+ type: "constant"
330
+ value: 0.0
331
+ }
332
+ }
333
+ }
334
+ layer {
335
+ name: "layer_128_1_conv_expand_h"
336
+ type: "Convolution"
337
+ bottom: "layer_128_1_bn1_h"
338
+ top: "layer_128_1_conv_expand_h"
339
+ param {
340
+ lr_mult: 1.0
341
+ decay_mult: 1.0
342
+ }
343
+ convolution_param {
344
+ num_output: 128
345
+ bias_term: false
346
+ pad: 0
347
+ kernel_size: 1
348
+ stride: 2
349
+ weight_filler {
350
+ type: "msra"
351
+ }
352
+ bias_filler {
353
+ type: "constant"
354
+ value: 0.0
355
+ }
356
+ }
357
+ }
358
+ layer {
359
+ name: "layer_128_1_sum"
360
+ type: "Eltwise"
361
+ bottom: "layer_128_1_conv2"
362
+ bottom: "layer_128_1_conv_expand_h"
363
+ top: "layer_128_1_sum"
364
+ }
365
+ layer {
366
+ name: "layer_256_1_bn1"
367
+ type: "BatchNorm"
368
+ bottom: "layer_128_1_sum"
369
+ top: "layer_256_1_bn1"
370
+ param {
371
+ lr_mult: 0.0
372
+ }
373
+ param {
374
+ lr_mult: 0.0
375
+ }
376
+ param {
377
+ lr_mult: 0.0
378
+ }
379
+ }
380
+ layer {
381
+ name: "layer_256_1_scale1"
382
+ type: "Scale"
383
+ bottom: "layer_256_1_bn1"
384
+ top: "layer_256_1_bn1"
385
+ param {
386
+ lr_mult: 1.0
387
+ decay_mult: 1.0
388
+ }
389
+ param {
390
+ lr_mult: 2.0
391
+ decay_mult: 1.0
392
+ }
393
+ scale_param {
394
+ bias_term: true
395
+ }
396
+ }
397
+ layer {
398
+ name: "layer_256_1_relu1"
399
+ type: "ReLU"
400
+ bottom: "layer_256_1_bn1"
401
+ top: "layer_256_1_bn1"
402
+ }
403
+ layer {
404
+ name: "layer_256_1_conv1"
405
+ type: "Convolution"
406
+ bottom: "layer_256_1_bn1"
407
+ top: "layer_256_1_conv1"
408
+ param {
409
+ lr_mult: 1.0
410
+ decay_mult: 1.0
411
+ }
412
+ convolution_param {
413
+ num_output: 256
414
+ bias_term: false
415
+ pad: 1
416
+ kernel_size: 3
417
+ stride: 2
418
+ weight_filler {
419
+ type: "msra"
420
+ }
421
+ bias_filler {
422
+ type: "constant"
423
+ value: 0.0
424
+ }
425
+ }
426
+ }
427
+ layer {
428
+ name: "layer_256_1_bn2"
429
+ type: "BatchNorm"
430
+ bottom: "layer_256_1_conv1"
431
+ top: "layer_256_1_conv1"
432
+ param {
433
+ lr_mult: 0.0
434
+ }
435
+ param {
436
+ lr_mult: 0.0
437
+ }
438
+ param {
439
+ lr_mult: 0.0
440
+ }
441
+ }
442
+ layer {
443
+ name: "layer_256_1_scale2"
444
+ type: "Scale"
445
+ bottom: "layer_256_1_conv1"
446
+ top: "layer_256_1_conv1"
447
+ param {
448
+ lr_mult: 1.0
449
+ decay_mult: 1.0
450
+ }
451
+ param {
452
+ lr_mult: 2.0
453
+ decay_mult: 1.0
454
+ }
455
+ scale_param {
456
+ bias_term: true
457
+ }
458
+ }
459
+ layer {
460
+ name: "layer_256_1_relu2"
461
+ type: "ReLU"
462
+ bottom: "layer_256_1_conv1"
463
+ top: "layer_256_1_conv1"
464
+ }
465
+ layer {
466
+ name: "layer_256_1_conv2"
467
+ type: "Convolution"
468
+ bottom: "layer_256_1_conv1"
469
+ top: "layer_256_1_conv2"
470
+ param {
471
+ lr_mult: 1.0
472
+ decay_mult: 1.0
473
+ }
474
+ convolution_param {
475
+ num_output: 256
476
+ bias_term: false
477
+ pad: 1
478
+ kernel_size: 3
479
+ stride: 1
480
+ weight_filler {
481
+ type: "msra"
482
+ }
483
+ bias_filler {
484
+ type: "constant"
485
+ value: 0.0
486
+ }
487
+ }
488
+ }
489
+ layer {
490
+ name: "layer_256_1_conv_expand"
491
+ type: "Convolution"
492
+ bottom: "layer_256_1_bn1"
493
+ top: "layer_256_1_conv_expand"
494
+ param {
495
+ lr_mult: 1.0
496
+ decay_mult: 1.0
497
+ }
498
+ convolution_param {
499
+ num_output: 256
500
+ bias_term: false
501
+ pad: 0
502
+ kernel_size: 1
503
+ stride: 2
504
+ weight_filler {
505
+ type: "msra"
506
+ }
507
+ bias_filler {
508
+ type: "constant"
509
+ value: 0.0
510
+ }
511
+ }
512
+ }
513
+ layer {
514
+ name: "layer_256_1_sum"
515
+ type: "Eltwise"
516
+ bottom: "layer_256_1_conv2"
517
+ bottom: "layer_256_1_conv_expand"
518
+ top: "layer_256_1_sum"
519
+ }
520
+ layer {
521
+ name: "layer_512_1_bn1"
522
+ type: "BatchNorm"
523
+ bottom: "layer_256_1_sum"
524
+ top: "layer_512_1_bn1"
525
+ param {
526
+ lr_mult: 0.0
527
+ }
528
+ param {
529
+ lr_mult: 0.0
530
+ }
531
+ param {
532
+ lr_mult: 0.0
533
+ }
534
+ }
535
+ layer {
536
+ name: "layer_512_1_scale1"
537
+ type: "Scale"
538
+ bottom: "layer_512_1_bn1"
539
+ top: "layer_512_1_bn1"
540
+ param {
541
+ lr_mult: 1.0
542
+ decay_mult: 1.0
543
+ }
544
+ param {
545
+ lr_mult: 2.0
546
+ decay_mult: 1.0
547
+ }
548
+ scale_param {
549
+ bias_term: true
550
+ }
551
+ }
552
+ layer {
553
+ name: "layer_512_1_relu1"
554
+ type: "ReLU"
555
+ bottom: "layer_512_1_bn1"
556
+ top: "layer_512_1_bn1"
557
+ }
558
+ layer {
559
+ name: "layer_512_1_conv1_h"
560
+ type: "Convolution"
561
+ bottom: "layer_512_1_bn1"
562
+ top: "layer_512_1_conv1_h"
563
+ param {
564
+ lr_mult: 1.0
565
+ decay_mult: 1.0
566
+ }
567
+ convolution_param {
568
+ num_output: 128
569
+ bias_term: false
570
+ pad: 1
571
+ kernel_size: 3
572
+ stride: 1 # 2
573
+ weight_filler {
574
+ type: "msra"
575
+ }
576
+ bias_filler {
577
+ type: "constant"
578
+ value: 0.0
579
+ }
580
+ }
581
+ }
582
+ layer {
583
+ name: "layer_512_1_bn2_h"
584
+ type: "BatchNorm"
585
+ bottom: "layer_512_1_conv1_h"
586
+ top: "layer_512_1_conv1_h"
587
+ param {
588
+ lr_mult: 0.0
589
+ }
590
+ param {
591
+ lr_mult: 0.0
592
+ }
593
+ param {
594
+ lr_mult: 0.0
595
+ }
596
+ }
597
+ layer {
598
+ name: "layer_512_1_scale2_h"
599
+ type: "Scale"
600
+ bottom: "layer_512_1_conv1_h"
601
+ top: "layer_512_1_conv1_h"
602
+ param {
603
+ lr_mult: 1.0
604
+ decay_mult: 1.0
605
+ }
606
+ param {
607
+ lr_mult: 2.0
608
+ decay_mult: 1.0
609
+ }
610
+ scale_param {
611
+ bias_term: true
612
+ }
613
+ }
614
+ layer {
615
+ name: "layer_512_1_relu2"
616
+ type: "ReLU"
617
+ bottom: "layer_512_1_conv1_h"
618
+ top: "layer_512_1_conv1_h"
619
+ }
620
+ layer {
621
+ name: "layer_512_1_conv2_h"
622
+ type: "Convolution"
623
+ bottom: "layer_512_1_conv1_h"
624
+ top: "layer_512_1_conv2_h"
625
+ param {
626
+ lr_mult: 1.0
627
+ decay_mult: 1.0
628
+ }
629
+ convolution_param {
630
+ num_output: 256
631
+ bias_term: false
632
+ pad: 2 # 1
633
+ kernel_size: 3
634
+ stride: 1
635
+ dilation: 2
636
+ weight_filler {
637
+ type: "msra"
638
+ }
639
+ bias_filler {
640
+ type: "constant"
641
+ value: 0.0
642
+ }
643
+ }
644
+ }
645
+ layer {
646
+ name: "layer_512_1_conv_expand_h"
647
+ type: "Convolution"
648
+ bottom: "layer_512_1_bn1"
649
+ top: "layer_512_1_conv_expand_h"
650
+ param {
651
+ lr_mult: 1.0
652
+ decay_mult: 1.0
653
+ }
654
+ convolution_param {
655
+ num_output: 256
656
+ bias_term: false
657
+ pad: 0
658
+ kernel_size: 1
659
+ stride: 1 # 2
660
+ weight_filler {
661
+ type: "msra"
662
+ }
663
+ bias_filler {
664
+ type: "constant"
665
+ value: 0.0
666
+ }
667
+ }
668
+ }
669
+ layer {
670
+ name: "layer_512_1_sum"
671
+ type: "Eltwise"
672
+ bottom: "layer_512_1_conv2_h"
673
+ bottom: "layer_512_1_conv_expand_h"
674
+ top: "layer_512_1_sum"
675
+ }
676
+ layer {
677
+ name: "last_bn_h"
678
+ type: "BatchNorm"
679
+ bottom: "layer_512_1_sum"
680
+ top: "layer_512_1_sum"
681
+ param {
682
+ lr_mult: 0.0
683
+ }
684
+ param {
685
+ lr_mult: 0.0
686
+ }
687
+ param {
688
+ lr_mult: 0.0
689
+ }
690
+ }
691
+ layer {
692
+ name: "last_scale_h"
693
+ type: "Scale"
694
+ bottom: "layer_512_1_sum"
695
+ top: "layer_512_1_sum"
696
+ param {
697
+ lr_mult: 1.0
698
+ decay_mult: 1.0
699
+ }
700
+ param {
701
+ lr_mult: 2.0
702
+ decay_mult: 1.0
703
+ }
704
+ scale_param {
705
+ bias_term: true
706
+ }
707
+ }
708
+ layer {
709
+ name: "last_relu"
710
+ type: "ReLU"
711
+ bottom: "layer_512_1_sum"
712
+ top: "fc7"
713
+ }
714
+
715
+ layer {
716
+ name: "conv6_1_h"
717
+ type: "Convolution"
718
+ bottom: "fc7"
719
+ top: "conv6_1_h"
720
+ param {
721
+ lr_mult: 1
722
+ decay_mult: 1
723
+ }
724
+ param {
725
+ lr_mult: 2
726
+ decay_mult: 0
727
+ }
728
+ convolution_param {
729
+ num_output: 128
730
+ pad: 0
731
+ kernel_size: 1
732
+ stride: 1
733
+ weight_filler {
734
+ type: "xavier"
735
+ }
736
+ bias_filler {
737
+ type: "constant"
738
+ value: 0
739
+ }
740
+ }
741
+ }
742
+ layer {
743
+ name: "conv6_1_relu"
744
+ type: "ReLU"
745
+ bottom: "conv6_1_h"
746
+ top: "conv6_1_h"
747
+ }
748
+ layer {
749
+ name: "conv6_2_h"
750
+ type: "Convolution"
751
+ bottom: "conv6_1_h"
752
+ top: "conv6_2_h"
753
+ param {
754
+ lr_mult: 1
755
+ decay_mult: 1
756
+ }
757
+ param {
758
+ lr_mult: 2
759
+ decay_mult: 0
760
+ }
761
+ convolution_param {
762
+ num_output: 256
763
+ pad: 1
764
+ kernel_size: 3
765
+ stride: 2
766
+ weight_filler {
767
+ type: "xavier"
768
+ }
769
+ bias_filler {
770
+ type: "constant"
771
+ value: 0
772
+ }
773
+ }
774
+ }
775
+ layer {
776
+ name: "conv6_2_relu"
777
+ type: "ReLU"
778
+ bottom: "conv6_2_h"
779
+ top: "conv6_2_h"
780
+ }
781
+ layer {
782
+ name: "conv7_1_h"
783
+ type: "Convolution"
784
+ bottom: "conv6_2_h"
785
+ top: "conv7_1_h"
786
+ param {
787
+ lr_mult: 1
788
+ decay_mult: 1
789
+ }
790
+ param {
791
+ lr_mult: 2
792
+ decay_mult: 0
793
+ }
794
+ convolution_param {
795
+ num_output: 64
796
+ pad: 0
797
+ kernel_size: 1
798
+ stride: 1
799
+ weight_filler {
800
+ type: "xavier"
801
+ }
802
+ bias_filler {
803
+ type: "constant"
804
+ value: 0
805
+ }
806
+ }
807
+ }
808
+ layer {
809
+ name: "conv7_1_relu"
810
+ type: "ReLU"
811
+ bottom: "conv7_1_h"
812
+ top: "conv7_1_h"
813
+ }
814
+ layer {
815
+ name: "conv7_2_h"
816
+ type: "Convolution"
817
+ bottom: "conv7_1_h"
818
+ top: "conv7_2_h"
819
+ param {
820
+ lr_mult: 1
821
+ decay_mult: 1
822
+ }
823
+ param {
824
+ lr_mult: 2
825
+ decay_mult: 0
826
+ }
827
+ convolution_param {
828
+ num_output: 128
829
+ pad: 1
830
+ kernel_size: 3
831
+ stride: 2
832
+ weight_filler {
833
+ type: "xavier"
834
+ }
835
+ bias_filler {
836
+ type: "constant"
837
+ value: 0
838
+ }
839
+ }
840
+ }
841
+ layer {
842
+ name: "conv7_2_relu"
843
+ type: "ReLU"
844
+ bottom: "conv7_2_h"
845
+ top: "conv7_2_h"
846
+ }
847
+ layer {
848
+ name: "conv8_1_h"
849
+ type: "Convolution"
850
+ bottom: "conv7_2_h"
851
+ top: "conv8_1_h"
852
+ param {
853
+ lr_mult: 1
854
+ decay_mult: 1
855
+ }
856
+ param {
857
+ lr_mult: 2
858
+ decay_mult: 0
859
+ }
860
+ convolution_param {
861
+ num_output: 64
862
+ pad: 0
863
+ kernel_size: 1
864
+ stride: 1
865
+ weight_filler {
866
+ type: "xavier"
867
+ }
868
+ bias_filler {
869
+ type: "constant"
870
+ value: 0
871
+ }
872
+ }
873
+ }
874
+ layer {
875
+ name: "conv8_1_relu"
876
+ type: "ReLU"
877
+ bottom: "conv8_1_h"
878
+ top: "conv8_1_h"
879
+ }
880
+ layer {
881
+ name: "conv8_2_h"
882
+ type: "Convolution"
883
+ bottom: "conv8_1_h"
884
+ top: "conv8_2_h"
885
+ param {
886
+ lr_mult: 1
887
+ decay_mult: 1
888
+ }
889
+ param {
890
+ lr_mult: 2
891
+ decay_mult: 0
892
+ }
893
+ convolution_param {
894
+ num_output: 128
895
+ pad: 1
896
+ kernel_size: 3
897
+ stride: 1
898
+ weight_filler {
899
+ type: "xavier"
900
+ }
901
+ bias_filler {
902
+ type: "constant"
903
+ value: 0
904
+ }
905
+ }
906
+ }
907
+ layer {
908
+ name: "conv8_2_relu"
909
+ type: "ReLU"
910
+ bottom: "conv8_2_h"
911
+ top: "conv8_2_h"
912
+ }
913
+ layer {
914
+ name: "conv9_1_h"
915
+ type: "Convolution"
916
+ bottom: "conv8_2_h"
917
+ top: "conv9_1_h"
918
+ param {
919
+ lr_mult: 1
920
+ decay_mult: 1
921
+ }
922
+ param {
923
+ lr_mult: 2
924
+ decay_mult: 0
925
+ }
926
+ convolution_param {
927
+ num_output: 64
928
+ pad: 0
929
+ kernel_size: 1
930
+ stride: 1
931
+ weight_filler {
932
+ type: "xavier"
933
+ }
934
+ bias_filler {
935
+ type: "constant"
936
+ value: 0
937
+ }
938
+ }
939
+ }
940
+ layer {
941
+ name: "conv9_1_relu"
942
+ type: "ReLU"
943
+ bottom: "conv9_1_h"
944
+ top: "conv9_1_h"
945
+ }
946
+ layer {
947
+ name: "conv9_2_h"
948
+ type: "Convolution"
949
+ bottom: "conv9_1_h"
950
+ top: "conv9_2_h"
951
+ param {
952
+ lr_mult: 1
953
+ decay_mult: 1
954
+ }
955
+ param {
956
+ lr_mult: 2
957
+ decay_mult: 0
958
+ }
959
+ convolution_param {
960
+ num_output: 128
961
+ pad: 1
962
+ kernel_size: 3
963
+ stride: 1
964
+ weight_filler {
965
+ type: "xavier"
966
+ }
967
+ bias_filler {
968
+ type: "constant"
969
+ value: 0
970
+ }
971
+ }
972
+ }
973
+ layer {
974
+ name: "conv9_2_relu"
975
+ type: "ReLU"
976
+ bottom: "conv9_2_h"
977
+ top: "conv9_2_h"
978
+ }
979
+ layer {
980
+ name: "conv4_3_norm"
981
+ type: "Normalize"
982
+ bottom: "layer_256_1_bn1"
983
+ top: "conv4_3_norm"
984
+ norm_param {
985
+ across_spatial: false
986
+ scale_filler {
987
+ type: "constant"
988
+ value: 20
989
+ }
990
+ channel_shared: false
991
+ }
992
+ }
993
+ layer {
994
+ name: "conv4_3_norm_mbox_loc"
995
+ type: "Convolution"
996
+ bottom: "conv4_3_norm"
997
+ top: "conv4_3_norm_mbox_loc"
998
+ param {
999
+ lr_mult: 1
1000
+ decay_mult: 1
1001
+ }
1002
+ param {
1003
+ lr_mult: 2
1004
+ decay_mult: 0
1005
+ }
1006
+ convolution_param {
1007
+ num_output: 16
1008
+ pad: 1
1009
+ kernel_size: 3
1010
+ stride: 1
1011
+ weight_filler {
1012
+ type: "xavier"
1013
+ }
1014
+ bias_filler {
1015
+ type: "constant"
1016
+ value: 0
1017
+ }
1018
+ }
1019
+ }
1020
+ layer {
1021
+ name: "conv4_3_norm_mbox_loc_perm"
1022
+ type: "Permute"
1023
+ bottom: "conv4_3_norm_mbox_loc"
1024
+ top: "conv4_3_norm_mbox_loc_perm"
1025
+ permute_param {
1026
+ order: 0
1027
+ order: 2
1028
+ order: 3
1029
+ order: 1
1030
+ }
1031
+ }
1032
+ layer {
1033
+ name: "conv4_3_norm_mbox_loc_flat"
1034
+ type: "Flatten"
1035
+ bottom: "conv4_3_norm_mbox_loc_perm"
1036
+ top: "conv4_3_norm_mbox_loc_flat"
1037
+ flatten_param {
1038
+ axis: 1
1039
+ }
1040
+ }
1041
+ layer {
1042
+ name: "conv4_3_norm_mbox_conf"
1043
+ type: "Convolution"
1044
+ bottom: "conv4_3_norm"
1045
+ top: "conv4_3_norm_mbox_conf"
1046
+ param {
1047
+ lr_mult: 1
1048
+ decay_mult: 1
1049
+ }
1050
+ param {
1051
+ lr_mult: 2
1052
+ decay_mult: 0
1053
+ }
1054
+ convolution_param {
1055
+ num_output: 8 # 84
1056
+ pad: 1
1057
+ kernel_size: 3
1058
+ stride: 1
1059
+ weight_filler {
1060
+ type: "xavier"
1061
+ }
1062
+ bias_filler {
1063
+ type: "constant"
1064
+ value: 0
1065
+ }
1066
+ }
1067
+ }
1068
+ layer {
1069
+ name: "conv4_3_norm_mbox_conf_perm"
1070
+ type: "Permute"
1071
+ bottom: "conv4_3_norm_mbox_conf"
1072
+ top: "conv4_3_norm_mbox_conf_perm"
1073
+ permute_param {
1074
+ order: 0
1075
+ order: 2
1076
+ order: 3
1077
+ order: 1
1078
+ }
1079
+ }
1080
+ layer {
1081
+ name: "conv4_3_norm_mbox_conf_flat"
1082
+ type: "Flatten"
1083
+ bottom: "conv4_3_norm_mbox_conf_perm"
1084
+ top: "conv4_3_norm_mbox_conf_flat"
1085
+ flatten_param {
1086
+ axis: 1
1087
+ }
1088
+ }
1089
+ layer {
1090
+ name: "conv4_3_norm_mbox_priorbox"
1091
+ type: "PriorBox"
1092
+ bottom: "conv4_3_norm"
1093
+ bottom: "data"
1094
+ top: "conv4_3_norm_mbox_priorbox"
1095
+ prior_box_param {
1096
+ min_size: 30.0
1097
+ max_size: 60.0
1098
+ aspect_ratio: 2
1099
+ flip: true
1100
+ clip: false
1101
+ variance: 0.1
1102
+ variance: 0.1
1103
+ variance: 0.2
1104
+ variance: 0.2
1105
+ step: 8
1106
+ offset: 0.5
1107
+ }
1108
+ }
1109
+ layer {
1110
+ name: "fc7_mbox_loc"
1111
+ type: "Convolution"
1112
+ bottom: "fc7"
1113
+ top: "fc7_mbox_loc"
1114
+ param {
1115
+ lr_mult: 1
1116
+ decay_mult: 1
1117
+ }
1118
+ param {
1119
+ lr_mult: 2
1120
+ decay_mult: 0
1121
+ }
1122
+ convolution_param {
1123
+ num_output: 24
1124
+ pad: 1
1125
+ kernel_size: 3
1126
+ stride: 1
1127
+ weight_filler {
1128
+ type: "xavier"
1129
+ }
1130
+ bias_filler {
1131
+ type: "constant"
1132
+ value: 0
1133
+ }
1134
+ }
1135
+ }
1136
+ layer {
1137
+ name: "fc7_mbox_loc_perm"
1138
+ type: "Permute"
1139
+ bottom: "fc7_mbox_loc"
1140
+ top: "fc7_mbox_loc_perm"
1141
+ permute_param {
1142
+ order: 0
1143
+ order: 2
1144
+ order: 3
1145
+ order: 1
1146
+ }
1147
+ }
1148
+ layer {
1149
+ name: "fc7_mbox_loc_flat"
1150
+ type: "Flatten"
1151
+ bottom: "fc7_mbox_loc_perm"
1152
+ top: "fc7_mbox_loc_flat"
1153
+ flatten_param {
1154
+ axis: 1
1155
+ }
1156
+ }
1157
+ layer {
1158
+ name: "fc7_mbox_conf"
1159
+ type: "Convolution"
1160
+ bottom: "fc7"
1161
+ top: "fc7_mbox_conf"
1162
+ param {
1163
+ lr_mult: 1
1164
+ decay_mult: 1
1165
+ }
1166
+ param {
1167
+ lr_mult: 2
1168
+ decay_mult: 0
1169
+ }
1170
+ convolution_param {
1171
+ num_output: 12 # 126
1172
+ pad: 1
1173
+ kernel_size: 3
1174
+ stride: 1
1175
+ weight_filler {
1176
+ type: "xavier"
1177
+ }
1178
+ bias_filler {
1179
+ type: "constant"
1180
+ value: 0
1181
+ }
1182
+ }
1183
+ }
1184
+ layer {
1185
+ name: "fc7_mbox_conf_perm"
1186
+ type: "Permute"
1187
+ bottom: "fc7_mbox_conf"
1188
+ top: "fc7_mbox_conf_perm"
1189
+ permute_param {
1190
+ order: 0
1191
+ order: 2
1192
+ order: 3
1193
+ order: 1
1194
+ }
1195
+ }
1196
+ layer {
1197
+ name: "fc7_mbox_conf_flat"
1198
+ type: "Flatten"
1199
+ bottom: "fc7_mbox_conf_perm"
1200
+ top: "fc7_mbox_conf_flat"
1201
+ flatten_param {
1202
+ axis: 1
1203
+ }
1204
+ }
1205
+ layer {
1206
+ name: "fc7_mbox_priorbox"
1207
+ type: "PriorBox"
1208
+ bottom: "fc7"
1209
+ bottom: "data"
1210
+ top: "fc7_mbox_priorbox"
1211
+ prior_box_param {
1212
+ min_size: 60.0
1213
+ max_size: 111.0
1214
+ aspect_ratio: 2
1215
+ aspect_ratio: 3
1216
+ flip: true
1217
+ clip: false
1218
+ variance: 0.1
1219
+ variance: 0.1
1220
+ variance: 0.2
1221
+ variance: 0.2
1222
+ step: 16
1223
+ offset: 0.5
1224
+ }
1225
+ }
1226
+ layer {
1227
+ name: "conv6_2_mbox_loc"
1228
+ type: "Convolution"
1229
+ bottom: "conv6_2_h"
1230
+ top: "conv6_2_mbox_loc"
1231
+ param {
1232
+ lr_mult: 1
1233
+ decay_mult: 1
1234
+ }
1235
+ param {
1236
+ lr_mult: 2
1237
+ decay_mult: 0
1238
+ }
1239
+ convolution_param {
1240
+ num_output: 24
1241
+ pad: 1
1242
+ kernel_size: 3
1243
+ stride: 1
1244
+ weight_filler {
1245
+ type: "xavier"
1246
+ }
1247
+ bias_filler {
1248
+ type: "constant"
1249
+ value: 0
1250
+ }
1251
+ }
1252
+ }
1253
+ layer {
1254
+ name: "conv6_2_mbox_loc_perm"
1255
+ type: "Permute"
1256
+ bottom: "conv6_2_mbox_loc"
1257
+ top: "conv6_2_mbox_loc_perm"
1258
+ permute_param {
1259
+ order: 0
1260
+ order: 2
1261
+ order: 3
1262
+ order: 1
1263
+ }
1264
+ }
1265
+ layer {
1266
+ name: "conv6_2_mbox_loc_flat"
1267
+ type: "Flatten"
1268
+ bottom: "conv6_2_mbox_loc_perm"
1269
+ top: "conv6_2_mbox_loc_flat"
1270
+ flatten_param {
1271
+ axis: 1
1272
+ }
1273
+ }
1274
+ layer {
1275
+ name: "conv6_2_mbox_conf"
1276
+ type: "Convolution"
1277
+ bottom: "conv6_2_h"
1278
+ top: "conv6_2_mbox_conf"
1279
+ param {
1280
+ lr_mult: 1
1281
+ decay_mult: 1
1282
+ }
1283
+ param {
1284
+ lr_mult: 2
1285
+ decay_mult: 0
1286
+ }
1287
+ convolution_param {
1288
+ num_output: 12 # 126
1289
+ pad: 1
1290
+ kernel_size: 3
1291
+ stride: 1
1292
+ weight_filler {
1293
+ type: "xavier"
1294
+ }
1295
+ bias_filler {
1296
+ type: "constant"
1297
+ value: 0
1298
+ }
1299
+ }
1300
+ }
1301
+ layer {
1302
+ name: "conv6_2_mbox_conf_perm"
1303
+ type: "Permute"
1304
+ bottom: "conv6_2_mbox_conf"
1305
+ top: "conv6_2_mbox_conf_perm"
1306
+ permute_param {
1307
+ order: 0
1308
+ order: 2
1309
+ order: 3
1310
+ order: 1
1311
+ }
1312
+ }
1313
+ layer {
1314
+ name: "conv6_2_mbox_conf_flat"
1315
+ type: "Flatten"
1316
+ bottom: "conv6_2_mbox_conf_perm"
1317
+ top: "conv6_2_mbox_conf_flat"
1318
+ flatten_param {
1319
+ axis: 1
1320
+ }
1321
+ }
1322
+ layer {
1323
+ name: "conv6_2_mbox_priorbox"
1324
+ type: "PriorBox"
1325
+ bottom: "conv6_2_h"
1326
+ bottom: "data"
1327
+ top: "conv6_2_mbox_priorbox"
1328
+ prior_box_param {
1329
+ min_size: 111.0
1330
+ max_size: 162.0
1331
+ aspect_ratio: 2
1332
+ aspect_ratio: 3
1333
+ flip: true
1334
+ clip: false
1335
+ variance: 0.1
1336
+ variance: 0.1
1337
+ variance: 0.2
1338
+ variance: 0.2
1339
+ step: 32
1340
+ offset: 0.5
1341
+ }
1342
+ }
1343
+ layer {
1344
+ name: "conv7_2_mbox_loc"
1345
+ type: "Convolution"
1346
+ bottom: "conv7_2_h"
1347
+ top: "conv7_2_mbox_loc"
1348
+ param {
1349
+ lr_mult: 1
1350
+ decay_mult: 1
1351
+ }
1352
+ param {
1353
+ lr_mult: 2
1354
+ decay_mult: 0
1355
+ }
1356
+ convolution_param {
1357
+ num_output: 24
1358
+ pad: 1
1359
+ kernel_size: 3
1360
+ stride: 1
1361
+ weight_filler {
1362
+ type: "xavier"
1363
+ }
1364
+ bias_filler {
1365
+ type: "constant"
1366
+ value: 0
1367
+ }
1368
+ }
1369
+ }
1370
+ layer {
1371
+ name: "conv7_2_mbox_loc_perm"
1372
+ type: "Permute"
1373
+ bottom: "conv7_2_mbox_loc"
1374
+ top: "conv7_2_mbox_loc_perm"
1375
+ permute_param {
1376
+ order: 0
1377
+ order: 2
1378
+ order: 3
1379
+ order: 1
1380
+ }
1381
+ }
1382
+ layer {
1383
+ name: "conv7_2_mbox_loc_flat"
1384
+ type: "Flatten"
1385
+ bottom: "conv7_2_mbox_loc_perm"
1386
+ top: "conv7_2_mbox_loc_flat"
1387
+ flatten_param {
1388
+ axis: 1
1389
+ }
1390
+ }
1391
+ layer {
1392
+ name: "conv7_2_mbox_conf"
1393
+ type: "Convolution"
1394
+ bottom: "conv7_2_h"
1395
+ top: "conv7_2_mbox_conf"
1396
+ param {
1397
+ lr_mult: 1
1398
+ decay_mult: 1
1399
+ }
1400
+ param {
1401
+ lr_mult: 2
1402
+ decay_mult: 0
1403
+ }
1404
+ convolution_param {
1405
+ num_output: 12 # 126
1406
+ pad: 1
1407
+ kernel_size: 3
1408
+ stride: 1
1409
+ weight_filler {
1410
+ type: "xavier"
1411
+ }
1412
+ bias_filler {
1413
+ type: "constant"
1414
+ value: 0
1415
+ }
1416
+ }
1417
+ }
1418
+ layer {
1419
+ name: "conv7_2_mbox_conf_perm"
1420
+ type: "Permute"
1421
+ bottom: "conv7_2_mbox_conf"
1422
+ top: "conv7_2_mbox_conf_perm"
1423
+ permute_param {
1424
+ order: 0
1425
+ order: 2
1426
+ order: 3
1427
+ order: 1
1428
+ }
1429
+ }
1430
+ layer {
1431
+ name: "conv7_2_mbox_conf_flat"
1432
+ type: "Flatten"
1433
+ bottom: "conv7_2_mbox_conf_perm"
1434
+ top: "conv7_2_mbox_conf_flat"
1435
+ flatten_param {
1436
+ axis: 1
1437
+ }
1438
+ }
1439
+ layer {
1440
+ name: "conv7_2_mbox_priorbox"
1441
+ type: "PriorBox"
1442
+ bottom: "conv7_2_h"
1443
+ bottom: "data"
1444
+ top: "conv7_2_mbox_priorbox"
1445
+ prior_box_param {
1446
+ min_size: 162.0
1447
+ max_size: 213.0
1448
+ aspect_ratio: 2
1449
+ aspect_ratio: 3
1450
+ flip: true
1451
+ clip: false
1452
+ variance: 0.1
1453
+ variance: 0.1
1454
+ variance: 0.2
1455
+ variance: 0.2
1456
+ step: 64
1457
+ offset: 0.5
1458
+ }
1459
+ }
1460
+ layer {
1461
+ name: "conv8_2_mbox_loc"
1462
+ type: "Convolution"
1463
+ bottom: "conv8_2_h"
1464
+ top: "conv8_2_mbox_loc"
1465
+ param {
1466
+ lr_mult: 1
1467
+ decay_mult: 1
1468
+ }
1469
+ param {
1470
+ lr_mult: 2
1471
+ decay_mult: 0
1472
+ }
1473
+ convolution_param {
1474
+ num_output: 16
1475
+ pad: 1
1476
+ kernel_size: 3
1477
+ stride: 1
1478
+ weight_filler {
1479
+ type: "xavier"
1480
+ }
1481
+ bias_filler {
1482
+ type: "constant"
1483
+ value: 0
1484
+ }
1485
+ }
1486
+ }
1487
+ layer {
1488
+ name: "conv8_2_mbox_loc_perm"
1489
+ type: "Permute"
1490
+ bottom: "conv8_2_mbox_loc"
1491
+ top: "conv8_2_mbox_loc_perm"
1492
+ permute_param {
1493
+ order: 0
1494
+ order: 2
1495
+ order: 3
1496
+ order: 1
1497
+ }
1498
+ }
1499
+ layer {
1500
+ name: "conv8_2_mbox_loc_flat"
1501
+ type: "Flatten"
1502
+ bottom: "conv8_2_mbox_loc_perm"
1503
+ top: "conv8_2_mbox_loc_flat"
1504
+ flatten_param {
1505
+ axis: 1
1506
+ }
1507
+ }
1508
+ layer {
1509
+ name: "conv8_2_mbox_conf"
1510
+ type: "Convolution"
1511
+ bottom: "conv8_2_h"
1512
+ top: "conv8_2_mbox_conf"
1513
+ param {
1514
+ lr_mult: 1
1515
+ decay_mult: 1
1516
+ }
1517
+ param {
1518
+ lr_mult: 2
1519
+ decay_mult: 0
1520
+ }
1521
+ convolution_param {
1522
+ num_output: 8 # 84
1523
+ pad: 1
1524
+ kernel_size: 3
1525
+ stride: 1
1526
+ weight_filler {
1527
+ type: "xavier"
1528
+ }
1529
+ bias_filler {
1530
+ type: "constant"
1531
+ value: 0
1532
+ }
1533
+ }
1534
+ }
1535
+ layer {
1536
+ name: "conv8_2_mbox_conf_perm"
1537
+ type: "Permute"
1538
+ bottom: "conv8_2_mbox_conf"
1539
+ top: "conv8_2_mbox_conf_perm"
1540
+ permute_param {
1541
+ order: 0
1542
+ order: 2
1543
+ order: 3
1544
+ order: 1
1545
+ }
1546
+ }
1547
+ layer {
1548
+ name: "conv8_2_mbox_conf_flat"
1549
+ type: "Flatten"
1550
+ bottom: "conv8_2_mbox_conf_perm"
1551
+ top: "conv8_2_mbox_conf_flat"
1552
+ flatten_param {
1553
+ axis: 1
1554
+ }
1555
+ }
1556
+ layer {
1557
+ name: "conv8_2_mbox_priorbox"
1558
+ type: "PriorBox"
1559
+ bottom: "conv8_2_h"
1560
+ bottom: "data"
1561
+ top: "conv8_2_mbox_priorbox"
1562
+ prior_box_param {
1563
+ min_size: 213.0
1564
+ max_size: 264.0
1565
+ aspect_ratio: 2
1566
+ flip: true
1567
+ clip: false
1568
+ variance: 0.1
1569
+ variance: 0.1
1570
+ variance: 0.2
1571
+ variance: 0.2
1572
+ step: 100
1573
+ offset: 0.5
1574
+ }
1575
+ }
1576
+ layer {
1577
+ name: "conv9_2_mbox_loc"
1578
+ type: "Convolution"
1579
+ bottom: "conv9_2_h"
1580
+ top: "conv9_2_mbox_loc"
1581
+ param {
1582
+ lr_mult: 1
1583
+ decay_mult: 1
1584
+ }
1585
+ param {
1586
+ lr_mult: 2
1587
+ decay_mult: 0
1588
+ }
1589
+ convolution_param {
1590
+ num_output: 16
1591
+ pad: 1
1592
+ kernel_size: 3
1593
+ stride: 1
1594
+ weight_filler {
1595
+ type: "xavier"
1596
+ }
1597
+ bias_filler {
1598
+ type: "constant"
1599
+ value: 0
1600
+ }
1601
+ }
1602
+ }
1603
+ layer {
1604
+ name: "conv9_2_mbox_loc_perm"
1605
+ type: "Permute"
1606
+ bottom: "conv9_2_mbox_loc"
1607
+ top: "conv9_2_mbox_loc_perm"
1608
+ permute_param {
1609
+ order: 0
1610
+ order: 2
1611
+ order: 3
1612
+ order: 1
1613
+ }
1614
+ }
1615
+ layer {
1616
+ name: "conv9_2_mbox_loc_flat"
1617
+ type: "Flatten"
1618
+ bottom: "conv9_2_mbox_loc_perm"
1619
+ top: "conv9_2_mbox_loc_flat"
1620
+ flatten_param {
1621
+ axis: 1
1622
+ }
1623
+ }
1624
+ layer {
1625
+ name: "conv9_2_mbox_conf"
1626
+ type: "Convolution"
1627
+ bottom: "conv9_2_h"
1628
+ top: "conv9_2_mbox_conf"
1629
+ param {
1630
+ lr_mult: 1
1631
+ decay_mult: 1
1632
+ }
1633
+ param {
1634
+ lr_mult: 2
1635
+ decay_mult: 0
1636
+ }
1637
+ convolution_param {
1638
+ num_output: 8 # 84
1639
+ pad: 1
1640
+ kernel_size: 3
1641
+ stride: 1
1642
+ weight_filler {
1643
+ type: "xavier"
1644
+ }
1645
+ bias_filler {
1646
+ type: "constant"
1647
+ value: 0
1648
+ }
1649
+ }
1650
+ }
1651
+ layer {
1652
+ name: "conv9_2_mbox_conf_perm"
1653
+ type: "Permute"
1654
+ bottom: "conv9_2_mbox_conf"
1655
+ top: "conv9_2_mbox_conf_perm"
1656
+ permute_param {
1657
+ order: 0
1658
+ order: 2
1659
+ order: 3
1660
+ order: 1
1661
+ }
1662
+ }
1663
+ layer {
1664
+ name: "conv9_2_mbox_conf_flat"
1665
+ type: "Flatten"
1666
+ bottom: "conv9_2_mbox_conf_perm"
1667
+ top: "conv9_2_mbox_conf_flat"
1668
+ flatten_param {
1669
+ axis: 1
1670
+ }
1671
+ }
1672
+ layer {
1673
+ name: "conv9_2_mbox_priorbox"
1674
+ type: "PriorBox"
1675
+ bottom: "conv9_2_h"
1676
+ bottom: "data"
1677
+ top: "conv9_2_mbox_priorbox"
1678
+ prior_box_param {
1679
+ min_size: 264.0
1680
+ max_size: 315.0
1681
+ aspect_ratio: 2
1682
+ flip: true
1683
+ clip: false
1684
+ variance: 0.1
1685
+ variance: 0.1
1686
+ variance: 0.2
1687
+ variance: 0.2
1688
+ step: 300
1689
+ offset: 0.5
1690
+ }
1691
+ }
1692
+ layer {
1693
+ name: "mbox_loc"
1694
+ type: "Concat"
1695
+ bottom: "conv4_3_norm_mbox_loc_flat"
1696
+ bottom: "fc7_mbox_loc_flat"
1697
+ bottom: "conv6_2_mbox_loc_flat"
1698
+ bottom: "conv7_2_mbox_loc_flat"
1699
+ bottom: "conv8_2_mbox_loc_flat"
1700
+ bottom: "conv9_2_mbox_loc_flat"
1701
+ top: "mbox_loc"
1702
+ concat_param {
1703
+ axis: 1
1704
+ }
1705
+ }
1706
+ layer {
1707
+ name: "mbox_conf"
1708
+ type: "Concat"
1709
+ bottom: "conv4_3_norm_mbox_conf_flat"
1710
+ bottom: "fc7_mbox_conf_flat"
1711
+ bottom: "conv6_2_mbox_conf_flat"
1712
+ bottom: "conv7_2_mbox_conf_flat"
1713
+ bottom: "conv8_2_mbox_conf_flat"
1714
+ bottom: "conv9_2_mbox_conf_flat"
1715
+ top: "mbox_conf"
1716
+ concat_param {
1717
+ axis: 1
1718
+ }
1719
+ }
1720
+ layer {
1721
+ name: "mbox_priorbox"
1722
+ type: "Concat"
1723
+ bottom: "conv4_3_norm_mbox_priorbox"
1724
+ bottom: "fc7_mbox_priorbox"
1725
+ bottom: "conv6_2_mbox_priorbox"
1726
+ bottom: "conv7_2_mbox_priorbox"
1727
+ bottom: "conv8_2_mbox_priorbox"
1728
+ bottom: "conv9_2_mbox_priorbox"
1729
+ top: "mbox_priorbox"
1730
+ concat_param {
1731
+ axis: 2
1732
+ }
1733
+ }
1734
+
1735
+ layer {
1736
+ name: "mbox_conf_reshape"
1737
+ type: "Reshape"
1738
+ bottom: "mbox_conf"
1739
+ top: "mbox_conf_reshape"
1740
+ reshape_param {
1741
+ shape {
1742
+ dim: 0
1743
+ dim: -1
1744
+ dim: 2
1745
+ }
1746
+ }
1747
+ }
1748
+ layer {
1749
+ name: "mbox_conf_softmax"
1750
+ type: "Softmax"
1751
+ bottom: "mbox_conf_reshape"
1752
+ top: "mbox_conf_softmax"
1753
+ softmax_param {
1754
+ axis: 2
1755
+ }
1756
+ }
1757
+ layer {
1758
+ name: "mbox_conf_flatten"
1759
+ type: "Flatten"
1760
+ bottom: "mbox_conf_softmax"
1761
+ top: "mbox_conf_flatten"
1762
+ flatten_param {
1763
+ axis: 1
1764
+ }
1765
+ }
1766
+
1767
+ layer {
1768
+ name: "detection_out"
1769
+ type: "DetectionOutput"
1770
+ bottom: "mbox_loc"
1771
+ bottom: "mbox_conf_flatten"
1772
+ bottom: "mbox_priorbox"
1773
+ top: "detection_out"
1774
+ include {
1775
+ phase: TEST
1776
+ }
1777
+ detection_output_param {
1778
+ num_classes: 2
1779
+ share_location: true
1780
+ background_label_id: 0
1781
+ nms_param {
1782
+ nms_threshold: 0.45
1783
+ top_k: 400
1784
+ }
1785
+ code_type: CENTER_SIZE
1786
+ keep_top_k: 200
1787
+ confidence_threshold: 0.01
1788
+ }
1789
+ }
utils/face_detector.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Face detector, used only for the demo to crop faces, as datasets have already been face-cropped"""
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ DEFAULT_FACE_DETECTOR = "utils/res10_300x300_ssd_iter_140000_fp16.caffemodel"
9
+ DEFAULT_DEPLOY = "utils/deploy.prototxt"
10
+
11
+ def enclosing_square(rect):
12
+ # Crea un quadrato che contiene il rettangolo passato in ingresso
13
+ x, y, w, h = rect
14
+ side = max(w, h)
15
+ # Centra il quadrato sulla bbox originale
16
+ cx = x + w // 2
17
+ cy = y + h // 2
18
+ x_new = cx - side // 2
19
+ y_new = cy - side // 2
20
+ return (x_new, y_new, side, side)
21
+
22
+
23
+ def cut(frame, roi):
24
+ pA = (int(roi[0]), int(roi[1]))
25
+ pB = (int(roi[0] + roi[2]), int(roi[1] + roi[3])) # pB will be an internal point
26
+ W, H = frame.shape[1], frame.shape[0]
27
+ A0 = pA[0] if pA[0] >= 0 else 0
28
+ A1 = pA[1] if pA[1] >= 0 else 0
29
+ data = frame[A1:pB[1], A0:pB[0]]
30
+ if pB[0] < W and pB[1] < H and pA[0] >= 0 and pA[1] >= 0:
31
+ return data
32
+ w, h = int(roi[2]), int(roi[3])
33
+ img = np.zeros((h, w, frame.shape[2]), dtype=np.uint8)
34
+ offX = int(-roi[0]) if roi[0] < 0 else 0
35
+ offY = int(-roi[1]) if roi[1] < 0 else 0
36
+ np.copyto(img[offY:offY + data.shape[0], offX:offX + data.shape[1]], data)
37
+ return img
38
+
39
+ class FaceDetector:
40
+ """Face detector to spot faces inside a picture."""
41
+ def __init__(self, face_detector = DEFAULT_FACE_DETECTOR, deploy=DEFAULT_DEPLOY, confidence_threshold=0.8):
42
+ self.detector = cv2.dnn.readNetFromCaffe(deploy, face_detector)
43
+ self.confidence_threshold = confidence_threshold
44
+
45
+ def detect(self, image, pad_rect=True):
46
+ blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123], False, False)
47
+ frameHeight, frameWidth, channels = image.shape
48
+ self.detector.setInput(blob)
49
+ detections = self.detector.forward()
50
+
51
+ faces_result = []
52
+ for i in range(detections.shape[2]):
53
+ confidence = detections[0, 0, i, 2]
54
+ if confidence > self.confidence_threshold:
55
+ x1 = int(detections[0, 0, i, 3] * frameWidth)
56
+ y1 = int(detections[0, 0, i, 4] * frameHeight)
57
+ x2 = int(detections[0, 0, i, 5] * frameWidth)
58
+ y2 = int(detections[0, 0, i, 6] * frameHeight)
59
+ f = (x1, y1, x2 - x1, y2 - y1) # bbox: (x, y, w, h)
60
+ if f[2] > 1 and f[3] > 1:
61
+ rect = enclosing_square(f) if pad_rect else f
62
+ img_crop = cut(image, rect)
63
+ if img_crop.shape[0] > 0 and img_crop.shape[1] > 0:
64
+ faces_result.append((img_crop, confidence, rect)) # usa rect (quadrato) come bbox finale
65
+ if len(faces_result) == 0:
66
+ return None
67
+ return faces_result
68
+
69
+ if __name__ == "__main__":
70
+ input_folder = "src/demo_images"
71
+ output_crop_folder = "./test/detector/crop"
72
+ output_bbox_folder = "./test/detector/bbox"
73
+
74
+ os.makedirs(output_crop_folder, exist_ok=True)
75
+ os.makedirs(output_bbox_folder, exist_ok=True)
76
+
77
+ face_detector = FaceDetector(confidence_threshold=0.8)
78
+
79
+ image_files = sorted([
80
+ f for f in os.listdir(input_folder)
81
+ if os.path.isfile(os.path.join(input_folder, f))
82
+ ])
83
+
84
+ for img_file in image_files:
85
+ img_path = os.path.join(input_folder, img_file)
86
+ img = cv2.imread(img_path)
87
+ if img is None:
88
+ continue
89
+
90
+ faces = face_detector.detect(img, pad_rect=True)
91
+ base_name = os.path.splitext(os.path.basename(img_path))[0]
92
+
93
+ if faces is not None:
94
+ # Salva i crop dei volti
95
+ for idx, (crop, confidence, bbox) in enumerate(faces):
96
+ crop_path = os.path.join(output_crop_folder, f"{base_name}_face{idx}.jpg")
97
+ cv2.imwrite(crop_path, crop)
98
+
99
+ # Salva l'immagine originale con bbox quadrata
100
+ img_bbox = img.copy()
101
+ for idx, (_, _, bbox) in enumerate(faces):
102
+ x, y, w, h = bbox # bbox è già quadrata
103
+ cv2.rectangle(img_bbox, (x, y), (x + w, y + h), (0, 0, 255), 2) # rosso BGR
104
+ bbox_path = os.path.join(output_bbox_folder, f"{base_name}_bbox.jpg")
105
+ cv2.imwrite(bbox_path, img_bbox)
utils/task_config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Task class definition """
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Type
5
+ import torch.nn as nn
6
+
7
+
8
+
9
+ @dataclass
10
+ class Task:
11
+ """Encapsulates all configuration for a single task."""
12
+ name: str
13
+ class_labels: List[str]
14
+ criterion: Type[nn.Module]
15
+ weight: float = 1.0
16
+ use_weighted_loss: bool = False
17
+
18
+ @property
19
+ def num_classes(self) -> int:
20
+ return len(self.class_labels)
21
+