init
Browse files- app.py +788 -0
- core/args.py +72 -0
- core/checkpoint.py +379 -0
- core/transformer.py +646 -0
- core/transforms/image_transform.py +409 -0
- core/utils.py +40 -0
- core/vision_encoder/__init__.py +0 -0
- core/vision_encoder/__pycache__/__init__.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/__init__.cpython-313.pyc +0 -0
- core/vision_encoder/__pycache__/config.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/config.cpython-313.pyc +0 -0
- core/vision_encoder/__pycache__/pe.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/pe.cpython-313.pyc +0 -0
- core/vision_encoder/__pycache__/pe_lora.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/rope.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/rope.cpython-313.pyc +0 -0
- core/vision_encoder/__pycache__/tokenizer.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/tokenizer.cpython-313.pyc +0 -0
- core/vision_encoder/__pycache__/transforms.cpython-312.pyc +0 -0
- core/vision_encoder/__pycache__/transforms.cpython-313.pyc +0 -0
- core/vision_encoder/config.py +260 -0
- core/vision_encoder/pe.py +833 -0
- core/vision_encoder/rope.py +347 -0
- core/vision_encoder/transforms.py +86 -0
- core/vision_projector/base.py +26 -0
- core/vision_projector/mlp.py +62 -0
- requirements.txt +10 -0
- setup.py +7 -0
- src/model.py +809 -0
- utils/__pycache__/commons.cpython-313.pyc +0 -0
- utils/__pycache__/dataset.cpython-313.pyc +0 -0
- utils/__pycache__/face_detector.cpython-313.pyc +0 -0
- utils/__pycache__/task_config.cpython-313.pyc +0 -0
- utils/commons.py +158 -0
- utils/deploy.prototxt +1789 -0
- utils/face_detector.py +105 -0
- utils/task_config.py +21 -0
app.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VLM Soft Biometrics - Gradio Interface
|
| 3 |
+
A web application for analyzing facial soft biometrics (age, gender, emotion) using Vision-Language Models.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 11 |
+
import base64
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
import traceback # Import traceback at the top
|
| 14 |
+
|
| 15 |
+
from utils.face_detector import FaceDetector
|
| 16 |
+
|
| 17 |
+
# Class definitions
|
| 18 |
+
from src.model import MTLModel
|
| 19 |
+
from utils.commons import get_backbone_pe
|
| 20 |
+
from utils.task_config import Task
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
TASKS = [
|
| 24 |
+
Task(name='Age', class_labels=["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], criterion=None),
|
| 25 |
+
Task(name='Gender', class_labels=["Male", "Female"], criterion=None),
|
| 26 |
+
Task(name='Emotion', class_labels=["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"], criterion=None)
|
| 27 |
+
]
|
| 28 |
+
CLASSES = [
|
| 29 |
+
["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"],
|
| 30 |
+
["M", "F"],
|
| 31 |
+
["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"]
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Global variables for model and detector
|
| 35 |
+
model = None
|
| 36 |
+
transform = None
|
| 37 |
+
detector = None
|
| 38 |
+
device = None
|
| 39 |
+
current_ckpt_dir = None
|
| 40 |
+
CHECKPOINTS_DIR = './checkpoints/'
|
| 41 |
+
|
| 42 |
+
def scan_checkpoints(ckpt_dir):
|
| 43 |
+
"""Scans a directory for .pt or .pth files."""
|
| 44 |
+
if not os.path.exists(ckpt_dir):
|
| 45 |
+
print(f"Warning: Checkpoint directory not found: {ckpt_dir}")
|
| 46 |
+
return [], None
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
ckpt_files = [
|
| 50 |
+
os.path.join(ckpt_dir, f)
|
| 51 |
+
for f in sorted(os.listdir(ckpt_dir))
|
| 52 |
+
]
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
|
| 55 |
+
return [], None
|
| 56 |
+
|
| 57 |
+
# Create a list of (label, value) tuples
|
| 58 |
+
# label = filename (e.g., "mtlora.pt"), value = full path
|
| 59 |
+
choices_list = [(os.path.basename(f), f) for f in ckpt_files]
|
| 60 |
+
|
| 61 |
+
default_ckpt_path = os.path.join(ckpt_dir, 'mtlora.pt')
|
| 62 |
+
|
| 63 |
+
if default_ckpt_path in ckpt_files:
|
| 64 |
+
return choices_list, default_ckpt_path
|
| 65 |
+
elif ckpt_files:
|
| 66 |
+
return choices_list, ckpt_files[0] # default_ckpt_value is the full path
|
| 67 |
+
else:
|
| 68 |
+
print(f"No checkpoints found in {ckpt_dir}")
|
| 69 |
+
return [], None
|
| 70 |
+
|
| 71 |
+
def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
|
| 72 |
+
"""Load and configure model."""
|
| 73 |
+
backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
|
| 74 |
+
model = MTLModel(backbone,tasks=TASKS,use_lora=True,use_deep_head=True,
|
| 75 |
+
use_mtl_lora=('mtlora' in ckpt_dir),
|
| 76 |
+
)
|
| 77 |
+
print(f'loading from {ckpt_dir}')
|
| 78 |
+
model.load_model(filepath=ckpt_dir,map_location=device)
|
| 79 |
+
return model,transform
|
| 80 |
+
|
| 81 |
+
def load_model_and_update_status(ckpt_dir):
|
| 82 |
+
"""
|
| 83 |
+
Wrapper function to load a model and return a status message.
|
| 84 |
+
This is used by the dropdown's 'change' event.
|
| 85 |
+
"""
|
| 86 |
+
global model, current_ckpt_dir
|
| 87 |
+
|
| 88 |
+
if ckpt_dir is None or ckpt_dir == "":
|
| 89 |
+
return "No checkpoint selected."
|
| 90 |
+
|
| 91 |
+
if model is not None and ckpt_dir == current_ckpt_dir:
|
| 92 |
+
status = f"Model already loaded: {os.path.basename(ckpt_dir)}"
|
| 93 |
+
print(status)
|
| 94 |
+
return status
|
| 95 |
+
|
| 96 |
+
gr.Info(f"Loading model: {os.path.basename(ckpt_dir)}...")
|
| 97 |
+
try:
|
| 98 |
+
init_model(ckpt_dir=ckpt_dir, detection_confidence=0.5)
|
| 99 |
+
current_ckpt_dir = ckpt_dir # Set global directory on successful load
|
| 100 |
+
status = f"Successfully loaded: {os.path.basename(ckpt_dir)}"
|
| 101 |
+
gr.Info("Model loaded successfully!")
|
| 102 |
+
print(status)
|
| 103 |
+
return status
|
| 104 |
+
except Exception as e:
|
| 105 |
+
status = f"Failed to load {os.path.basename(ckpt_dir)}: {str(e)}"
|
| 106 |
+
print(status)
|
| 107 |
+
traceback.print_exc()
|
| 108 |
+
return status
|
| 109 |
+
|
| 110 |
+
def predict(model, image):
|
| 111 |
+
"""Make predictions for age, gender, and emotion."""
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
results = model(image)
|
| 114 |
+
|
| 115 |
+
age_logits, gender_logits, emotion_logits = results['Age'], results['Gender'], results['Emotion']
|
| 116 |
+
# Get probabilities using softmax
|
| 117 |
+
age_probs = torch.softmax(age_logits, dim=-1)
|
| 118 |
+
gender_probs = torch.softmax(gender_logits, dim=-1)
|
| 119 |
+
emotion_probs = torch.softmax(emotion_logits, dim=-1)
|
| 120 |
+
|
| 121 |
+
ages = torch.argmax(age_logits, dim=-1).cpu().tolist()
|
| 122 |
+
genders = torch.argmax(gender_logits, dim=-1).cpu().tolist()
|
| 123 |
+
emotions = torch.argmax(emotion_logits, dim=-1).cpu().tolist()
|
| 124 |
+
|
| 125 |
+
results = []
|
| 126 |
+
for i in range(len(ages)):
|
| 127 |
+
# Get all probabilities for each class
|
| 128 |
+
age_all_probs = {
|
| 129 |
+
CLASSES[0][j]: float(age_probs[i][j].cpu().detach())
|
| 130 |
+
for j in range(len(CLASSES[0]))
|
| 131 |
+
}
|
| 132 |
+
gender_all_probs = {
|
| 133 |
+
CLASSES[1][j]: float(gender_probs[i][j].cpu().detach())
|
| 134 |
+
for j in range(len(CLASSES[1]))
|
| 135 |
+
}
|
| 136 |
+
emotion_all_probs = {
|
| 137 |
+
CLASSES[2][j]: float(emotion_probs[i][j].cpu().detach())
|
| 138 |
+
for j in range(len(CLASSES[2]))
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
results.append({
|
| 142 |
+
'age': {
|
| 143 |
+
'predicted_class': CLASSES[0][ages[i]],
|
| 144 |
+
'predicted_confidence': float(age_probs[i][ages[i]].cpu().detach()),
|
| 145 |
+
'all_probabilities': age_all_probs
|
| 146 |
+
},
|
| 147 |
+
'gender': {
|
| 148 |
+
'predicted_class': CLASSES[1][genders[i]],
|
| 149 |
+
'predicted_confidence': float(gender_probs[i][genders[i]].cpu().detach()),
|
| 150 |
+
'all_probabilities': gender_all_probs
|
| 151 |
+
},
|
| 152 |
+
'emotion': {
|
| 153 |
+
'predicted_class': CLASSES[2][emotions[i]],
|
| 154 |
+
'predicted_confidence': float(emotion_probs[i][emotions[i]].cpu().detach()),
|
| 155 |
+
'all_probabilities': emotion_all_probs
|
| 156 |
+
}
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
return results
|
| 160 |
+
|
| 161 |
+
def get_centroid_weighted_age(probs):
|
| 162 |
+
probs = list(probs.values())
|
| 163 |
+
centroids = [1, 4.5, 14.5, 24.5, 34.5, 44.5, 54.5, 64.5, 80]
|
| 164 |
+
age = 0
|
| 165 |
+
# print(probs) # DEBUG
|
| 166 |
+
for i,p in enumerate(probs):
|
| 167 |
+
age += p * centroids[i]
|
| 168 |
+
|
| 169 |
+
return age
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
|
| 173 |
+
"""Initialize model and detector."""
|
| 174 |
+
global model, transform, detector, device
|
| 175 |
+
|
| 176 |
+
print(f"\n{'='*60}")
|
| 177 |
+
print(f"INITIALIZING MODEL: {ckpt_dir}")
|
| 178 |
+
print(f"{'='*60}")
|
| 179 |
+
|
| 180 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 181 |
+
print(f"Using device: {device}")
|
| 182 |
+
|
| 183 |
+
# Verify model weights exist
|
| 184 |
+
if not os.path.exists(ckpt_dir):
|
| 185 |
+
error_msg = f"Model weights not found: {ckpt_dir}."
|
| 186 |
+
print(f"ERROR: {error_msg}")
|
| 187 |
+
raise FileNotFoundError(error_msg)
|
| 188 |
+
|
| 189 |
+
print(f"Model weights found: {ckpt_dir}")
|
| 190 |
+
|
| 191 |
+
# Load the perception encoder
|
| 192 |
+
model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
|
| 193 |
+
model.eval()
|
| 194 |
+
model.to(device)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
detector = FaceDetector(confidence_threshold=detection_confidence)
|
| 198 |
+
|
| 199 |
+
print("✓ Model and detector initialized successfully")
|
| 200 |
+
print(f"{'='*60}\n")
|
| 201 |
+
|
| 202 |
+
def process_image(image, selected_checkpoint_path):
|
| 203 |
+
"""
|
| 204 |
+
Process an uploaded image and return predictions with annotated image.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
image: PIL Image or numpy array
|
| 208 |
+
selected_checkpoint_path: The path from the checkpoint dropdown
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
tuple: (annotated_image, results_html)
|
| 212 |
+
"""
|
| 213 |
+
if image is None:
|
| 214 |
+
return None, "<p style='color: red;'>Please upload an image</p>"
|
| 215 |
+
|
| 216 |
+
# Ensure model is initialized
|
| 217 |
+
if model is None or selected_checkpoint_path != current_ckpt_dir:
|
| 218 |
+
status = load_model_and_update_status(selected_checkpoint_path)
|
| 219 |
+
if "Failed" in status or "Error" in status:
|
| 220 |
+
return image, f"<p style'color: red;'>Model Error: {status}</p>"
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# --- 1. Prepare images for detection and drawing ---
|
| 225 |
+
|
| 226 |
+
# Convert PIL to OpenCV format (BGR) for the detector
|
| 227 |
+
if isinstance(image, Image.Image):
|
| 228 |
+
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 229 |
+
else:
|
| 230 |
+
# Assuming it's a numpy array from Gradio webcam
|
| 231 |
+
img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 232 |
+
|
| 233 |
+
# Create a PIL copy to draw annotations on
|
| 234 |
+
img_pil_annotated = image.copy()
|
| 235 |
+
draw = ImageDraw.Draw(img_pil_annotated)
|
| 236 |
+
|
| 237 |
+
# --- 2. Detect faces ---
|
| 238 |
+
faces = detector.detect(img_cv, pad_rect=True)
|
| 239 |
+
|
| 240 |
+
if faces is None or len(faces) == 0:
|
| 241 |
+
return image, "<p style='color: orange;'>No faces detected in the image</p>"
|
| 242 |
+
|
| 243 |
+
# --- 3. Process detected faces ---
|
| 244 |
+
crops_pil = []
|
| 245 |
+
face_data = []
|
| 246 |
+
|
| 247 |
+
for idx, (crop, confidence, bbox) in enumerate(faces):
|
| 248 |
+
crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
| 249 |
+
crop_pil = Image.fromarray(crop_rgb)
|
| 250 |
+
crops_pil.append(crop_pil)
|
| 251 |
+
|
| 252 |
+
# Resize crop to 336x336 for display
|
| 253 |
+
crop_resized = crop_pil.resize((336, 336), Image.Resampling.LANCZOS)
|
| 254 |
+
|
| 255 |
+
face_data.append({
|
| 256 |
+
'bbox': bbox,
|
| 257 |
+
'detection_confidence': float(confidence),
|
| 258 |
+
'crop_image': crop_resized # Store the resized crop
|
| 259 |
+
})
|
| 260 |
+
|
| 261 |
+
# --- 4. Batch transform and predict ---
|
| 262 |
+
crop_tensors = [transform(crop_pil) for crop_pil in crops_pil]
|
| 263 |
+
batch_tensor = torch.stack(crop_tensors).to(device)
|
| 264 |
+
|
| 265 |
+
# Get predictions
|
| 266 |
+
predictions = predict(model, batch_tensor)
|
| 267 |
+
|
| 268 |
+
# Combine face data with predictions
|
| 269 |
+
for face, pred in zip(face_data, predictions):
|
| 270 |
+
face['predictions'] = pred
|
| 271 |
+
|
| 272 |
+
# --- 5. Create annotated image (using PIL) ---
|
| 273 |
+
for idx, face in enumerate(face_data):
|
| 274 |
+
bbox = face['bbox']
|
| 275 |
+
pred = face['predictions']
|
| 276 |
+
x, y, w, h = bbox
|
| 277 |
+
|
| 278 |
+
# --- Calculate Adaptive Font (from demo.py) ---
|
| 279 |
+
font_size_ratio = 0.08
|
| 280 |
+
min_font_size = 12
|
| 281 |
+
max_font_size = 48
|
| 282 |
+
adaptive_font_size = max(min_font_size, min(int(w * font_size_ratio), max_font_size))
|
| 283 |
+
try:
|
| 284 |
+
font = ImageFont.load_default(size=adaptive_font_size)
|
| 285 |
+
except IOError:
|
| 286 |
+
font = ImageFont.load_default()
|
| 287 |
+
|
| 288 |
+
# --- Draw Bounding Box ---
|
| 289 |
+
draw.rectangle([(x, y), (x + w, y + h)], outline="lime", width=2)
|
| 290 |
+
|
| 291 |
+
# --- Prepare Text Lines (Top-1 Only) ---
|
| 292 |
+
lines_to_draw = []
|
| 293 |
+
|
| 294 |
+
# Age
|
| 295 |
+
age_label = pred['age']['predicted_class']
|
| 296 |
+
age_conf = pred['age']['predicted_confidence']
|
| 297 |
+
lines_to_draw.append(f"Age: {age_label} ({age_conf*100:.0f}%)")
|
| 298 |
+
|
| 299 |
+
# Gender
|
| 300 |
+
gen_label = pred['gender']['predicted_class']
|
| 301 |
+
gen_conf = pred['gender']['predicted_confidence']
|
| 302 |
+
lines_to_draw.append(f"Gender: {gen_label} ({gen_conf*100:.0f}%)")
|
| 303 |
+
|
| 304 |
+
# Emotion
|
| 305 |
+
emo_label = pred['emotion']['predicted_class']
|
| 306 |
+
emo_conf = pred['emotion']['predicted_confidence']
|
| 307 |
+
lines_to_draw.append(f"Emotion: {emo_label} ({emo_conf*100:.0f}%)")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# --- Calculate total height of the text block (from demo.py) ---
|
| 311 |
+
line_spacing = 10
|
| 312 |
+
total_text_height = 0
|
| 313 |
+
for line in lines_to_draw:
|
| 314 |
+
_left, top, _right, bottom = draw.textbbox((0, 0), line, font=font)
|
| 315 |
+
total_text_height += (bottom - top) + line_spacing
|
| 316 |
+
|
| 317 |
+
# --- Place text ABOVE or BELOW the box (from demo.py) ---
|
| 318 |
+
if y - total_text_height > 0:
|
| 319 |
+
# PLACE TEXT ABOVE: There is enough space
|
| 320 |
+
text_y = y - line_spacing
|
| 321 |
+
for line in reversed(lines_to_draw):
|
| 322 |
+
left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="ls") # anchor left-baseline
|
| 323 |
+
draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
|
| 324 |
+
draw.text((x, text_y), line, font=font, fill="white", anchor="ls")
|
| 325 |
+
text_y = top - line_spacing # Move y-position up for the next line
|
| 326 |
+
else:
|
| 327 |
+
# PLACE TEXT BELOW: Not enough space above, so draw downwards
|
| 328 |
+
text_y = y + h + line_spacing
|
| 329 |
+
for line in lines_to_draw:
|
| 330 |
+
left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="lt")
|
| 331 |
+
draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
|
| 332 |
+
draw.text((x, text_y), line, font=font, fill="white", anchor="lt")
|
| 333 |
+
text_y = bottom + line_spacing
|
| 334 |
+
|
| 335 |
+
# --- 6. Create HTML results ---
|
| 336 |
+
|
| 337 |
+
# Helper function to convert PIL image to base64
|
| 338 |
+
def pil_to_base64(img_pil):
|
| 339 |
+
buffered = BytesIO()
|
| 340 |
+
img_pil.save(buffered, format="JPEG")
|
| 341 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 342 |
+
return f"data:image/jpeg;base64,{img_str}"
|
| 343 |
+
|
| 344 |
+
# (HTML Generation code remains the same as before)
|
| 345 |
+
results_html = f"""
|
| 346 |
+
<style>
|
| 347 |
+
:root {{
|
| 348 |
+
--primary-color: #4f46e5;
|
| 349 |
+
--success-color: #10b981;
|
| 350 |
+
--text-primary: #ffffff;
|
| 351 |
+
--text-secondary: #9ca3af;
|
| 352 |
+
--background-dark: #1f2937;
|
| 353 |
+
--background-darker: #111827;
|
| 354 |
+
--border-color: #374151;
|
| 355 |
+
}}
|
| 356 |
+
.results-container {{
|
| 357 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| 358 |
+
background: var(--background-darker);
|
| 359 |
+
padding: 20px;
|
| 360 |
+
border-radius: 12px;
|
| 361 |
+
color: var(--text-primary);
|
| 362 |
+
}}
|
| 363 |
+
.results-container h2 {{
|
| 364 |
+
color: var(--text-primary);
|
| 365 |
+
margin-bottom: 20px;
|
| 366 |
+
}}
|
| 367 |
+
.face-count {{
|
| 368 |
+
display: inline-block;
|
| 369 |
+
background: var(--primary-color);
|
| 370 |
+
color: white;
|
| 371 |
+
padding: 4px 12px;
|
| 372 |
+
border-radius: 20px;
|
| 373 |
+
font-size: 0.9rem;
|
| 374 |
+
font-weight: 500;
|
| 375 |
+
margin-left: 8px;
|
| 376 |
+
}}
|
| 377 |
+
.face-card {{
|
| 378 |
+
background: var(--background-dark);
|
| 379 |
+
border-radius: 8px;
|
| 380 |
+
padding: 20px;
|
| 381 |
+
margin-top: 15px;
|
| 382 |
+
border: 1px solid var(--border-color);
|
| 383 |
+
display: flex;
|
| 384 |
+
gap: 20px;
|
| 385 |
+
align-items: flex-start;
|
| 386 |
+
}}
|
| 387 |
+
.face-header {{
|
| 388 |
+
font-size: 1rem;
|
| 389 |
+
font-weight: 600;
|
| 390 |
+
margin-bottom: 20px;
|
| 391 |
+
color: var(--text-primary);
|
| 392 |
+
}}
|
| 393 |
+
.face-image-left {{
|
| 394 |
+
flex-shrink: 0;
|
| 395 |
+
width: 336px;
|
| 396 |
+
height: 336px;
|
| 397 |
+
background: var(--background-darker);
|
| 398 |
+
border-radius: 8px;
|
| 399 |
+
overflow: hidden;
|
| 400 |
+
border: 1px solid var(--border-color);
|
| 401 |
+
}}
|
| 402 |
+
.face-image-left img {{
|
| 403 |
+
width: 100%;
|
| 404 |
+
height: 100%;
|
| 405 |
+
object-fit: cover;
|
| 406 |
+
}}
|
| 407 |
+
.face-predictions-right {{
|
| 408 |
+
flex: 1;
|
| 409 |
+
display: flex;
|
| 410 |
+
flex-direction: column;
|
| 411 |
+
gap: 10px;
|
| 412 |
+
}}
|
| 413 |
+
.predictions-horizontal {{
|
| 414 |
+
display: flex;
|
| 415 |
+
flex-direction: row;
|
| 416 |
+
gap: 30px;
|
| 417 |
+
justify-content: space-between;
|
| 418 |
+
}}
|
| 419 |
+
.prediction-section {{
|
| 420 |
+
flex: 1;
|
| 421 |
+
min-width: 0;
|
| 422 |
+
}}
|
| 423 |
+
.prediction-category-label {{
|
| 424 |
+
font-size: 0.8rem;
|
| 425 |
+
font-weight: 700;
|
| 426 |
+
text-transform: uppercase;
|
| 427 |
+
letter-spacing: 0.5px;
|
| 428 |
+
color: var(--primary-color);
|
| 429 |
+
margin-bottom: 8px;
|
| 430 |
+
border-bottom: 2px solid var(--primary-color);
|
| 431 |
+
padding-bottom: 4px;
|
| 432 |
+
}}
|
| 433 |
+
.probabilities-list {{
|
| 434 |
+
display: flex;
|
| 435 |
+
flex-direction: column;
|
| 436 |
+
gap: 6px;
|
| 437 |
+
}}
|
| 438 |
+
.probability-item {{
|
| 439 |
+
display: grid;
|
| 440 |
+
grid-template-columns: 70px 1fr 55px;
|
| 441 |
+
align-items: center;
|
| 442 |
+
gap: 8px;
|
| 443 |
+
padding: 4px 6px;
|
| 444 |
+
border-radius: 4px;
|
| 445 |
+
}}
|
| 446 |
+
.probability-item.predicted {{
|
| 447 |
+
background: rgba(79, 70, 229, 0.2);
|
| 448 |
+
border-left: 3px solid var(--primary-color);
|
| 449 |
+
padding-left: 8px;
|
| 450 |
+
}}
|
| 451 |
+
.prob-class {{
|
| 452 |
+
font-size: 0.8rem;
|
| 453 |
+
font-weight: 600;
|
| 454 |
+
color: var(--text-primary);
|
| 455 |
+
word-wrap: break-word; /* Ensure long class names wrap */
|
| 456 |
+
}}
|
| 457 |
+
.probability-item.predicted .prob-class {{
|
| 458 |
+
color: var(--primary-color);
|
| 459 |
+
font-weight: 700;
|
| 460 |
+
}}
|
| 461 |
+
.prob-bar-container {{
|
| 462 |
+
height: 6px;
|
| 463 |
+
background: var(--border-color);
|
| 464 |
+
border-radius: 3px;
|
| 465 |
+
overflow: hidden;
|
| 466 |
+
}}
|
| 467 |
+
.prob-bar {{
|
| 468 |
+
height: 100%;
|
| 469 |
+
background: linear-gradient(90deg, var(--primary-color), var(--success-color));
|
| 470 |
+
border-radius: 3px;
|
| 471 |
+
transition: width 0.6s ease;
|
| 472 |
+
}}
|
| 473 |
+
.probability-item.predicted .prob-bar {{
|
| 474 |
+
background: var(--primary-color);
|
| 475 |
+
}}
|
| 476 |
+
.prob-percentage {{
|
| 477 |
+
font-size: 0.75rem;
|
| 478 |
+
font-weight: 500;
|
| 479 |
+
color: var(--text-secondary);
|
| 480 |
+
text-align: right;
|
| 481 |
+
}}
|
| 482 |
+
.probability-item.predicted .prob-percentage {{
|
| 483 |
+
color: var(--primary-color);
|
| 484 |
+
font-weight: 700;
|
| 485 |
+
}}
|
| 486 |
+
@media (max-width: 1200px) {{
|
| 487 |
+
.predictions-horizontal {{
|
| 488 |
+
flex-direction: column;
|
| 489 |
+
gap: 15px;
|
| 490 |
+
}}
|
| 491 |
+
}}
|
| 492 |
+
@media (max-width: 900px) {{
|
| 493 |
+
.face-card {{
|
| 494 |
+
flex-direction: column;
|
| 495 |
+
}}
|
| 496 |
+
.face-image-left {{
|
| 497 |
+
width: 100%;
|
| 498 |
+
max-width: 336px;
|
| 499 |
+
margin: 0 auto;
|
| 500 |
+
}}
|
| 501 |
+
.probability-item {{
|
| 502 |
+
grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
|
| 503 |
+
}}
|
| 504 |
+
.prob-class {{
|
| 505 |
+
font-size: 0.75rem;
|
| 506 |
+
}}
|
| 507 |
+
}}
|
| 508 |
+
</style>
|
| 509 |
+
|
| 510 |
+
<div class='results-container'>
|
| 511 |
+
<h2 style='margin-top: 0;'>Classification Results <span class='face-count'>{len(face_data)} face(s)</span></h2>
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
for idx, face in enumerate(face_data):
|
| 515 |
+
pred = face['predictions']
|
| 516 |
+
face_img_base64 = pil_to_base64(face['crop_image'])
|
| 517 |
+
age = get_centroid_weighted_age(pred['age']['all_probabilities'])
|
| 518 |
+
results_html += f"""
|
| 519 |
+
<div class='face-card'>
|
| 520 |
+
<div class='face-image-left'>
|
| 521 |
+
<img src='{face_img_base64}' alt='Face {idx+1}'>
|
| 522 |
+
</div>
|
| 523 |
+
<div class='face-predictions-right'>
|
| 524 |
+
<div class='face-header'>Face {idx+1} - Detection Confidence: {face['detection_confidence']:.1%} - Centroid Age: {int(age)}</div>
|
| 525 |
+
<div class='predictions-horizontal'>
|
| 526 |
+
<div class='prediction-section'>
|
| 527 |
+
<div class='prediction-category-label'>Age</div>
|
| 528 |
+
<div class='probabilities-list'>
|
| 529 |
+
"""
|
| 530 |
+
for age_class in CLASSES[0]:
|
| 531 |
+
prob = pred['age']['all_probabilities'][age_class]
|
| 532 |
+
is_predicted = (age_class == pred['age']['predicted_class'])
|
| 533 |
+
predicted_class = 'predicted' if is_predicted else ''
|
| 534 |
+
results_html += f"""
|
| 535 |
+
<div class='probability-item {predicted_class}'>
|
| 536 |
+
<span class='prob-class'>{age_class}</span>
|
| 537 |
+
<div class='prob-bar-container'>
|
| 538 |
+
<div class='prob-bar' style='width: {prob*100}%'></div>
|
| 539 |
+
</div>
|
| 540 |
+
<span class='prob-percentage'>{prob*100:.1f}%</span>
|
| 541 |
+
</div>
|
| 542 |
+
"""
|
| 543 |
+
results_html += f"""
|
| 544 |
+
</div>
|
| 545 |
+
</div>
|
| 546 |
+
<div class='prediction-section'>
|
| 547 |
+
<div class='prediction-category-label'>Gender</div>
|
| 548 |
+
<div class='probabilities-list'>
|
| 549 |
+
"""
|
| 550 |
+
for gender_class in CLASSES[1]:
|
| 551 |
+
prob = pred['gender']['all_probabilities'][gender_class]
|
| 552 |
+
is_predicted = (gender_class == pred['gender']['predicted_class'])
|
| 553 |
+
predicted_class = 'predicted' if is_predicted else ''
|
| 554 |
+
results_html += f"""
|
| 555 |
+
<div class='probability-item {predicted_class}'>
|
| 556 |
+
<span class='prob-class'>{gender_class}</span>
|
| 557 |
+
<div class='prob-bar-container'>
|
| 558 |
+
<div class='prob-bar' style='width: {prob*100}%'></div>
|
| 559 |
+
</div>
|
| 560 |
+
<span class='prob-percentage'>{prob*100:.1f}%</span>
|
| 561 |
+
</div>
|
| 562 |
+
"""
|
| 563 |
+
results_html += """
|
| 564 |
+
</div>
|
| 565 |
+
</div>
|
| 566 |
+
<div class='prediction-section'>
|
| 567 |
+
<div class='prediction-category-label'>Emotion</div>
|
| 568 |
+
<div class='probabilities-list'>
|
| 569 |
+
"""
|
| 570 |
+
for emotion_class in CLASSES[2]:
|
| 571 |
+
prob = pred['emotion']['all_probabilities'][emotion_class]
|
| 572 |
+
is_predicted = (emotion_class == pred['emotion']['predicted_class'])
|
| 573 |
+
predicted_class = 'predicted' if is_predicted else ''
|
| 574 |
+
results_html += f"""
|
| 575 |
+
<div class='probability-item {predicted_class}'>
|
| 576 |
+
<span class='prob-class'>{emotion_class}</span>
|
| 577 |
+
<div class='prob-bar-container'>
|
| 578 |
+
<div class='prob-bar' style='width: {prob*100}%'></div>
|
| 579 |
+
</div>
|
| 580 |
+
<span class='prob-percentage'>{prob*100:.1f}%</span>
|
| 581 |
+
</div>
|
| 582 |
+
"""
|
| 583 |
+
results_html += """
|
| 584 |
+
</div>
|
| 585 |
+
</div>
|
| 586 |
+
</div>
|
| 587 |
+
</div>
|
| 588 |
+
</div>
|
| 589 |
+
"""
|
| 590 |
+
results_html += "</div>"
|
| 591 |
+
|
| 592 |
+
# --- 7. Return the annotated PIL image and HTML ---
|
| 593 |
+
return img_pil_annotated, results_html
|
| 594 |
+
|
| 595 |
+
except Exception as e:
|
| 596 |
+
traceback.print_exc()
|
| 597 |
+
return image, f"<p style='color: red;'>Error processing image: {str(e)}</p>"
|
| 598 |
+
|
| 599 |
+
def create_interface(checkpoint_list, default_checkpoint, initial_status):
|
| 600 |
+
"""Create and configure the Gradio interface."""
|
| 601 |
+
|
| 602 |
+
# Custom CSS for better styling
|
| 603 |
+
custom_css = """
|
| 604 |
+
.gradio-container {
|
| 605 |
+
font-family: 'Arial', sans-serif;
|
| 606 |
+
}
|
| 607 |
+
.output-html {
|
| 608 |
+
max-height: none !important;
|
| 609 |
+
overflow-y: auto;
|
| 610 |
+
}
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
# Create interface
|
| 614 |
+
with gr.Blocks(css=custom_css, title="Face Classification System") as demo:
|
| 615 |
+
|
| 616 |
+
with gr.Row():
|
| 617 |
+
gr.Markdown("# Face Classification System")
|
| 618 |
+
|
| 619 |
+
# --- Model Selection ---
|
| 620 |
+
with gr.Row():
|
| 621 |
+
with gr.Column(scale=3):
|
| 622 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 623 |
+
label="Select Model Checkpoint",
|
| 624 |
+
choices=checkpoint_list,
|
| 625 |
+
value=default_checkpoint,
|
| 626 |
+
)
|
| 627 |
+
with gr.Column(scale=2):
|
| 628 |
+
model_status_text = gr.Textbox(
|
| 629 |
+
label="Model Status",
|
| 630 |
+
value=initial_status,
|
| 631 |
+
interactive=False,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
# Features | Instructions
|
| 635 |
+
with gr.Row():
|
| 636 |
+
with gr.Column(scale=1):
|
| 637 |
+
gr.Markdown("""
|
| 638 |
+
### Features
|
| 639 |
+
- **Age Classification**: 9 categories (0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+) + Age estimation with weighted centroid average
|
| 640 |
+
- **Gender Classification**: M/F
|
| 641 |
+
- **Emotion Recognition**: 7 categories (Surprise, Fear, Disgust, Happy, Sad, Angry, Neutral)
|
| 642 |
+
- **Automatic Face Detection**: Detects and analyzes multiple faces
|
| 643 |
+
- **Detailed Probability Distributions**: View confidence for all classes
|
| 644 |
+
""")
|
| 645 |
+
|
| 646 |
+
with gr.Column(scale=1):
|
| 647 |
+
gr.Markdown("""
|
| 648 |
+
### Instructions
|
| 649 |
+
1. (Optional) Select a model checkpoint from the dropdown.
|
| 650 |
+
2. Upload an image or capture from webcam (or select an example below)
|
| 651 |
+
3. Click "Classify Image"
|
| 652 |
+
4. View detected faces with age, gender, and emotion predictions below
|
| 653 |
+
""")
|
| 654 |
+
|
| 655 |
+
# Upload Image | Annotated Image
|
| 656 |
+
with gr.Row():
|
| 657 |
+
with gr.Column(scale=1):
|
| 658 |
+
input_image = gr.Image(
|
| 659 |
+
label="Upload Image",
|
| 660 |
+
type="pil",
|
| 661 |
+
sources=["upload", "webcam"],
|
| 662 |
+
height=400
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
with gr.Column(scale=1):
|
| 666 |
+
output_image = gr.Image(
|
| 667 |
+
label="Annotated Image",
|
| 668 |
+
type="pil",
|
| 669 |
+
height=400
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
with gr.Row():
|
| 673 |
+
with gr.Column(scale=1):
|
| 674 |
+
analyze_btn = gr.Button(
|
| 675 |
+
"Classify Image",
|
| 676 |
+
variant="primary",
|
| 677 |
+
size="lg"
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# Examples - after button
|
| 681 |
+
# Dynamically load example images from example directory
|
| 682 |
+
example_dir = "example"
|
| 683 |
+
example_images = []
|
| 684 |
+
if os.path.exists(example_dir):
|
| 685 |
+
try:
|
| 686 |
+
example_images = [
|
| 687 |
+
os.path.join(example_dir, f)
|
| 688 |
+
for f in sorted(os.listdir(example_dir))
|
| 689 |
+
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
| 690 |
+
]
|
| 691 |
+
except Exception as e:
|
| 692 |
+
print(f"Error reading example images from {example_dir}: {e}")
|
| 693 |
+
|
| 694 |
+
if example_images:
|
| 695 |
+
gr.Markdown("### 📸 Try with example images")
|
| 696 |
+
gr.Examples(
|
| 697 |
+
examples=example_images,
|
| 698 |
+
inputs=input_image,
|
| 699 |
+
cache_examples=False
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Results section - full width below everything
|
| 703 |
+
with gr.Row():
|
| 704 |
+
with gr.Column(scale=1):
|
| 705 |
+
output_html = gr.HTML(
|
| 706 |
+
label="Classification Results",
|
| 707 |
+
elem_classes="output-html"
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Event handlers
|
| 711 |
+
analyze_btn.click(
|
| 712 |
+
fn=process_image,
|
| 713 |
+
inputs=[input_image, checkpoint_dropdown], # Pass dropdown value
|
| 714 |
+
outputs=[output_image, output_html]
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
checkpoint_dropdown.change(
|
| 718 |
+
fn=load_model_and_update_status,
|
| 719 |
+
inputs=[checkpoint_dropdown],
|
| 720 |
+
outputs=[model_status_text]
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
return demo
|
| 725 |
+
|
| 726 |
+
# === Main Application Startup ===
|
| 727 |
+
|
| 728 |
+
# Initialize for Hugging Face Spaces (module-level)
|
| 729 |
+
print("="*60)
|
| 730 |
+
print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
|
| 731 |
+
print("="*60)
|
| 732 |
+
|
| 733 |
+
# --- 1. Scan for models first ---
|
| 734 |
+
checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
|
| 735 |
+
|
| 736 |
+
if not checkpoint_list:
|
| 737 |
+
print(f"CRITICAL: No checkpoints found in {CHECKPOINTS_DIR}. App may not function.")
|
| 738 |
+
else:
|
| 739 |
+
print(f"Found checkpoints: {len(checkpoint_list)} file(s).")
|
| 740 |
+
print(f"Default checkpoint: {default_checkpoint}")
|
| 741 |
+
|
| 742 |
+
# --- 2. Try to initialize default model ---
|
| 743 |
+
initial_status_msg = "No default model found. Please select one."
|
| 744 |
+
if default_checkpoint:
|
| 745 |
+
print(f"\nInitializing default model: {default_checkpoint}")
|
| 746 |
+
# This will load the model AND set current_ckpt_dir
|
| 747 |
+
initial_status_msg = load_model_and_update_status(default_checkpoint)
|
| 748 |
+
print(initial_status_msg)
|
| 749 |
+
else:
|
| 750 |
+
print("⚠ Warning: No default model to load.")
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# --- 3. Create interface FIRST (so it shows even if model fails) ---
|
| 754 |
+
print("Creating Gradio interface...")
|
| 755 |
+
demo = create_interface(checkpoint_list, default_checkpoint, initial_status_msg)
|
| 756 |
+
print("✓ Interface created successfully!")
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
if __name__ == "__main__":
|
| 760 |
+
import argparse
|
| 761 |
+
|
| 762 |
+
parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
|
| 763 |
+
parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
|
| 764 |
+
help="Path to the checkpoint directory (overridden by UI)")
|
| 765 |
+
parser.add_argument("--detection_confidence", type=float, default=0.5,
|
| 766 |
+
help="Confidence threshold for face detection")
|
| 767 |
+
parser.add_argument("--port", type=int, default=7860,
|
| 768 |
+
help="Port to run the Gradio app")
|
| 769 |
+
parser.add_argument("--share", action="store_true",
|
| 770 |
+
help="Create a public share link")
|
| 771 |
+
parser.add_argument("--server_name", type=str, default="0.0.0.0",
|
| 772 |
+
help="Server name/IP to bind to")
|
| 773 |
+
args = parser.parse_args()
|
| 774 |
+
|
| 775 |
+
# Update global config if args are provided (though UI dropdown is primary)
|
| 776 |
+
CHECKPOINTS_DIR = args.ckpt_dir
|
| 777 |
+
# Note: detection_confidence is passed to init_model, so it's handled.
|
| 778 |
+
|
| 779 |
+
print(f"\nLaunching server on {args.server_name}:{args.port}")
|
| 780 |
+
print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
|
| 781 |
+
print("="*60)
|
| 782 |
+
|
| 783 |
+
demo.launch(
|
| 784 |
+
share=args.share,
|
| 785 |
+
server_name=args.server_name,
|
| 786 |
+
server_port=args.port,
|
| 787 |
+
show_error=True,
|
| 788 |
+
)
|
core/args.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Type, TypeVar
|
| 5 |
+
|
| 6 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger()
|
| 9 |
+
|
| 10 |
+
T = TypeVar("T")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def set_struct_recursively(cfg, strict: bool = True):
|
| 14 |
+
# Set struct mode for the current level
|
| 15 |
+
OmegaConf.set_struct(cfg, strict)
|
| 16 |
+
|
| 17 |
+
# Traverse through nested dictionaries and lists
|
| 18 |
+
if isinstance(cfg, DictConfig):
|
| 19 |
+
for key, value in cfg.items():
|
| 20 |
+
if isinstance(value, (DictConfig, ListConfig)):
|
| 21 |
+
set_struct_recursively(value, strict)
|
| 22 |
+
elif isinstance(cfg, ListConfig):
|
| 23 |
+
for item in cfg:
|
| 24 |
+
if isinstance(item, (DictConfig, ListConfig)):
|
| 25 |
+
set_struct_recursively(item, strict)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def flatten_dict(d, parent_key="", sep="_"):
|
| 29 |
+
items = []
|
| 30 |
+
for k, v in d.items():
|
| 31 |
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
| 32 |
+
if isinstance(v, dict):
|
| 33 |
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
| 34 |
+
else:
|
| 35 |
+
items.append((new_key, v))
|
| 36 |
+
return dict(items)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
|
| 40 |
+
"""
|
| 41 |
+
Converts a dictionary to a dataclass instance, recursively for nested structures.
|
| 42 |
+
"""
|
| 43 |
+
base = OmegaConf.structured(cls())
|
| 44 |
+
OmegaConf.set_struct(base, strict)
|
| 45 |
+
override = OmegaConf.create(data)
|
| 46 |
+
return OmegaConf.to_object(OmegaConf.merge(base, override))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def dataclass_to_dict(dataclass_instance: T) -> dict:
|
| 50 |
+
"""
|
| 51 |
+
Converts a dataclass instance to a dictionary, recursively for nested structures.
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(dataclass_instance, dict):
|
| 54 |
+
return dataclass_instance
|
| 55 |
+
|
| 56 |
+
return OmegaConf.to_container(
|
| 57 |
+
OmegaConf.structured(dataclass_instance), resolve=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_config_file(config_file, dataclass_cls: Type[T]) -> T:
|
| 62 |
+
config = OmegaConf.to_container(OmegaConf.load(config_file), resolve=True)
|
| 63 |
+
return dataclass_from_dict(dataclass_cls, config)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def dump_config(config, path, log_config=True):
|
| 67 |
+
yaml_dump = OmegaConf.to_yaml(OmegaConf.structured(config))
|
| 68 |
+
with open(path, "w") as f:
|
| 69 |
+
if log_config:
|
| 70 |
+
logger.info("Using the following config for this run:")
|
| 71 |
+
logger.info(yaml_dump)
|
| 72 |
+
f.write(yaml_dump)
|
core/checkpoint.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torch.distributed.checkpoint as dcp
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.optim.optimizer
|
| 16 |
+
from omegaconf import OmegaConf
|
| 17 |
+
from torch.distributed._tensor import DeviceMesh
|
| 18 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
| 19 |
+
from torch.distributed.checkpoint.state_dict import (get_model_state_dict,
|
| 20 |
+
get_state_dict,
|
| 21 |
+
set_state_dict)
|
| 22 |
+
|
| 23 |
+
from core.distributed import get_is_master
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger("CHECKPOINT")
|
| 26 |
+
|
| 27 |
+
FOLDER_NAME = "{:010d}"
|
| 28 |
+
RE_FOLDER = r"\d{10}"
|
| 29 |
+
|
| 30 |
+
RE_CKPT = r"__\d_\d\.distcp"
|
| 31 |
+
|
| 32 |
+
CONSOLIDATE_FOLDER = "consolidated"
|
| 33 |
+
CONSOLIDATE_NAME = "consolidated.pth"
|
| 34 |
+
|
| 35 |
+
CONFIG_NAME = "params.json"
|
| 36 |
+
TRAIN_STATE_NAME = "train_state_{:05d}.json"
|
| 37 |
+
RE_DIGITS = re.compile(r"\d+")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class SaveEvery:
|
| 42 |
+
every: int = 1000
|
| 43 |
+
keep: int = 0
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class CheckpointArgs:
|
| 48 |
+
dump: SaveEvery = field(default_factory=SaveEvery)
|
| 49 |
+
eval: SaveEvery = field(default_factory=SaveEvery)
|
| 50 |
+
path: Optional[str] = None
|
| 51 |
+
init_ckpt_path: Optional[str] = None
|
| 52 |
+
vision_model_path: Optional[str] = None
|
| 53 |
+
is_consolidated_model: bool = False
|
| 54 |
+
continue_training_from_init: bool = False
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _get_key_step(name: str):
|
| 58 |
+
return int(re.findall(RE_DIGITS, name)[-1])
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def consolidate_checkpoints(ckpt_dir: str):
|
| 62 |
+
"""
|
| 63 |
+
Consolidates all FSDP checkpoints in a directory to a single file
|
| 64 |
+
Consolidate checkpoint is saved in a subdirectory of ckpt_dir
|
| 65 |
+
|
| 66 |
+
Parameters:
|
| 67 |
+
ckpt_dir: str - path to the directory containing the checkpoints
|
| 68 |
+
|
| 69 |
+
Returns the path to the consolidated checkpoint
|
| 70 |
+
"""
|
| 71 |
+
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
|
| 72 |
+
if not (consolidate_path / CONSOLIDATE_NAME).exists():
|
| 73 |
+
consolidate_path.mkdir(exist_ok=True)
|
| 74 |
+
logger.info(f"Consolidating to: {str(consolidate_path)}")
|
| 75 |
+
dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
|
| 76 |
+
(consolidate_path / CONFIG_NAME).write_text(
|
| 77 |
+
(Path(ckpt_dir) / CONFIG_NAME).read_text()
|
| 78 |
+
)
|
| 79 |
+
logger.info("Consolidated !")
|
| 80 |
+
return consolidate_path
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_from_checkpoint(
|
| 84 |
+
ckpt_dir: str,
|
| 85 |
+
model: nn.Module,
|
| 86 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 87 |
+
model_key: str = "model",
|
| 88 |
+
optim_key: str = "optim",
|
| 89 |
+
):
|
| 90 |
+
if not (Path(ckpt_dir) / ".metadata").exists():
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
state_dict = {}
|
| 96 |
+
if optimizer is not None:
|
| 97 |
+
state_dict[model_key], state_dict[optim_key] = get_state_dict(model, optimizer)
|
| 98 |
+
else:
|
| 99 |
+
state_dict[model_key] = get_model_state_dict(model)
|
| 100 |
+
if model_key == "": # If only loading a model directly, the key should be empty
|
| 101 |
+
state_dict = state_dict.pop(model_key)
|
| 102 |
+
|
| 103 |
+
dcp.load(state_dict, checkpoint_id=ckpt_dir)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class CheckpointManager:
|
| 107 |
+
def __init__(self, args: CheckpointArgs):
|
| 108 |
+
self.path = args.path
|
| 109 |
+
self.dump_every = args.dump
|
| 110 |
+
self.eval_every = args.eval
|
| 111 |
+
self.init_ckpt_path = args.init_ckpt_path
|
| 112 |
+
self.continue_training_from_init = args.continue_training_from_init
|
| 113 |
+
|
| 114 |
+
assert os.path.exists(
|
| 115 |
+
self.path
|
| 116 |
+
), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
|
| 117 |
+
|
| 118 |
+
self.existing_saves = self.get_existing_saves()
|
| 119 |
+
|
| 120 |
+
def get_existing_saves(self) -> List[Path]:
|
| 121 |
+
folders = [
|
| 122 |
+
p
|
| 123 |
+
for p in Path(self.path).iterdir()
|
| 124 |
+
if p.is_dir() and re.match(RE_FOLDER, p.name)
|
| 125 |
+
]
|
| 126 |
+
folders.sort(key=lambda p: _get_key_step(p.name))
|
| 127 |
+
return folders
|
| 128 |
+
|
| 129 |
+
def clean_up(self):
|
| 130 |
+
logger.info("Cleaning up checkpoints...")
|
| 131 |
+
dump_folders = []
|
| 132 |
+
eval_folders = []
|
| 133 |
+
other_folders = []
|
| 134 |
+
for p in self.existing_saves:
|
| 135 |
+
is_dump = _get_key_step(p.name) % self.dump_every.every == 0
|
| 136 |
+
is_eval = _get_key_step(p.name) % self.eval_every.every == 0
|
| 137 |
+
if is_dump:
|
| 138 |
+
dump_folders.append(p)
|
| 139 |
+
if is_eval:
|
| 140 |
+
eval_folders.append(p)
|
| 141 |
+
if not (is_dump or is_eval):
|
| 142 |
+
other_folders.append(p)
|
| 143 |
+
|
| 144 |
+
logger.info(f"Dump folders: {dump_folders}")
|
| 145 |
+
logger.info(f"Eval folders: {eval_folders}")
|
| 146 |
+
logger.info(f"Other folders: {other_folders}")
|
| 147 |
+
|
| 148 |
+
if self.dump_every.keep > 0:
|
| 149 |
+
dump_folders = dump_folders[-self.dump_every.keep :]
|
| 150 |
+
if self.eval_every.keep > 0:
|
| 151 |
+
eval_folders = eval_folders[-self.eval_every.keep :]
|
| 152 |
+
|
| 153 |
+
folder_to_keep = set(other_folders + dump_folders + eval_folders)
|
| 154 |
+
folder_to_remove = set(self.existing_saves) - folder_to_keep
|
| 155 |
+
|
| 156 |
+
logger.info(f"Removing folders: {folder_to_remove}")
|
| 157 |
+
|
| 158 |
+
if dist.get_rank() == 0:
|
| 159 |
+
for folder in folder_to_remove:
|
| 160 |
+
for file in folder.iterdir():
|
| 161 |
+
if file.is_file():
|
| 162 |
+
file.unlink()
|
| 163 |
+
elif file.is_dir():
|
| 164 |
+
assert file.name in [CONSOLIDATE_FOLDER]
|
| 165 |
+
for f in file.iterdir():
|
| 166 |
+
f.unlink()
|
| 167 |
+
file.rmdir()
|
| 168 |
+
folder.rmdir()
|
| 169 |
+
|
| 170 |
+
dist.barrier()
|
| 171 |
+
|
| 172 |
+
self.existing_saves = list(folder_to_keep)
|
| 173 |
+
self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
|
| 174 |
+
|
| 175 |
+
def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
|
| 176 |
+
path = None
|
| 177 |
+
for p in reversed(self.existing_saves):
|
| 178 |
+
if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
|
| 179 |
+
path = p
|
| 180 |
+
break
|
| 181 |
+
return path
|
| 182 |
+
|
| 183 |
+
def _create_folder(self, base_path: Path, folder_name: str) -> Path:
|
| 184 |
+
folder = base_path / folder_name
|
| 185 |
+
if get_is_master():
|
| 186 |
+
folder.mkdir(parents=False, exist_ok=True)
|
| 187 |
+
if dist.is_initialized():
|
| 188 |
+
dist.barrier()
|
| 189 |
+
return folder
|
| 190 |
+
|
| 191 |
+
def _get_dp_tp_mesh(
|
| 192 |
+
self, device_mesh: Optional[DeviceMesh] = None
|
| 193 |
+
) -> Tuple[int, int]:
|
| 194 |
+
dp_rank = 0
|
| 195 |
+
tp_rank = 0
|
| 196 |
+
if device_mesh is not None:
|
| 197 |
+
if "dp_replicate" in device_mesh.mesh_dim_names:
|
| 198 |
+
dp_rank = device_mesh.get_local_rank("dp_replicate")
|
| 199 |
+
if "dp_shard" in device_mesh.mesh_dim_names:
|
| 200 |
+
dp_rank = dp_rank * device_mesh[
|
| 201 |
+
"dp_replicate"
|
| 202 |
+
].size() + device_mesh.get_local_rank("dp_shard")
|
| 203 |
+
if "tp" in device_mesh.mesh_dim_names:
|
| 204 |
+
tp_rank = device_mesh.get_local_rank("tp")
|
| 205 |
+
return dp_rank, tp_rank
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def get_state_dict(
|
| 209 |
+
self,
|
| 210 |
+
model,
|
| 211 |
+
optimizer,
|
| 212 |
+
):
|
| 213 |
+
model_sd, optim_sd = get_state_dict(model, optimizer)
|
| 214 |
+
return {"model": model_sd, "optim": optim_sd}
|
| 215 |
+
|
| 216 |
+
def save(
|
| 217 |
+
self,
|
| 218 |
+
model,
|
| 219 |
+
optimizer,
|
| 220 |
+
train_state,
|
| 221 |
+
config,
|
| 222 |
+
device_mesh: Optional[DeviceMesh] = None,
|
| 223 |
+
) -> bool:
|
| 224 |
+
|
| 225 |
+
# When creating directory check if only rank0 or is there other solution
|
| 226 |
+
path = Path(self.path)
|
| 227 |
+
curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
|
| 228 |
+
logger.info(f"Saving to: {str(curr_save_dir)}")
|
| 229 |
+
|
| 230 |
+
if dist.is_initialized():
|
| 231 |
+
dist.barrier()
|
| 232 |
+
|
| 233 |
+
logger.info("Saving...")
|
| 234 |
+
state_dict = self.get_state_dict(model, optimizer)
|
| 235 |
+
dcp.save(state_dict, checkpoint_id=curr_save_dir)
|
| 236 |
+
logger.info("State dict saved!")
|
| 237 |
+
|
| 238 |
+
if dist.is_initialized():
|
| 239 |
+
dist.barrier()
|
| 240 |
+
|
| 241 |
+
if get_is_master():
|
| 242 |
+
with open(curr_save_dir / CONFIG_NAME, "w") as f:
|
| 243 |
+
json.dump(
|
| 244 |
+
OmegaConf.to_container(OmegaConf.structured(config), resolve=True),
|
| 245 |
+
f,
|
| 246 |
+
indent=4,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Add json dump here
|
| 250 |
+
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
| 251 |
+
if tp_rank == 0:
|
| 252 |
+
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
| 253 |
+
logger.info(
|
| 254 |
+
f"Saving train state to: {str(curr_save_dir / train_state_name)}"
|
| 255 |
+
)
|
| 256 |
+
# logger.info(f"train_state.state_dict()={train_state.state_dict()}")
|
| 257 |
+
with open(curr_save_dir / train_state_name, "w") as f:
|
| 258 |
+
json.dump(train_state.state_dict(), f)
|
| 259 |
+
logger.info("Train state saved !")
|
| 260 |
+
|
| 261 |
+
self.existing_saves.append(curr_save_dir)
|
| 262 |
+
|
| 263 |
+
self.clean_up()
|
| 264 |
+
|
| 265 |
+
if dist.is_initialized():
|
| 266 |
+
dist.barrier()
|
| 267 |
+
return True
|
| 268 |
+
|
| 269 |
+
@torch.no_grad()
|
| 270 |
+
def load(
|
| 271 |
+
self,
|
| 272 |
+
model: nn.Module,
|
| 273 |
+
optimizer,
|
| 274 |
+
train_state,
|
| 275 |
+
device_mesh: DeviceMesh,
|
| 276 |
+
path: Optional[Path] = None,
|
| 277 |
+
):
|
| 278 |
+
dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
|
| 279 |
+
# Loading tries to load the provided path, if not available the last saved step and finally from the init path
|
| 280 |
+
path = path or self.get_last_step_path(dp_rank=dp_rank)
|
| 281 |
+
# If none of those are available don't do anything
|
| 282 |
+
if path is None:
|
| 283 |
+
# If no checkpoints exist do nothing
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
# Only load train state if it's provided, the files exist and we're not loading from init path
|
| 287 |
+
train_state_name = TRAIN_STATE_NAME.format(dp_rank)
|
| 288 |
+
logger.info("Reloading train state")
|
| 289 |
+
with open(path / train_state_name, "r") as f:
|
| 290 |
+
train_state_dict = json.load(f)
|
| 291 |
+
train_state.load_state_dict(train_state_dict)
|
| 292 |
+
logger.info("Train state reloaded")
|
| 293 |
+
|
| 294 |
+
logger.info(f"Loading from: {str(path)}")
|
| 295 |
+
state_dict = self.get_state_dict(
|
| 296 |
+
model=model,
|
| 297 |
+
optimizer=optimizer,
|
| 298 |
+
)
|
| 299 |
+
dcp.load(state_dict, checkpoint_id=path)
|
| 300 |
+
logger.info("State dict loaded.")
|
| 301 |
+
|
| 302 |
+
logger.info("Reloading model and optim")
|
| 303 |
+
|
| 304 |
+
set_state_dict(
|
| 305 |
+
model,
|
| 306 |
+
optimizer,
|
| 307 |
+
model_state_dict=state_dict["model"],
|
| 308 |
+
optim_state_dict=state_dict["optim"],
|
| 309 |
+
)
|
| 310 |
+
logger.info("Model and optim reloaded")
|
| 311 |
+
|
| 312 |
+
@classmethod
|
| 313 |
+
def instantiate_and_make_dir(cls, args: CheckpointArgs):
|
| 314 |
+
if get_is_master():
|
| 315 |
+
os.makedirs(args.path, exist_ok=True)
|
| 316 |
+
dist.barrier()
|
| 317 |
+
|
| 318 |
+
return cls(args)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def get_consolidated_ckpt_path(ckpt_dir: Path, mp_rank: int = 0, mp_size: int = 1):
|
| 322 |
+
if mp_size == 1:
|
| 323 |
+
assert mp_rank == 0
|
| 324 |
+
no_rank_path = ckpt_dir / "consolidated.pth"
|
| 325 |
+
if no_rank_path.exists():
|
| 326 |
+
return no_rank_path
|
| 327 |
+
return ckpt_dir / f"consolidated.{mp_rank:02d}.pth"
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def load_consolidated_checkpoint(
|
| 331 |
+
model: nn.Module,
|
| 332 |
+
consolidated_path: str,
|
| 333 |
+
vision_model_path: Optional[str] = None,
|
| 334 |
+
):
|
| 335 |
+
"""
|
| 336 |
+
Loads a consolidated checkpoint into the model.
|
| 337 |
+
This version supports both:
|
| 338 |
+
- a single file named 'consolidated.pth'
|
| 339 |
+
- multiple parts named like 'consolidated.00.pth', 'consolidated.01.pth', etc.
|
| 340 |
+
"""
|
| 341 |
+
ckpt_path = Path(consolidated_path)
|
| 342 |
+
cp_file = get_consolidated_ckpt_path(ckpt_path, mp_rank=0, mp_size=1)
|
| 343 |
+
if cp_file.exists():
|
| 344 |
+
# Use the single file
|
| 345 |
+
st_dict = torch.load(cp_file, weights_only=True)
|
| 346 |
+
if "model" in st_dict:
|
| 347 |
+
st_dict = st_dict["model"]
|
| 348 |
+
else:
|
| 349 |
+
# Fall back to multi-part consolidated files (e.g. consolidated.00.pth, consolidated.01.pth, …)
|
| 350 |
+
checkpoint_files = sorted(ckpt_path.glob("consolidated.*.pth"))
|
| 351 |
+
if not checkpoint_files:
|
| 352 |
+
raise FileNotFoundError(
|
| 353 |
+
f"No consolidated checkpoint file found in {ckpt_path}."
|
| 354 |
+
)
|
| 355 |
+
st_dict = {}
|
| 356 |
+
for ckpt_file in checkpoint_files:
|
| 357 |
+
part = torch.load(ckpt_file, weights_only=True)
|
| 358 |
+
# If the checkpoint part is wrapped with "model", unwrap it
|
| 359 |
+
if "model" in part:
|
| 360 |
+
part = part["model"]
|
| 361 |
+
# Merge the state dicts (assumes the keys are all unique or will correctly overwrite)
|
| 362 |
+
st_dict.update(part)
|
| 363 |
+
|
| 364 |
+
model.vision_projector.init_tensors()
|
| 365 |
+
model.vision_model.init_tensors()
|
| 366 |
+
model.rope_embeddings.reset_parameters()
|
| 367 |
+
|
| 368 |
+
if vision_model_path is not None:
|
| 369 |
+
model.vision_model.load_ckpt(vision_model_path)
|
| 370 |
+
|
| 371 |
+
missing_keys, unexpected_keys = model.load_state_dict(st_dict, strict=False)
|
| 372 |
+
missing_keys = [k for k in missing_keys if "tied_module.weight" not in k]
|
| 373 |
+
if vision_model_path is not None:
|
| 374 |
+
# vision_model is already loaded separately
|
| 375 |
+
missing_keys = [k for k in missing_keys if "vision_model." not in k]
|
| 376 |
+
if len(missing_keys) > 0:
|
| 377 |
+
logger.warning(f"Missing keys when reloading: {missing_keys}")
|
| 378 |
+
if len(unexpected_keys) > 0:
|
| 379 |
+
logger.warning(f"Unexpected keys when reloading: {unexpected_keys}")
|
core/transformer.py
ADDED
|
@@ -0,0 +1,646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
| 12 |
+
flex_attention)
|
| 13 |
+
from xformers.ops import AttentionBias, fmha
|
| 14 |
+
|
| 15 |
+
from core import probe
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class InitStdFactor(Enum):
|
| 19 |
+
DISABLED = "disabled" # Init std is divided by 1.0
|
| 20 |
+
GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers)
|
| 21 |
+
CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth)
|
| 22 |
+
DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class BaseTransformerArgs:
|
| 27 |
+
dim: int = 512
|
| 28 |
+
n_layers: int = 8
|
| 29 |
+
head_dim: Optional[int] = None
|
| 30 |
+
n_heads: Optional[int] = None
|
| 31 |
+
n_kv_heads: Optional[int] = None
|
| 32 |
+
|
| 33 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 34 |
+
|
| 35 |
+
multiple_of: int = 256
|
| 36 |
+
|
| 37 |
+
norm_eps: float = 1e-5
|
| 38 |
+
|
| 39 |
+
rope_theta: float = 10000.0
|
| 40 |
+
|
| 41 |
+
old_context_len: int = 8192
|
| 42 |
+
rope_scale_factor: int = 1
|
| 43 |
+
low_freq_factor: int = 1
|
| 44 |
+
high_freq_factor: int = 32
|
| 45 |
+
|
| 46 |
+
init_base_std: Optional[float] = None
|
| 47 |
+
init_std_factor: str = "disabled"
|
| 48 |
+
|
| 49 |
+
max_seqlen: int = 1024
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def cross_entropy(pred, target, **kwargs):
|
| 53 |
+
return F.nll_loss(
|
| 54 |
+
F.log_softmax(pred.flatten(end_dim=-2).float(), -1),
|
| 55 |
+
target.flatten(end_dim=-1),
|
| 56 |
+
**kwargs,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
|
| 61 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
| 62 |
+
assert dim == 2, "Only dim=2 is supported. Check the implementation for other dims."
|
| 63 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
| 64 |
+
if n_rep == 1:
|
| 65 |
+
return x
|
| 66 |
+
return (
|
| 67 |
+
x[:, :, :, None, :]
|
| 68 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
| 69 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
|
| 74 |
+
"""
|
| 75 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 76 |
+
|
| 77 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 78 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
|
| 82 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 83 |
+
seq_dim (int): Sequence dimension index.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 87 |
+
"""
|
| 88 |
+
ndim = x.ndim
|
| 89 |
+
assert 0 <= seq_dim < ndim
|
| 90 |
+
assert freqs_cis.shape == (
|
| 91 |
+
x.shape[seq_dim],
|
| 92 |
+
x.shape[-3],
|
| 93 |
+
2,
|
| 94 |
+
2,
|
| 95 |
+
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
|
| 96 |
+
shape = [
|
| 97 |
+
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
|
| 98 |
+
] + [2, 2]
|
| 99 |
+
return freqs_cis.view(*shape)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def apply_rotary_emb(
|
| 103 |
+
xq: torch.Tensor,
|
| 104 |
+
xk: torch.Tensor,
|
| 105 |
+
seq_dim: int,
|
| 106 |
+
freqs_cis: torch.Tensor,
|
| 107 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 108 |
+
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
| 109 |
+
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
|
| 110 |
+
freqs_cis = reshape_for_broadcast(
|
| 111 |
+
freqs_cis, xq_, seq_dim
|
| 112 |
+
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
|
| 113 |
+
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
|
| 114 |
+
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
|
| 115 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def causal_mask(b, h, q_idx, kv_idx):
|
| 119 |
+
return q_idx >= kv_idx
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def lengths_to_start_ids(lengths):
|
| 123 |
+
doc_start = lengths.cumsum(0)
|
| 124 |
+
doc_start = doc_start.roll(1)
|
| 125 |
+
doc_start[0] = 0
|
| 126 |
+
return doc_start
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def lengths_to_local_ids(lengths):
|
| 130 |
+
assert lengths.ndim == 1
|
| 131 |
+
nb_seqs = lengths.size(0)
|
| 132 |
+
total_seqlen = lengths.sum()
|
| 133 |
+
# This gives the document id of each token
|
| 134 |
+
doc_id = torch.repeat_interleave(lengths)
|
| 135 |
+
# Compute document start for each document
|
| 136 |
+
doc_start = lengths_to_start_ids(lengths)
|
| 137 |
+
# Compute document start for each token
|
| 138 |
+
doc_start = doc_start[doc_id]
|
| 139 |
+
# Compute the position of each token within each document
|
| 140 |
+
tok_id = torch.arange(total_seqlen, device=lengths.device) - doc_start
|
| 141 |
+
|
| 142 |
+
return doc_id, tok_id
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def generate_doc_mask_mod(
|
| 146 |
+
mask_mod: _mask_mod_signature,
|
| 147 |
+
lengths: torch.Tensor,
|
| 148 |
+
kv_lengths: Optional[torch.Tensor] = None,
|
| 149 |
+
) -> _mask_mod_signature:
|
| 150 |
+
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
|
| 151 |
+
format.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
mask_mod: The mask mod to apply to the documents
|
| 155 |
+
lengths: Lengths of each document
|
| 156 |
+
|
| 157 |
+
Note:
|
| 158 |
+
What is the sequence stacked format? When assembling batches of inputs, we
|
| 159 |
+
take multiple sequences and stack them together to form 1 large sequence. We then
|
| 160 |
+
use masking to ensure that the attention scores are only applied to tokens within
|
| 161 |
+
the same document.
|
| 162 |
+
|
| 163 |
+
Example:
|
| 164 |
+
|
| 165 |
+
- Square mask
|
| 166 |
+
doc_mask lengths
|
| 167 |
+
a a b b b c c 2 3 2
|
| 168 |
+
a 1 0 0 0 0 0 0
|
| 169 |
+
a 1 1 0 0 0 0 0
|
| 170 |
+
b 0 0 1 0 0 0 0
|
| 171 |
+
b 0 0 1 1 0 0 0
|
| 172 |
+
b 0 0 1 1 1 0 0
|
| 173 |
+
c 0 0 0 0 0 1 0
|
| 174 |
+
c 0 0 0 0 0 1 1
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
kv_lengths = kv_lengths if kv_lengths is not None else lengths
|
| 178 |
+
q_document_id, q_token_id = lengths_to_local_ids(lengths)
|
| 179 |
+
kv_document_id, kv_token_id = lengths_to_local_ids(kv_lengths)
|
| 180 |
+
q_max_idx = lengths.sum() - 1
|
| 181 |
+
kv_max_idx = kv_lengths.sum() - 1
|
| 182 |
+
|
| 183 |
+
def doc_mask_mod(b, h, q_idx, kv_idx):
|
| 184 |
+
q_idx_cap = torch.minimum(q_max_idx, q_idx)
|
| 185 |
+
kv_idx_cap = torch.minimum(kv_max_idx, kv_idx)
|
| 186 |
+
valid_idx = (q_idx <= q_max_idx) & (kv_idx <= kv_max_idx)
|
| 187 |
+
same_doc = q_document_id[q_idx_cap] == kv_document_id[kv_idx_cap]
|
| 188 |
+
q_logical = q_token_id[q_idx_cap]
|
| 189 |
+
kv_logical = kv_token_id[kv_idx_cap]
|
| 190 |
+
inner_mask = mask_mod(b, h, q_logical, kv_logical)
|
| 191 |
+
return same_doc & inner_mask & valid_idx
|
| 192 |
+
|
| 193 |
+
return doc_mask_mod
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Rotary embedding as in xformer, see if torchtrain implementation is not better. Also might be usefull to make it work with batch*seqlen collapsed.
|
| 197 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 198 |
+
"""
|
| 199 |
+
RotaryEmbedding Module
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(
|
| 203 |
+
self,
|
| 204 |
+
theta: float,
|
| 205 |
+
head_dim: int,
|
| 206 |
+
max_seqlen: int = 1024,
|
| 207 |
+
scale_factor: int = 1,
|
| 208 |
+
low_freq_factor: int = 1,
|
| 209 |
+
high_freq_factor: int = 32,
|
| 210 |
+
old_context_len: int = 8192,
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
|
| 214 |
+
self.theta = theta
|
| 215 |
+
self.head_dim = head_dim
|
| 216 |
+
self.max_seqlen = max_seqlen
|
| 217 |
+
self.scale_factor = scale_factor
|
| 218 |
+
self.low_freq_factor = low_freq_factor
|
| 219 |
+
self.high_freq_factor = high_freq_factor
|
| 220 |
+
self.old_context_len = old_context_len
|
| 221 |
+
if scale_factor != 1:
|
| 222 |
+
self.low_freq_wavelen = old_context_len / low_freq_factor
|
| 223 |
+
self.high_freq_wavelen = old_context_len / high_freq_factor
|
| 224 |
+
assert self.low_freq_wavelen >= self.high_freq_wavelen
|
| 225 |
+
|
| 226 |
+
def reset_parameters(self):
|
| 227 |
+
self.register_buffer(
|
| 228 |
+
"freqs_cis",
|
| 229 |
+
self.precompute_freqs_cis(
|
| 230 |
+
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
| 231 |
+
),
|
| 232 |
+
persistent=False,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def apply_scaling(self, freqs):
|
| 236 |
+
if self.scale_factor == 1:
|
| 237 |
+
return freqs
|
| 238 |
+
new_freqs = []
|
| 239 |
+
for freq in freqs:
|
| 240 |
+
wavelen = 2 * math.pi / freq
|
| 241 |
+
if wavelen < self.high_freq_wavelen:
|
| 242 |
+
new_freqs.append(freq)
|
| 243 |
+
elif wavelen > self.low_freq_wavelen:
|
| 244 |
+
new_freqs.append(freq / self.scale_factor)
|
| 245 |
+
else:
|
| 246 |
+
assert self.low_freq_wavelen != self.high_freq_wavelen
|
| 247 |
+
smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
|
| 248 |
+
self.high_freq_factor - self.low_freq_factor
|
| 249 |
+
)
|
| 250 |
+
new_freqs.append(
|
| 251 |
+
(1 - smooth) * freq / self.scale_factor + smooth * freq
|
| 252 |
+
)
|
| 253 |
+
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
| 254 |
+
|
| 255 |
+
def precompute_freqs_cis(
|
| 256 |
+
self,
|
| 257 |
+
dim: int,
|
| 258 |
+
end: int,
|
| 259 |
+
theta: float = 10000.0,
|
| 260 |
+
):
|
| 261 |
+
"""
|
| 262 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 263 |
+
|
| 264 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
| 265 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 266 |
+
The returned tensor contains complex values in complex64 data type.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
dim (int): Dimension of the frequency tensor.
|
| 270 |
+
end (int): End index for precomputing frequencies.
|
| 271 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
torch.Tensor: Precomputed frequency tensor with complex exponentials.
|
| 275 |
+
"""
|
| 276 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 277 |
+
freqs = self.apply_scaling(freqs)
|
| 278 |
+
|
| 279 |
+
t = torch.arange(end, device=freqs.device)
|
| 280 |
+
freqs = torch.outer(t, freqs).float()
|
| 281 |
+
|
| 282 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 283 |
+
|
| 284 |
+
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
|
| 285 |
+
|
| 286 |
+
def forward(
|
| 287 |
+
self, seqlen: Optional[int] = None, tok_idx: Optional[torch.Tensor] = None
|
| 288 |
+
):
|
| 289 |
+
"""
|
| 290 |
+
Return freqs_cis corresponding to consecutive seqlen positions or the corresponding tok_idx positions
|
| 291 |
+
Args:
|
| 292 |
+
seqlen (int): Contiguous sequence length
|
| 293 |
+
tok_idx (torch.Tensor[int]): Position indices of each token this overrides seqlen
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis
|
| 297 |
+
"""
|
| 298 |
+
test = (seqlen is not None) or (tok_idx is not None)
|
| 299 |
+
assert test, "Should provide atleast seqlen or tok_idx"
|
| 300 |
+
if tok_idx is not None:
|
| 301 |
+
return self.freqs_cis[tok_idx]
|
| 302 |
+
elif seqlen is not None:
|
| 303 |
+
return self.freqs_cis[0:seqlen]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class RMSNorm(nn.Module):
|
| 307 |
+
"""
|
| 308 |
+
Initialize the RMSNorm normalization layer.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
dim (int): The dimension of the input tensor.
|
| 312 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 313 |
+
|
| 314 |
+
Attributes:
|
| 315 |
+
eps (float): A small value added to the denominator for numerical stability.
|
| 316 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
| 317 |
+
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.eps = eps
|
| 323 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 324 |
+
|
| 325 |
+
def _norm(self, x: torch.Tensor):
|
| 326 |
+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
|
| 327 |
+
|
| 328 |
+
def forward(self, x: torch.Tensor):
|
| 329 |
+
x = probe.log_stats(x, "resid")
|
| 330 |
+
output = self._norm(x.float())
|
| 331 |
+
return (output * self.weight.float()).type_as(x)
|
| 332 |
+
|
| 333 |
+
def reset_parameters(self):
|
| 334 |
+
torch.nn.init.ones_(self.weight) # type: ignore
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class TiedLinear(nn.Module):
|
| 338 |
+
def __init__(self, tied_module: nn.Module) -> None:
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.tied_module = tied_module
|
| 341 |
+
if not hasattr(tied_module, "weight"):
|
| 342 |
+
raise AttributeError(
|
| 343 |
+
"Provided module does not have attribute 'weight'. Please check your tied_module."
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
| 347 |
+
return F.linear(x, self.tied_module.weight)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
class Attention(nn.Module):
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
dim: int,
|
| 354 |
+
head_dim: int,
|
| 355 |
+
n_heads: int,
|
| 356 |
+
n_kv_heads: int,
|
| 357 |
+
rope_theta: float,
|
| 358 |
+
):
|
| 359 |
+
super().__init__()
|
| 360 |
+
|
| 361 |
+
self.dim = dim
|
| 362 |
+
self.head_dim = head_dim
|
| 363 |
+
self.rope_theta = rope_theta
|
| 364 |
+
|
| 365 |
+
self.n_heads = n_heads
|
| 366 |
+
self.n_kv_heads = n_kv_heads
|
| 367 |
+
self.heads_per_group = self.n_heads // self.n_kv_heads
|
| 368 |
+
|
| 369 |
+
self.wq = nn.Linear(
|
| 370 |
+
dim,
|
| 371 |
+
n_heads * head_dim,
|
| 372 |
+
bias=False,
|
| 373 |
+
)
|
| 374 |
+
self.wk = nn.Linear(
|
| 375 |
+
dim,
|
| 376 |
+
n_kv_heads * head_dim,
|
| 377 |
+
bias=False,
|
| 378 |
+
)
|
| 379 |
+
self.wv = nn.Linear(
|
| 380 |
+
dim,
|
| 381 |
+
n_kv_heads * head_dim,
|
| 382 |
+
bias=False,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
self.wo = nn.Linear(
|
| 386 |
+
n_heads * head_dim,
|
| 387 |
+
dim,
|
| 388 |
+
bias=False,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def forward(
|
| 392 |
+
self,
|
| 393 |
+
x: torch.Tensor,
|
| 394 |
+
freq_cis: torch.Tensor,
|
| 395 |
+
tok_idx: Optional[torch.Tensor] = None,
|
| 396 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
| 397 |
+
attn_impl: str = "sdpa",
|
| 398 |
+
) -> torch.Tensor:
|
| 399 |
+
# B S D
|
| 400 |
+
bsz, seq_len, dim = x.shape
|
| 401 |
+
xq = self.wq(x.view_as(x))
|
| 402 |
+
xk = self.wk(x.view_as(x))
|
| 403 |
+
xv = self.wv(x.view_as(x))
|
| 404 |
+
|
| 405 |
+
output_shape = xq.shape
|
| 406 |
+
# B S D -> B S H D
|
| 407 |
+
xq = xq.view(bsz, seq_len, self.n_heads, self.head_dim)
|
| 408 |
+
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
| 409 |
+
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
|
| 410 |
+
|
| 411 |
+
xq, xk = apply_rotary_emb(xq, xk, 1, freq_cis[0:seq_len])
|
| 412 |
+
|
| 413 |
+
# This condition helps us be easily compatible
|
| 414 |
+
# with inference by adding a pluggable KVCache
|
| 415 |
+
if hasattr(self, "kv_cache"):
|
| 416 |
+
xk, xv = self.kv_cache.update(xk, xv, tok_idx)
|
| 417 |
+
|
| 418 |
+
xk = repeat_kv(xk, self.heads_per_group, dim=2)
|
| 419 |
+
xv = repeat_kv(xv, self.heads_per_group, dim=2)
|
| 420 |
+
|
| 421 |
+
if attn_impl == "flex_attention":
|
| 422 |
+
assert mask is None or isinstance(mask, BlockMask)
|
| 423 |
+
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
| 424 |
+
output = flex_attention(xq, xk, xv, block_mask=mask)
|
| 425 |
+
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
| 426 |
+
|
| 427 |
+
elif attn_impl == "fmha":
|
| 428 |
+
assert mask is None or isinstance(mask, AttentionBias)
|
| 429 |
+
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
|
| 430 |
+
# This uses B S H D instead of B H S D of pytorch
|
| 431 |
+
|
| 432 |
+
elif attn_impl == "sdpa":
|
| 433 |
+
xq, xk, xv = map(lambda e: e.transpose(1, 2), (xq, xk, xv))
|
| 434 |
+
assert mask is None or isinstance(mask, (str, torch.Tensor))
|
| 435 |
+
is_causal = (mask == "causal") if isinstance(mask, str) else False
|
| 436 |
+
mask = mask if isinstance(mask, torch.Tensor) else None
|
| 437 |
+
output = F.scaled_dot_product_attention(
|
| 438 |
+
xq,
|
| 439 |
+
xk,
|
| 440 |
+
xv,
|
| 441 |
+
is_causal=is_causal,
|
| 442 |
+
attn_mask=mask,
|
| 443 |
+
)
|
| 444 |
+
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
| 445 |
+
else:
|
| 446 |
+
raise NotImplementedError(
|
| 447 |
+
f"Attention implementation {attn_impl} not supported"
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
output = self.wo(output.reshape(output_shape))
|
| 451 |
+
|
| 452 |
+
return output
|
| 453 |
+
|
| 454 |
+
def reset_parameters(self, init_std=None, factor=1.0):
|
| 455 |
+
init_std = init_std or (self.dim ** (-0.5))
|
| 456 |
+
|
| 457 |
+
for w in [self.wq, self.wk, self.wv]:
|
| 458 |
+
nn.init.trunc_normal_(
|
| 459 |
+
w.weight,
|
| 460 |
+
mean=0.0,
|
| 461 |
+
std=init_std,
|
| 462 |
+
a=-3 * init_std,
|
| 463 |
+
b=3 * init_std,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
nn.init.trunc_normal_(
|
| 467 |
+
self.wo.weight,
|
| 468 |
+
mean=0.0,
|
| 469 |
+
std=init_std / factor,
|
| 470 |
+
a=-3 * init_std,
|
| 471 |
+
b=3 * init_std,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class FeedForward(nn.Module):
|
| 476 |
+
def __init__(
|
| 477 |
+
self,
|
| 478 |
+
dim: int,
|
| 479 |
+
hidden_dim: int,
|
| 480 |
+
multiple_of: int,
|
| 481 |
+
ffn_dim_multiplier: Optional[float],
|
| 482 |
+
mp_size: int = 1,
|
| 483 |
+
):
|
| 484 |
+
super().__init__()
|
| 485 |
+
|
| 486 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 487 |
+
if ffn_dim_multiplier is not None:
|
| 488 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
| 489 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
| 490 |
+
assert hidden_dim % mp_size == 0
|
| 491 |
+
|
| 492 |
+
self.dim = dim
|
| 493 |
+
self.hidden_dim = hidden_dim
|
| 494 |
+
|
| 495 |
+
self.w1 = nn.Linear(
|
| 496 |
+
dim,
|
| 497 |
+
hidden_dim,
|
| 498 |
+
bias=False,
|
| 499 |
+
)
|
| 500 |
+
self.w3 = nn.Linear(
|
| 501 |
+
dim,
|
| 502 |
+
hidden_dim,
|
| 503 |
+
bias=False,
|
| 504 |
+
)
|
| 505 |
+
self.w2 = nn.Linear(
|
| 506 |
+
hidden_dim,
|
| 507 |
+
dim,
|
| 508 |
+
bias=False,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 512 |
+
# B S D
|
| 513 |
+
x1 = self.w1(x.view_as(x))
|
| 514 |
+
x3 = self.w3(x.view_as(x))
|
| 515 |
+
output = self.w2(F.silu(x1) * x3)
|
| 516 |
+
return output
|
| 517 |
+
|
| 518 |
+
def reset_parameters(self, init_std=None, factor=1.0):
|
| 519 |
+
in_init_std = init_std or (self.dim ** (-0.5))
|
| 520 |
+
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
| 521 |
+
in_init_std = in_init_std
|
| 522 |
+
out_init_std = out_init_std / factor
|
| 523 |
+
for w in [self.w1, self.w3]:
|
| 524 |
+
nn.init.trunc_normal_(
|
| 525 |
+
w.weight,
|
| 526 |
+
mean=0.0,
|
| 527 |
+
std=in_init_std,
|
| 528 |
+
a=-3 * in_init_std,
|
| 529 |
+
b=3 * in_init_std,
|
| 530 |
+
)
|
| 531 |
+
nn.init.trunc_normal_(
|
| 532 |
+
self.w2.weight,
|
| 533 |
+
mean=0.0,
|
| 534 |
+
std=out_init_std,
|
| 535 |
+
a=-3 * out_init_std,
|
| 536 |
+
b=3 * out_init_std,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class TransformerBlock(nn.Module):
|
| 541 |
+
def __init__(self, args: BaseTransformerArgs):
|
| 542 |
+
super().__init__()
|
| 543 |
+
|
| 544 |
+
assert (args.head_dim is not None) or (
|
| 545 |
+
args.n_heads is not None
|
| 546 |
+
), "Should specify at least head_dim or n_heads"
|
| 547 |
+
self.head_dim = args.head_dim or args.dim // args.n_heads
|
| 548 |
+
self.n_heads = args.n_heads or args.dim // args.head_dim
|
| 549 |
+
self.n_kv_heads = args.n_kv_heads or self.n_heads
|
| 550 |
+
|
| 551 |
+
assert args.n_heads % self.n_kv_heads == 0
|
| 552 |
+
assert args.dim % args.n_heads == 0
|
| 553 |
+
|
| 554 |
+
self.attention = Attention(
|
| 555 |
+
dim=args.dim,
|
| 556 |
+
head_dim=self.head_dim,
|
| 557 |
+
n_heads=self.n_heads,
|
| 558 |
+
n_kv_heads=self.n_kv_heads,
|
| 559 |
+
rope_theta=args.rope_theta,
|
| 560 |
+
)
|
| 561 |
+
self.feed_forward = FeedForward(
|
| 562 |
+
dim=args.dim,
|
| 563 |
+
hidden_dim=4 * args.dim,
|
| 564 |
+
multiple_of=args.multiple_of,
|
| 565 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
| 566 |
+
)
|
| 567 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
| 568 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
| 569 |
+
|
| 570 |
+
def forward(
|
| 571 |
+
self,
|
| 572 |
+
x: torch.Tensor,
|
| 573 |
+
freq_cis: torch.Tensor,
|
| 574 |
+
tok_idx: Optional[torch.Tensor] = None,
|
| 575 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
| 576 |
+
attn_impl: str = "sdpa",
|
| 577 |
+
) -> torch.Tensor:
|
| 578 |
+
|
| 579 |
+
h = x + self.attention(
|
| 580 |
+
self.attention_norm(x),
|
| 581 |
+
freq_cis,
|
| 582 |
+
tok_idx=tok_idx,
|
| 583 |
+
mask=mask,
|
| 584 |
+
attn_impl=attn_impl,
|
| 585 |
+
)
|
| 586 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 587 |
+
return out
|
| 588 |
+
|
| 589 |
+
def init_weights(self, init_std=None, factor=1.0):
|
| 590 |
+
self.attention.reset_parameters(init_std, factor)
|
| 591 |
+
self.attention_norm.reset_parameters()
|
| 592 |
+
|
| 593 |
+
self.feed_forward.reset_parameters(init_std, factor)
|
| 594 |
+
self.ffn_norm.reset_parameters()
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
class BaseTransformer(nn.Module):
|
| 598 |
+
def __init__(self, args: BaseTransformerArgs):
|
| 599 |
+
super().__init__()
|
| 600 |
+
self.dim = args.dim
|
| 601 |
+
self.init_base_std = args.init_base_std
|
| 602 |
+
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
| 603 |
+
self.max_seqlen = args.max_seqlen
|
| 604 |
+
self.rope_embeddings = RotaryEmbedding(
|
| 605 |
+
theta=args.rope_theta,
|
| 606 |
+
head_dim=args.head_dim or args.dim // args.n_heads,
|
| 607 |
+
max_seqlen=args.max_seqlen,
|
| 608 |
+
scale_factor=args.rope_scale_factor,
|
| 609 |
+
low_freq_factor=args.low_freq_factor,
|
| 610 |
+
high_freq_factor=args.high_freq_factor,
|
| 611 |
+
old_context_len=args.old_context_len,
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
self.layers = nn.ModuleList()
|
| 615 |
+
for _ in range(args.n_layers):
|
| 616 |
+
self.layers.append(TransformerBlock(args))
|
| 617 |
+
|
| 618 |
+
def forward(
|
| 619 |
+
self,
|
| 620 |
+
h,
|
| 621 |
+
tok_idx: Optional[torch.Tensor] = None,
|
| 622 |
+
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
| 623 |
+
attn_impl: str = "sdpa",
|
| 624 |
+
):
|
| 625 |
+
|
| 626 |
+
freq_cis = self.rope_embeddings(seqlen=self.max_seqlen, tok_idx=tok_idx)
|
| 627 |
+
|
| 628 |
+
for i, layer in enumerate(self.layers):
|
| 629 |
+
h = layer(h, freq_cis, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
|
| 630 |
+
return h
|
| 631 |
+
|
| 632 |
+
def reset_parameters(self):
|
| 633 |
+
# Either use fixed base std or sqrt model dim
|
| 634 |
+
self.rope_embeddings.reset_parameters()
|
| 635 |
+
|
| 636 |
+
def init_weights(self):
|
| 637 |
+
self.reset_parameters()
|
| 638 |
+
for depth, layer in enumerate(self.layers):
|
| 639 |
+
factor = {
|
| 640 |
+
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
| 641 |
+
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
| 642 |
+
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
| 643 |
+
InitStdFactor.DISABLED: 1.0,
|
| 644 |
+
}[self.init_std_factor]
|
| 645 |
+
|
| 646 |
+
layer.init_weights(self.init_base_std, factor)
|
core/transforms/image_transform.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from functools import reduce
|
| 5 |
+
from logging import getLogger
|
| 6 |
+
from typing import Any, Callable, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torchvision.transforms as tv
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torchvision.transforms import functional as F
|
| 13 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 14 |
+
|
| 15 |
+
logger = getLogger()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
MEAN = (0.5, 0.5, 0.5)
|
| 19 |
+
STD = (0.5, 0.5, 0.5)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_image_transform(
|
| 23 |
+
vision_input_type: str = "vanilla",
|
| 24 |
+
image_res: int = 336,
|
| 25 |
+
max_num_tiles: int = 1,
|
| 26 |
+
normalize_img: bool = True,
|
| 27 |
+
) -> Tuple[Callable, int]:
|
| 28 |
+
|
| 29 |
+
if vision_input_type == "thumb+tile":
|
| 30 |
+
transforms = VariableSizeImageTransform(
|
| 31 |
+
size=image_res,
|
| 32 |
+
max_num_tiles=max_num_tiles,
|
| 33 |
+
normalize_img=normalize_img,
|
| 34 |
+
use_thumbnail="before",
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
transforms = ImageTransform(
|
| 38 |
+
size=image_res,
|
| 39 |
+
normalize_img=normalize_img,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
logger.info(
|
| 43 |
+
f"Initalized transforms with: vision_input_type: '{vision_input_type}' and max_num_tiles: {max_num_tiles}."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return transforms
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ImageTransform(object):
|
| 50 |
+
"""
|
| 51 |
+
Image transform will resize the longer edge to a given size and pad the shorter edge with mean pixel value of the image.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
size: int = 336,
|
| 57 |
+
normalize_img: bool = True,
|
| 58 |
+
) -> None:
|
| 59 |
+
self.size = size
|
| 60 |
+
self._mean = MEAN
|
| 61 |
+
self._std = STD
|
| 62 |
+
|
| 63 |
+
logger.info(f"ImageTransform size: {self.size}")
|
| 64 |
+
|
| 65 |
+
self.to_tensor = tv.ToTensor()
|
| 66 |
+
self.normalize = (
|
| 67 |
+
tv.Normalize(
|
| 68 |
+
mean=self._mean,
|
| 69 |
+
std=self._std,
|
| 70 |
+
inplace=True,
|
| 71 |
+
)
|
| 72 |
+
if normalize_img
|
| 73 |
+
else lambda x: x
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def __call__(self, image: Image.Image):
|
| 77 |
+
w, h = image.size
|
| 78 |
+
image = F.resize(
|
| 79 |
+
image, (self.size, self.size), interpolation=InterpolationMode.BICUBIC
|
| 80 |
+
)
|
| 81 |
+
image = self.to_tensor(image)
|
| 82 |
+
image = self.normalize(image)
|
| 83 |
+
|
| 84 |
+
# Add chunk dim to make it compatible with existing dataloaders
|
| 85 |
+
image = image.view(1, 3, self.size, self.size)
|
| 86 |
+
return image, (w, h)
|
| 87 |
+
|
| 88 |
+
def _transform_torch_tensor(self, image: torch.Tensor):
|
| 89 |
+
h, w = image.shape[-2:] # Image shape (C, H, W) or (N, C, H, W)
|
| 90 |
+
image = F.resize(
|
| 91 |
+
image, size=(self.size, self.size), interpolation=InterpolationMode.BICUBIC
|
| 92 |
+
)
|
| 93 |
+
image = (
|
| 94 |
+
image.to(torch.float32) / 255.0
|
| 95 |
+
) # Convert to float and scale to [0, 1] range
|
| 96 |
+
image = self.normalize(image)
|
| 97 |
+
return image, (w, h)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class VariableSizeImageTransform(object):
|
| 101 |
+
"""
|
| 102 |
+
The variable size image transform will resize the image dynamically
|
| 103 |
+
based on the image aspect ratio and the number of image chunks we allow.
|
| 104 |
+
|
| 105 |
+
The algorithm will not upsample low-res images to fit a certain aspect
|
| 106 |
+
ratio, because that leads to a significant degradation in image quality.
|
| 107 |
+
|
| 108 |
+
For example, if an input image is of size 300x800, and we want to allow
|
| 109 |
+
a maximum of 16 image chunks, it will find the closest aspect ratio that
|
| 110 |
+
is allowed within 16 image chunks, i.e., 2:5 = 2 horizontal patches and
|
| 111 |
+
5 vertical patches, giving a total of 10 chunks.
|
| 112 |
+
|
| 113 |
+
The image will then be resized to products of the base size (default is
|
| 114 |
+
224px because MetaCLIP takes that), so in this case it will be resized to
|
| 115 |
+
2*224:5*224 = 448:1120, where we maintain the original aspect ratio and
|
| 116 |
+
pad with the mean value for the rest. This approach minimizes the amount
|
| 117 |
+
of padding required for any arbitrary resolution.
|
| 118 |
+
|
| 119 |
+
The final output will therefore be of shape (11, 3, 224, 224), where 10
|
| 120 |
+
patches are coming from the resizing and chunking, and the first patch
|
| 121 |
+
is a downsampled version of the image that preserves aspect ratios.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
size: int = 336,
|
| 127 |
+
normalize_img: bool = True,
|
| 128 |
+
max_num_tiles: int = 1,
|
| 129 |
+
use_thumbnail: str = "no",
|
| 130 |
+
area_limit: bool = False,
|
| 131 |
+
) -> None:
|
| 132 |
+
self.size = size
|
| 133 |
+
self._mean = MEAN
|
| 134 |
+
self._std = STD
|
| 135 |
+
|
| 136 |
+
logger.info(f"VariableSizeImageTransform size: {self.size}")
|
| 137 |
+
|
| 138 |
+
self.to_tensor = tv.ToTensor()
|
| 139 |
+
self.normalize = (
|
| 140 |
+
tv.Normalize(
|
| 141 |
+
mean=self._mean,
|
| 142 |
+
std=self._std,
|
| 143 |
+
inplace=True,
|
| 144 |
+
)
|
| 145 |
+
if normalize_img
|
| 146 |
+
else lambda x: x
|
| 147 |
+
)
|
| 148 |
+
self.area_limit = area_limit
|
| 149 |
+
self.max_num_tiles = max_num_tiles
|
| 150 |
+
self.use_thumbnail = use_thumbnail
|
| 151 |
+
if self.use_thumbnail != "no":
|
| 152 |
+
self.thumbnail_transform = ImageTransform(
|
| 153 |
+
size=self.size,
|
| 154 |
+
normalize_img=normalize_img,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _factors(n: int):
|
| 159 |
+
"""Return all factors of a number."""
|
| 160 |
+
return set(
|
| 161 |
+
reduce(
|
| 162 |
+
list.__add__,
|
| 163 |
+
([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
|
| 164 |
+
)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def _find_supported_aspect_ratios(self):
|
| 168 |
+
"""
|
| 169 |
+
This function computes all the allowed aspect ratios for a fixed
|
| 170 |
+
number of input chunks.
|
| 171 |
+
|
| 172 |
+
For example, with `num_tiles=5`, it will return:
|
| 173 |
+
{
|
| 174 |
+
0.2: [(1, 5)],
|
| 175 |
+
5.0: [(5, 1)],
|
| 176 |
+
0.25: [(1, 4)],
|
| 177 |
+
1.0: [(2, 2), (1, 1)],
|
| 178 |
+
4.0: [(4, 1)],
|
| 179 |
+
0.3333333333333333: [(1, 3)],
|
| 180 |
+
3.0: [(3, 1)],
|
| 181 |
+
0.5: [(1, 2)],
|
| 182 |
+
2.0: [(2, 1)]
|
| 183 |
+
}
|
| 184 |
+
"""
|
| 185 |
+
asp_dict = {}
|
| 186 |
+
for chunk_size in range(self.max_num_tiles, 0, -1):
|
| 187 |
+
_factors = sorted(VariableSizeImageTransform._factors(chunk_size))
|
| 188 |
+
_asp_ratios = [(x, chunk_size // x) for x in _factors]
|
| 189 |
+
for ratio in _asp_ratios:
|
| 190 |
+
k = ratio[0] / ratio[1]
|
| 191 |
+
if k not in asp_dict:
|
| 192 |
+
asp_dict[k] = [ratio]
|
| 193 |
+
else:
|
| 194 |
+
asp_dict[k].append(ratio)
|
| 195 |
+
return asp_dict
|
| 196 |
+
|
| 197 |
+
def _find_closest_aspect_ratio(self, img_width: int, img_height: int) -> Tuple:
|
| 198 |
+
"""
|
| 199 |
+
Given an image width, height and target number of chunks
|
| 200 |
+
this function will find the closest supported aspect ratio.
|
| 201 |
+
"""
|
| 202 |
+
tgt_ar = img_width / img_height
|
| 203 |
+
asp_dict = self._find_supported_aspect_ratios()
|
| 204 |
+
cl_d, cl_p = 1e23, None
|
| 205 |
+
if tgt_ar >= 1:
|
| 206 |
+
cl_p = min(
|
| 207 |
+
[k for k in asp_dict.keys() if k <= tgt_ar],
|
| 208 |
+
key=lambda x: abs(x - tgt_ar),
|
| 209 |
+
)
|
| 210 |
+
v = asp_dict[cl_p]
|
| 211 |
+
# select width
|
| 212 |
+
widths = [(idx, self.size * vv[0]) for idx, vv in enumerate(v)]
|
| 213 |
+
tgt_idx = max(widths, key=lambda x: x[1])[0]
|
| 214 |
+
else:
|
| 215 |
+
cl_p = min(
|
| 216 |
+
[k for k in asp_dict.keys() if k > tgt_ar],
|
| 217 |
+
key=lambda x: abs(1 / x - 1 / tgt_ar),
|
| 218 |
+
)
|
| 219 |
+
v = asp_dict[cl_p]
|
| 220 |
+
# select height
|
| 221 |
+
heights = [(idx, self.size * vv[1]) for idx, vv in enumerate(v)]
|
| 222 |
+
tgt_idx = max(heights, key=lambda x: x[1])[0]
|
| 223 |
+
out = v[tgt_idx]
|
| 224 |
+
return out
|
| 225 |
+
|
| 226 |
+
def _resize(
|
| 227 |
+
self, image: Image.Image, target_width: int, target_height: int
|
| 228 |
+
) -> Image.Image:
|
| 229 |
+
# Resize longer edge to given size.
|
| 230 |
+
w, h = image.size
|
| 231 |
+
scale = w / h
|
| 232 |
+
|
| 233 |
+
if scale > 1.0:
|
| 234 |
+
# width > height
|
| 235 |
+
new_w = target_width
|
| 236 |
+
new_h = math.floor(new_w / scale)
|
| 237 |
+
else:
|
| 238 |
+
# height >= width
|
| 239 |
+
new_h = target_height
|
| 240 |
+
new_w = math.floor(new_h * scale)
|
| 241 |
+
|
| 242 |
+
image = F.resize(image, (new_h, new_w))
|
| 243 |
+
return image
|
| 244 |
+
|
| 245 |
+
def _pad(self, image: Image.Image, new_width: int, new_height: int) -> Image.Image:
|
| 246 |
+
mean_per_channel = tuple(
|
| 247 |
+
np.clip(np.array(image).mean(axis=(0, 1)), 0, 255).astype(np.uint8)
|
| 248 |
+
)
|
| 249 |
+
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
|
| 250 |
+
new_im.paste(image)
|
| 251 |
+
return new_im
|
| 252 |
+
|
| 253 |
+
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
| 254 |
+
# Split image into number of required tiles (width x height)
|
| 255 |
+
num_channels, height, width = image.size()
|
| 256 |
+
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
|
| 257 |
+
# Permute dimensions to reorder the axes
|
| 258 |
+
image = image.permute(1, 3, 0, 2, 4).contiguous()
|
| 259 |
+
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
| 260 |
+
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
|
| 261 |
+
return image
|
| 262 |
+
|
| 263 |
+
def _get_image_height_width(
|
| 264 |
+
self, image_width: int, image_height: int, target_width: int, target_height: int
|
| 265 |
+
) -> Tuple[int, int]:
|
| 266 |
+
"""
|
| 267 |
+
Given image width, height and target width, height for the canvas, return the dimensions of how the image would be resized
|
| 268 |
+
with aspect ratio preservation.
|
| 269 |
+
"""
|
| 270 |
+
scale = image_width / image_height
|
| 271 |
+
|
| 272 |
+
if scale > 1.0:
|
| 273 |
+
# Width is larger than height
|
| 274 |
+
|
| 275 |
+
# Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
|
| 276 |
+
rescaling_factor = min(
|
| 277 |
+
target_width / image_width, target_height / image_height
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Set new width to target width and height to the rescaled height.
|
| 281 |
+
new_w = rescaling_factor * image_width
|
| 282 |
+
new_h = math.floor(new_w / scale)
|
| 283 |
+
|
| 284 |
+
else:
|
| 285 |
+
# Height is larger than width
|
| 286 |
+
|
| 287 |
+
# Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
|
| 288 |
+
rescaling_factor = min(
|
| 289 |
+
target_width / image_width, target_height / image_height
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Set new height to target height and width to the rescaled width.
|
| 293 |
+
new_h = rescaling_factor * image_height
|
| 294 |
+
new_w = math.floor(new_h * scale)
|
| 295 |
+
|
| 296 |
+
return new_w, new_h
|
| 297 |
+
|
| 298 |
+
def _fit_image_to_canvas(
|
| 299 |
+
self, img_width: int, img_height: int, area_limit: bool
|
| 300 |
+
) -> Any:
|
| 301 |
+
"""
|
| 302 |
+
Given an image width, height and target number of chunks this function will see if the image
|
| 303 |
+
can be fit into any of the canvases that can be build from arranging the tiles in a grid.
|
| 304 |
+
If the image can be fit onto several canvases, it will return the canvas where the shorter edge
|
| 305 |
+
of the image will be largest.
|
| 306 |
+
|
| 307 |
+
If area_limit is set to True, the tie-breaking prefers the canvas where area is less than 2x the original area.
|
| 308 |
+
"""
|
| 309 |
+
# Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None.
|
| 310 |
+
optimal_canvas = None
|
| 311 |
+
optimal_image_width_height = None
|
| 312 |
+
|
| 313 |
+
scale = img_width / img_height
|
| 314 |
+
|
| 315 |
+
# Gather all potential supported image resolutions and iterate through them to find best match
|
| 316 |
+
potential_arrangements = [
|
| 317 |
+
item
|
| 318 |
+
for sublist in self._find_supported_aspect_ratios().values()
|
| 319 |
+
for item in sublist
|
| 320 |
+
]
|
| 321 |
+
for n_w, n_h in potential_arrangements:
|
| 322 |
+
# Compute the canvas size
|
| 323 |
+
canvas_width, canvas_height = n_w * self.size, n_h * self.size
|
| 324 |
+
|
| 325 |
+
# Check if image can fit into the canvas without downsampling
|
| 326 |
+
if canvas_width >= img_width and canvas_height >= img_height:
|
| 327 |
+
# If we did not find a good canvas yet, we will use the current one
|
| 328 |
+
if optimal_canvas is None:
|
| 329 |
+
# Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling
|
| 330 |
+
optimal_canvas = (n_w, n_h)
|
| 331 |
+
optimal_image_width_height = self._get_image_height_width(
|
| 332 |
+
image_width=img_width,
|
| 333 |
+
image_height=img_height,
|
| 334 |
+
target_width=n_w * self.size,
|
| 335 |
+
target_height=n_h * self.size,
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
# If we already found an optimal canvas before, we will check if the shorter edge of the image will be larger than the current optimal canvas.
|
| 339 |
+
# This means we can potentially upsample the image resolution which is beneficial to performance.
|
| 340 |
+
image_width_height = self._get_image_height_width(
|
| 341 |
+
image_width=img_width,
|
| 342 |
+
image_height=img_height,
|
| 343 |
+
target_width=n_w * self.size,
|
| 344 |
+
target_height=n_h * self.size,
|
| 345 |
+
)
|
| 346 |
+
if area_limit:
|
| 347 |
+
# Prioritize aspect ratio, and choose best within area limit when tied.
|
| 348 |
+
curr_scale = image_width_height[0] / image_width_height[1]
|
| 349 |
+
optim_scale = (
|
| 350 |
+
optimal_image_width_height[0]
|
| 351 |
+
/ optimal_image_width_height[1]
|
| 352 |
+
)
|
| 353 |
+
if abs(scale - curr_scale) < abs(scale - optim_scale):
|
| 354 |
+
# 1. optimize aspect ratio
|
| 355 |
+
optimal_canvas = (n_w, n_h)
|
| 356 |
+
optimal_image_width_height = image_width_height
|
| 357 |
+
elif abs(scale - curr_scale) == abs(scale - optim_scale):
|
| 358 |
+
# 2. optimize area
|
| 359 |
+
if (
|
| 360 |
+
image_width_height[0] * image_width_height[1]
|
| 361 |
+
< 2 * img_width * img_height
|
| 362 |
+
):
|
| 363 |
+
# 2.1 area is less than 2x the original area
|
| 364 |
+
optimal_canvas = (n_w, n_h)
|
| 365 |
+
optimal_image_width_height = image_width_height
|
| 366 |
+
else:
|
| 367 |
+
# NOTE: L3V dynamid tiling. Priortize biggest canvas.
|
| 368 |
+
if (
|
| 369 |
+
scale < 1.0
|
| 370 |
+
and (image_width_height[0] >= optimal_image_width_height[0])
|
| 371 |
+
) or (
|
| 372 |
+
scale >= 1.0
|
| 373 |
+
and (image_width_height[1] >= optimal_image_width_height[1])
|
| 374 |
+
):
|
| 375 |
+
optimal_canvas = (n_w, n_h)
|
| 376 |
+
optimal_image_width_height = image_width_height
|
| 377 |
+
return optimal_canvas
|
| 378 |
+
|
| 379 |
+
def __call__(self, image: Image.Image) -> Tuple[Any, Any]:
|
| 380 |
+
assert isinstance(image, Image.Image), type(image)
|
| 381 |
+
if self.use_thumbnail != "no":
|
| 382 |
+
thumbnail = self.thumbnail_transform(image)[0]
|
| 383 |
+
|
| 384 |
+
w, h = image.size
|
| 385 |
+
# Check if the image can be fit to the canvas without downsampling
|
| 386 |
+
ar = self._fit_image_to_canvas(
|
| 387 |
+
img_width=w, img_height=h, area_limit=self.area_limit
|
| 388 |
+
)
|
| 389 |
+
if ar is None:
|
| 390 |
+
# If we did not find a canvas, we have to find the closest aspect ratio and downsample the image
|
| 391 |
+
ar = self._find_closest_aspect_ratio(img_width=w, img_height=h)
|
| 392 |
+
|
| 393 |
+
image = F.resize(
|
| 394 |
+
image,
|
| 395 |
+
(ar[1] * self.size, ar[0] * self.size), # (h, w)
|
| 396 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 397 |
+
)
|
| 398 |
+
image = self._pad(image, ar[0] * self.size, ar[1] * self.size)
|
| 399 |
+
image = self.to_tensor(image)
|
| 400 |
+
image = self.normalize(image)
|
| 401 |
+
image = self._split(image, ar[0], ar[1]) # type: ignore
|
| 402 |
+
if self.use_thumbnail == "before":
|
| 403 |
+
image = torch.cat((thumbnail, image), dim=0)
|
| 404 |
+
elif self.use_thumbnail == "after":
|
| 405 |
+
image = torch.cat((image, thumbnail), dim=0)
|
| 406 |
+
elif self.use_thumbnail == "both":
|
| 407 |
+
image = torch.cat((thumbnail, image, thumbnail), dim=0)
|
| 408 |
+
|
| 409 |
+
return image, ar
|
core/utils.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Callable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class InitArgs:
|
| 11 |
+
use_gaussian: bool = True # gaussian vs uniform
|
| 12 |
+
coeff_std: Optional[float] = None # std coeff multiplier
|
| 13 |
+
no_init: bool = False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_init_fn(
|
| 17 |
+
args: InitArgs, input_dim: int, init_depth: Optional[int]
|
| 18 |
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
| 19 |
+
"""
|
| 20 |
+
Init functions.
|
| 21 |
+
"""
|
| 22 |
+
if args.no_init:
|
| 23 |
+
return lambda x: x
|
| 24 |
+
|
| 25 |
+
# standard deviation
|
| 26 |
+
std = 1 / math.sqrt(input_dim)
|
| 27 |
+
std = std if args.coeff_std is None else (args.coeff_std * std)
|
| 28 |
+
|
| 29 |
+
# rescale with depth
|
| 30 |
+
if init_depth is not None:
|
| 31 |
+
std = std / math.sqrt(2 * init_depth)
|
| 32 |
+
|
| 33 |
+
# gaussian vs uniform
|
| 34 |
+
if args.use_gaussian:
|
| 35 |
+
return partial(
|
| 36 |
+
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
| 40 |
+
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
core/vision_encoder/__init__.py
ADDED
|
File without changes
|
core/vision_encoder/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (159 Bytes). View file
|
|
|
core/vision_encoder/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
core/vision_encoder/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (4.31 kB). View file
|
|
|
core/vision_encoder/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
core/vision_encoder/__pycache__/pe.cpython-312.pyc
ADDED
|
Binary file (42.9 kB). View file
|
|
|
core/vision_encoder/__pycache__/pe.cpython-313.pyc
ADDED
|
Binary file (38.1 kB). View file
|
|
|
core/vision_encoder/__pycache__/pe_lora.cpython-312.pyc
ADDED
|
Binary file (35.5 kB). View file
|
|
|
core/vision_encoder/__pycache__/rope.cpython-312.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
core/vision_encoder/__pycache__/rope.cpython-313.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
core/vision_encoder/__pycache__/tokenizer.cpython-312.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
core/vision_encoder/__pycache__/tokenizer.cpython-313.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
core/vision_encoder/__pycache__/transforms.cpython-312.pyc
ADDED
|
Binary file (3.51 kB). View file
|
|
|
core/vision_encoder/__pycache__/transforms.cpython-313.pyc
ADDED
|
Binary file (3.3 kB). View file
|
|
|
core/vision_encoder/config.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Include all available vision encoder configurations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, replace
|
| 8 |
+
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def fetch_pe_checkpoint(name: str, path: Optional[str] = None):
|
| 16 |
+
path = path or f"hf://facebook/{name}:{name}.pt"
|
| 17 |
+
|
| 18 |
+
if path.startswith("hf://"):
|
| 19 |
+
# Load from huggingface
|
| 20 |
+
path = path[len("hf://"):]
|
| 21 |
+
repo, file = path.split(":")
|
| 22 |
+
|
| 23 |
+
return hf_hub_download(repo_id=repo, filename=file)
|
| 24 |
+
else:
|
| 25 |
+
return path
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class PEConfig:
|
| 32 |
+
""" Vision Tower Config. """
|
| 33 |
+
patch_size: int
|
| 34 |
+
width: int
|
| 35 |
+
layers: int
|
| 36 |
+
heads: int
|
| 37 |
+
mlp_ratio: float
|
| 38 |
+
output_dim: Optional[int]
|
| 39 |
+
|
| 40 |
+
ls_init_value: float = None
|
| 41 |
+
drop_path: float = 0.0
|
| 42 |
+
|
| 43 |
+
image_size: int = 224,
|
| 44 |
+
use_abs_posemb: bool = True
|
| 45 |
+
use_cls_token: bool = False
|
| 46 |
+
use_rope2d: bool = True
|
| 47 |
+
|
| 48 |
+
pool_type: str = "attn"
|
| 49 |
+
attn_pooler_heads: int = 8
|
| 50 |
+
|
| 51 |
+
use_ln_pre: bool = True
|
| 52 |
+
use_ln_post: bool = True
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class PETextConfig:
|
| 57 |
+
""" Text Tower Config. """
|
| 58 |
+
context_length: int
|
| 59 |
+
width: int
|
| 60 |
+
heads: int
|
| 61 |
+
layers: int
|
| 62 |
+
|
| 63 |
+
output_dim: int
|
| 64 |
+
|
| 65 |
+
mlp_ratio: float = 4.0
|
| 66 |
+
vocab_size: int = 49408
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
PE_VISION_CONFIG = {}
|
| 72 |
+
PE_TEXT_CONFIG = {}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
#########################################
|
| 77 |
+
# PE CORE #
|
| 78 |
+
#########################################
|
| 79 |
+
|
| 80 |
+
PE_VISION_CONFIG["PE-Core-G14-448"] = PEConfig(
|
| 81 |
+
image_size=448,
|
| 82 |
+
patch_size=14,
|
| 83 |
+
width=1536,
|
| 84 |
+
layers=50,
|
| 85 |
+
heads=16,
|
| 86 |
+
mlp_ratio=8960 / 1536,
|
| 87 |
+
pool_type="attn",
|
| 88 |
+
output_dim=1280,
|
| 89 |
+
use_cls_token=False,
|
| 90 |
+
)
|
| 91 |
+
PE_TEXT_CONFIG["PE-Core-G14-448"] = PETextConfig(
|
| 92 |
+
context_length=72,
|
| 93 |
+
width=1280,
|
| 94 |
+
heads=20,
|
| 95 |
+
layers=24,
|
| 96 |
+
output_dim=1280
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
PE_VISION_CONFIG["PE-Core-L14-336"] = PEConfig(
|
| 101 |
+
image_size=336,
|
| 102 |
+
patch_size=14,
|
| 103 |
+
width=1024,
|
| 104 |
+
layers=24,
|
| 105 |
+
heads=16,
|
| 106 |
+
mlp_ratio=4.0,
|
| 107 |
+
pool_type="attn",
|
| 108 |
+
output_dim=1024,
|
| 109 |
+
use_cls_token=True,
|
| 110 |
+
)
|
| 111 |
+
PE_TEXT_CONFIG["PE-Core-L14-336"] = PETextConfig(
|
| 112 |
+
context_length=32,
|
| 113 |
+
width=1024,
|
| 114 |
+
heads=16,
|
| 115 |
+
layers=24,
|
| 116 |
+
output_dim=1024
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
PE_VISION_CONFIG["PE-Core-B16-224"] = PEConfig(
|
| 121 |
+
image_size=224,
|
| 122 |
+
patch_size=16,
|
| 123 |
+
width=768,
|
| 124 |
+
layers=12,
|
| 125 |
+
heads=12,
|
| 126 |
+
mlp_ratio=4.0,
|
| 127 |
+
pool_type="attn",
|
| 128 |
+
output_dim=1024,
|
| 129 |
+
use_cls_token=True,
|
| 130 |
+
)
|
| 131 |
+
PE_TEXT_CONFIG["PE-Core-B16-224"] = PE_TEXT_CONFIG["PE-Core-L14-336"]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
PE_VISION_CONFIG["PE-Core-S16-384"] = PEConfig(
|
| 137 |
+
image_size=384,
|
| 138 |
+
patch_size=16,
|
| 139 |
+
width=384,
|
| 140 |
+
layers=12,
|
| 141 |
+
heads=6,
|
| 142 |
+
mlp_ratio=4.0,
|
| 143 |
+
pool_type="attn",
|
| 144 |
+
output_dim=512,
|
| 145 |
+
use_cls_token=True,
|
| 146 |
+
)
|
| 147 |
+
PE_TEXT_CONFIG["PE-Core-S16-384"] = PETextConfig(
|
| 148 |
+
context_length=32,
|
| 149 |
+
width=512,
|
| 150 |
+
heads=8,
|
| 151 |
+
layers=12,
|
| 152 |
+
output_dim=512
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
PE_VISION_CONFIG["PE-Core-T16-384"] = PEConfig(
|
| 158 |
+
image_size=384,
|
| 159 |
+
patch_size=16,
|
| 160 |
+
width=192,
|
| 161 |
+
layers=12,
|
| 162 |
+
heads=3,
|
| 163 |
+
mlp_ratio=4.0,
|
| 164 |
+
pool_type="attn",
|
| 165 |
+
output_dim=512,
|
| 166 |
+
use_cls_token=True,
|
| 167 |
+
)
|
| 168 |
+
PE_TEXT_CONFIG["PE-Core-T16-384"] = PE_TEXT_CONFIG["PE-Core-S16-384"]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
#########################################
|
| 177 |
+
# PE Lang #
|
| 178 |
+
#########################################
|
| 179 |
+
|
| 180 |
+
PE_VISION_CONFIG["PE-Lang-G14-448"] = replace(
|
| 181 |
+
PE_VISION_CONFIG["PE-Core-G14-448"],
|
| 182 |
+
image_size=448,
|
| 183 |
+
pool_type="none",
|
| 184 |
+
use_ln_post=False,
|
| 185 |
+
output_dim=None,
|
| 186 |
+
ls_init_value=0.1,
|
| 187 |
+
layers=47,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
PE_VISION_CONFIG["PE-Lang-L14-448"] = replace(
|
| 191 |
+
PE_VISION_CONFIG["PE-Core-L14-336"],
|
| 192 |
+
image_size=448,
|
| 193 |
+
pool_type="none",
|
| 194 |
+
use_ln_post=False,
|
| 195 |
+
output_dim=None,
|
| 196 |
+
ls_init_value=0.1,
|
| 197 |
+
layers=23
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Stage 2 checkpoints for PLM-8B and PLM-3B respectively. Pretrained with tiling.
|
| 202 |
+
# Use these checkpoints if you're building a model that uses tiling downstream!
|
| 203 |
+
PE_VISION_CONFIG["PE-Lang-G14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-G14-448"]
|
| 204 |
+
PE_VISION_CONFIG["PE-Lang-L14-448-Tiling"] = PE_VISION_CONFIG["PE-Lang-L14-448"]
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
#########################################
|
| 214 |
+
# PE Spatial #
|
| 215 |
+
#########################################
|
| 216 |
+
|
| 217 |
+
PE_VISION_CONFIG["PE-Spatial-G14-448"] = replace(
|
| 218 |
+
PE_VISION_CONFIG["PE-Core-G14-448"],
|
| 219 |
+
image_size=448,
|
| 220 |
+
pool_type="none",
|
| 221 |
+
use_ln_post=False,
|
| 222 |
+
output_dim=None,
|
| 223 |
+
ls_init_value=0.1,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# No layerscale on the smaller spatial models
|
| 227 |
+
PE_VISION_CONFIG["PE-Spatial-L14-448"] = replace(
|
| 228 |
+
PE_VISION_CONFIG["PE-Core-L14-336"],
|
| 229 |
+
image_size=448,
|
| 230 |
+
pool_type="none",
|
| 231 |
+
use_ln_post=False,
|
| 232 |
+
output_dim=None,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
PE_VISION_CONFIG["PE-Spatial-B16-512"] = replace(
|
| 237 |
+
PE_VISION_CONFIG["PE-Core-B16-224"],
|
| 238 |
+
image_size=512,
|
| 239 |
+
pool_type="none",
|
| 240 |
+
use_ln_post=False,
|
| 241 |
+
output_dim=None,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
PE_VISION_CONFIG["PE-Spatial-S16-512"] = replace(
|
| 246 |
+
PE_VISION_CONFIG["PE-Core-S16-384"],
|
| 247 |
+
image_size=512,
|
| 248 |
+
pool_type="none",
|
| 249 |
+
use_ln_post=False,
|
| 250 |
+
output_dim=None,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
PE_VISION_CONFIG["PE-Spatial-T16-512"] = replace(
|
| 255 |
+
PE_VISION_CONFIG["PE-Core-T16-384"],
|
| 256 |
+
image_size=512,
|
| 257 |
+
pool_type="none",
|
| 258 |
+
use_ln_post=False,
|
| 259 |
+
output_dim=None,
|
| 260 |
+
)
|
core/vision_encoder/pe.py
ADDED
|
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from dataclasses import asdict
|
| 3 |
+
from functools import partial
|
| 4 |
+
from logging import getLogger
|
| 5 |
+
from typing import Callable, Optional, Literal
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from timm.layers import DropPath
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
from torch.nn.init import constant_, xavier_uniform_
|
| 14 |
+
from torch.nn.parameter import Parameter
|
| 15 |
+
from torch.utils.checkpoint import checkpoint
|
| 16 |
+
import types
|
| 17 |
+
from core.vision_encoder.rope import Rope2D
|
| 18 |
+
from core.vision_encoder.config import PEConfig, PETextConfig, PE_VISION_CONFIG, PE_TEXT_CONFIG, fetch_pe_checkpoint
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = getLogger()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LayerScale(nn.Module):
|
| 27 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.inplace = inplace
|
| 30 |
+
self.dim = dim
|
| 31 |
+
self.init_values = init_values
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 35 |
+
|
| 36 |
+
def init_tensors(self):
|
| 37 |
+
self.gamma = nn.Parameter(self.init_values * torch.ones(self.dim))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AttentionPooling(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
embed_dim: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
num_probe: int = 1,
|
| 46 |
+
mlp_ratio: int = 4,
|
| 47 |
+
act_layer: Callable = nn.GELU,
|
| 48 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.embed_dim = embed_dim
|
| 52 |
+
self.num_heads = num_heads
|
| 53 |
+
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
|
| 54 |
+
self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
|
| 55 |
+
self.layernorm = norm_layer(embed_dim)
|
| 56 |
+
self.mlp_width = int(embed_dim * mlp_ratio)
|
| 57 |
+
self.mlp = nn.Sequential(
|
| 58 |
+
OrderedDict(
|
| 59 |
+
[
|
| 60 |
+
("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
|
| 61 |
+
("gelu", act_layer()),
|
| 62 |
+
("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
)
|
| 66 |
+
self._is_converted = False
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor):
|
| 69 |
+
# This is the original forward method that will be replaced.
|
| 70 |
+
batch, _, _ = x.shape
|
| 71 |
+
q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
|
| 72 |
+
x_attn = self.attn(q, x, x, need_weights=False)[0]
|
| 73 |
+
x = x_attn + self.mlp(self.layernorm(x_attn))
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SelfAttention(nn.Module):
|
| 79 |
+
r"""
|
| 80 |
+
Implements sequence packed attention and RoPe
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
embed_dim: int,
|
| 86 |
+
num_heads: int,
|
| 87 |
+
rope: Optional[nn.Module] = None,
|
| 88 |
+
):
|
| 89 |
+
super(SelfAttention, self).__init__()
|
| 90 |
+
self.embed_dim = embed_dim
|
| 91 |
+
|
| 92 |
+
self.num_heads = num_heads
|
| 93 |
+
self.head_dim = embed_dim // num_heads
|
| 94 |
+
assert (
|
| 95 |
+
self.head_dim * num_heads == self.embed_dim
|
| 96 |
+
), "embed_dim must be divisible by num_heads"
|
| 97 |
+
|
| 98 |
+
# To make this compatibile with nn.MultiHeadAttention
|
| 99 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
| 100 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
| 101 |
+
|
| 102 |
+
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
| 103 |
+
|
| 104 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
| 105 |
+
|
| 106 |
+
self.rope = rope
|
| 107 |
+
self.scale = self.head_dim ** (-0.5)
|
| 108 |
+
|
| 109 |
+
def init_tensors(self):
|
| 110 |
+
xavier_uniform_(self.in_proj_weight)
|
| 111 |
+
constant_(self.in_proj_bias, 0.0)
|
| 112 |
+
constant_(self.out_proj.bias, 0.0)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def del_muda(self):
|
| 116 |
+
del self.in_proj_weight
|
| 117 |
+
del self.in_proj_bias
|
| 118 |
+
|
| 119 |
+
def migrate_weights(self):
|
| 120 |
+
"""
|
| 121 |
+
MUST be called *after* loading the state_dict.
|
| 122 |
+
This copies the weights from the old Parameters to the new nn.Linear layer.
|
| 123 |
+
"""
|
| 124 |
+
# Use torch.no_grad() to ensure this is done without tracking gradients
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
self.in_proj.weight.copy_(self.in_proj_weight)
|
| 127 |
+
self.in_proj.bias.copy_(self.in_proj_bias)
|
| 128 |
+
|
| 129 |
+
# del self.in_proj_weight
|
| 130 |
+
# del self.in_proj_bias
|
| 131 |
+
# print("Migration complete. Old parameters have been removed.")
|
| 132 |
+
|
| 133 |
+
def forward(self, x, attn_mask=None, need_weights=False):
|
| 134 |
+
batch, seq, embed_dim = x.shape
|
| 135 |
+
|
| 136 |
+
#proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
|
| 137 |
+
proj = self.in_proj(x)
|
| 138 |
+
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
| 139 |
+
proj = (
|
| 140 |
+
proj.unflatten(-1, (3, embed_dim))
|
| 141 |
+
.unsqueeze(0)
|
| 142 |
+
.transpose(0, -2)
|
| 143 |
+
.squeeze(-2)
|
| 144 |
+
.contiguous()
|
| 145 |
+
)
|
| 146 |
+
q, k, v = proj[0], proj[1], proj[2]
|
| 147 |
+
|
| 148 |
+
# Use "q_" so that we don't accidentally quit in pdb :)
|
| 149 |
+
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
|
| 150 |
+
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
|
| 151 |
+
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
|
| 152 |
+
|
| 153 |
+
if self.rope:
|
| 154 |
+
q, k = self.rope(q, k)
|
| 155 |
+
|
| 156 |
+
if not need_weights:
|
| 157 |
+
# Original efficient path
|
| 158 |
+
attn = F.scaled_dot_product_attention(
|
| 159 |
+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
|
| 160 |
+
)
|
| 161 |
+
attn = rearrange(attn, "b h s d -> b s (h d)")
|
| 162 |
+
return self.out_proj(attn)
|
| 163 |
+
else:
|
| 164 |
+
# Path to get attention weights
|
| 165 |
+
q_scaled = q * self.scale
|
| 166 |
+
# attn_weights shape: (batch, num_heads, seq_len, seq_len)
|
| 167 |
+
attn_weights = torch.matmul(q_scaled, k.transpose(-2, -1))
|
| 168 |
+
|
| 169 |
+
if attn_mask is not None:
|
| 170 |
+
attn_weights += attn_mask
|
| 171 |
+
|
| 172 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 173 |
+
|
| 174 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 175 |
+
attn_output = rearrange(attn_output, "b h s d -> b s (h d)")
|
| 176 |
+
|
| 177 |
+
output = self.out_proj(attn_output)
|
| 178 |
+
return output, attn_weights
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ResidualAttentionBlock(nn.Module):
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
d_model: int,
|
| 185 |
+
n_head: int,
|
| 186 |
+
mlp_ratio: float = 4.0,
|
| 187 |
+
ls_init_value: float = None,
|
| 188 |
+
act_layer: Callable = nn.GELU,
|
| 189 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 190 |
+
drop_path: float = 0.0,
|
| 191 |
+
rope: Optional[nn.Module] = None,
|
| 192 |
+
):
|
| 193 |
+
super().__init__()
|
| 194 |
+
|
| 195 |
+
if rope:
|
| 196 |
+
self.attn = SelfAttention(d_model, n_head, rope=rope)
|
| 197 |
+
else:
|
| 198 |
+
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
| 199 |
+
|
| 200 |
+
self.ls_1 = (
|
| 201 |
+
LayerScale(d_model, ls_init_value)
|
| 202 |
+
if ls_init_value is not None
|
| 203 |
+
else nn.Identity()
|
| 204 |
+
)
|
| 205 |
+
self.ls_2 = (
|
| 206 |
+
LayerScale(d_model, ls_init_value)
|
| 207 |
+
if ls_init_value is not None
|
| 208 |
+
else nn.Identity()
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
self.ln_1 = norm_layer(d_model)
|
| 212 |
+
self.ln_2 = norm_layer(d_model)
|
| 213 |
+
|
| 214 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 215 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 216 |
+
|
| 217 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 218 |
+
self.mlp = nn.Sequential(
|
| 219 |
+
OrderedDict(
|
| 220 |
+
[
|
| 221 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
| 222 |
+
("gelu", act_layer()),
|
| 223 |
+
("c_proj", nn.Linear(mlp_width, d_model)),
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
def _call_attn(
|
| 229 |
+
self,
|
| 230 |
+
q_x: torch.Tensor,
|
| 231 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 232 |
+
need_weights: bool = False,
|
| 233 |
+
):
|
| 234 |
+
|
| 235 |
+
if attn_mask is not None:
|
| 236 |
+
if not attn_mask.dtype == torch.bool:
|
| 237 |
+
attn_mask = attn_mask.to(q_x.dtype)
|
| 238 |
+
|
| 239 |
+
if isinstance(self.attn, SelfAttention):
|
| 240 |
+
# Pass the flag to your custom SelfAttention
|
| 241 |
+
return self.attn(q_x, attn_mask=attn_mask, need_weights=need_weights)
|
| 242 |
+
else:
|
| 243 |
+
# Standard nn.MultiheadAttention
|
| 244 |
+
return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=need_weights)[0]
|
| 245 |
+
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
x: torch.Tensor,
|
| 249 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 250 |
+
need_weights: bool = False,
|
| 251 |
+
):
|
| 252 |
+
attn_result = self._call_attn(self.ln_1(x), attn_mask=attn_mask, need_weights=need_weights)
|
| 253 |
+
|
| 254 |
+
attn_weights = None
|
| 255 |
+
if need_weights:
|
| 256 |
+
# Unpack the output and the weights
|
| 257 |
+
attn_output, attn_weights = attn_result
|
| 258 |
+
else:
|
| 259 |
+
attn_output = attn_result
|
| 260 |
+
|
| 261 |
+
x = x + self.drop_path1(self.ls_1(attn_output))
|
| 262 |
+
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
|
| 263 |
+
|
| 264 |
+
if need_weights:
|
| 265 |
+
return x, attn_weights # Return weights
|
| 266 |
+
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
def del_muda(self):
|
| 270 |
+
self.attn.del_muda()
|
| 271 |
+
|
| 272 |
+
class Transformer(nn.Module):
|
| 273 |
+
def __init__(
|
| 274 |
+
self,
|
| 275 |
+
width: int,
|
| 276 |
+
layers: int,
|
| 277 |
+
heads: int,
|
| 278 |
+
mlp_ratio: float = 4.0,
|
| 279 |
+
ls_init_value: float = None,
|
| 280 |
+
act_layer: Callable = nn.GELU,
|
| 281 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 282 |
+
drop_path: float = 0.0,
|
| 283 |
+
rope: Optional[nn.Module] = None,
|
| 284 |
+
):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.width = width
|
| 287 |
+
self.layers = layers
|
| 288 |
+
self.grad_checkpointing = False
|
| 289 |
+
|
| 290 |
+
self.resblocks = nn.ModuleList(
|
| 291 |
+
[
|
| 292 |
+
ResidualAttentionBlock(
|
| 293 |
+
width,
|
| 294 |
+
heads,
|
| 295 |
+
mlp_ratio,
|
| 296 |
+
ls_init_value=ls_init_value,
|
| 297 |
+
act_layer=act_layer,
|
| 298 |
+
norm_layer=norm_layer,
|
| 299 |
+
drop_path=drop_path,
|
| 300 |
+
rope=rope,
|
| 301 |
+
)
|
| 302 |
+
for _ in range(layers)
|
| 303 |
+
]
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
@torch.jit.ignore
|
| 307 |
+
def set_grad_checkpointing(self, enable=True):
|
| 308 |
+
self.grad_checkpointing = enable
|
| 309 |
+
|
| 310 |
+
@torch.jit.ignore
|
| 311 |
+
def truncate(self, layer_idx: int):
|
| 312 |
+
""" Delete layers so the last layer is the given layer index. """
|
| 313 |
+
self.layers = ((self.layers + layer_idx) % self.layers) + 1
|
| 314 |
+
self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
|
| 315 |
+
|
| 316 |
+
def del_muda(self):
|
| 317 |
+
for resblock in self.resblocks:
|
| 318 |
+
resblock.del_muda()
|
| 319 |
+
|
| 320 |
+
def forward(
|
| 321 |
+
self,
|
| 322 |
+
x: torch.Tensor,
|
| 323 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 324 |
+
layer_idx: int = -1,
|
| 325 |
+
need_weights: bool = False, # Add need_weights flag
|
| 326 |
+
):
|
| 327 |
+
stop_idx = (self.layers + layer_idx) % self.layers
|
| 328 |
+
|
| 329 |
+
attention_maps = [] # List to store maps from each layer
|
| 330 |
+
|
| 331 |
+
for i, r in enumerate(self.resblocks):
|
| 332 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 333 |
+
if need_weights:
|
| 334 |
+
raise ValueError("Cannot get attention maps with gradient checkpointing enabled.")
|
| 335 |
+
x = checkpoint(r, x, attn_mask, use_reentrant=False)
|
| 336 |
+
else:
|
| 337 |
+
if need_weights:
|
| 338 |
+
x, attn_map = r(x, attn_mask=attn_mask, need_weights=True)
|
| 339 |
+
attention_maps.append(attn_map)
|
| 340 |
+
else:
|
| 341 |
+
x = r(x, attn_mask=attn_mask, need_weights=False)
|
| 342 |
+
|
| 343 |
+
if i == stop_idx:
|
| 344 |
+
break
|
| 345 |
+
|
| 346 |
+
if need_weights:
|
| 347 |
+
return x, attention_maps # Return the list of maps
|
| 348 |
+
|
| 349 |
+
return x
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class VisionTransformer(nn.Module):
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
patch_size: int,
|
| 356 |
+
width: int,
|
| 357 |
+
layers: int,
|
| 358 |
+
heads: int,
|
| 359 |
+
mlp_ratio: float,
|
| 360 |
+
act_layer: Callable = nn.GELU,
|
| 361 |
+
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
|
| 362 |
+
use_ln_pre: bool = True,
|
| 363 |
+
use_ln_post: bool = True,
|
| 364 |
+
ls_init_value: float = None,
|
| 365 |
+
drop_path: float = 0.0,
|
| 366 |
+
image_size: int = 448, # Pretrain image size only; you can pass in any image size
|
| 367 |
+
use_abs_posemb: bool = True,
|
| 368 |
+
use_rope2d: bool = True,
|
| 369 |
+
use_cls_token: bool = False,
|
| 370 |
+
output_dim: Optional[int] = 1280,
|
| 371 |
+
attn_pooler_heads: int = 8,
|
| 372 |
+
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
|
| 373 |
+
):
|
| 374 |
+
super().__init__()
|
| 375 |
+
assert pool_type in ("attn", "tok", "avg", "none")
|
| 376 |
+
self.pool_type = pool_type
|
| 377 |
+
self.patch_size = patch_size
|
| 378 |
+
|
| 379 |
+
self.output_dim = output_dim or width
|
| 380 |
+
self.proj_dim = output_dim
|
| 381 |
+
self.heads = heads
|
| 382 |
+
self.width = width
|
| 383 |
+
self.layers = layers
|
| 384 |
+
|
| 385 |
+
self.use_abs_posemb = use_abs_posemb
|
| 386 |
+
self.use_cls_token = use_cls_token
|
| 387 |
+
self.use_rope2d = use_rope2d
|
| 388 |
+
self.image_size = image_size
|
| 389 |
+
|
| 390 |
+
self.conv1 = nn.Conv2d(
|
| 391 |
+
in_channels=3,
|
| 392 |
+
out_channels=width,
|
| 393 |
+
kernel_size=patch_size,
|
| 394 |
+
stride=patch_size,
|
| 395 |
+
bias=False,
|
| 396 |
+
)
|
| 397 |
+
self.rope = (
|
| 398 |
+
Rope2D(
|
| 399 |
+
dim=width // heads,
|
| 400 |
+
use_cls_token=self.use_cls_token,
|
| 401 |
+
)
|
| 402 |
+
if self.use_rope2d
|
| 403 |
+
else None
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
|
| 407 |
+
self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
|
| 408 |
+
|
| 409 |
+
self.transformer = Transformer(
|
| 410 |
+
width,
|
| 411 |
+
layers,
|
| 412 |
+
heads,
|
| 413 |
+
mlp_ratio,
|
| 414 |
+
ls_init_value=ls_init_value,
|
| 415 |
+
act_layer=act_layer,
|
| 416 |
+
norm_layer=norm_layer,
|
| 417 |
+
drop_path=drop_path,
|
| 418 |
+
rope=self.rope,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
if pool_type == "attn":
|
| 422 |
+
self.attn_pool = AttentionPooling(
|
| 423 |
+
embed_dim=width,
|
| 424 |
+
num_heads=attn_pooler_heads,
|
| 425 |
+
act_layer=act_layer,
|
| 426 |
+
norm_layer=norm_layer,
|
| 427 |
+
)
|
| 428 |
+
else:
|
| 429 |
+
self.attn_pool = None
|
| 430 |
+
|
| 431 |
+
self.init_tensors()
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def del_muda(self):
|
| 435 |
+
self.transformer.del_muda()
|
| 436 |
+
|
| 437 |
+
def delete_attn_pool(self):
|
| 438 |
+
del self.attn_pool
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def init_tensors(self):
|
| 442 |
+
def init_submodule_tensors(module):
|
| 443 |
+
for name, child in module.named_children():
|
| 444 |
+
if hasattr(child, "init_tensors"):
|
| 445 |
+
logger.debug(f"Initializing tensors for submodule: {name}")
|
| 446 |
+
child.init_tensors()
|
| 447 |
+
init_submodule_tensors(child)
|
| 448 |
+
|
| 449 |
+
init_submodule_tensors(self)
|
| 450 |
+
self.rope.init_tensors()
|
| 451 |
+
|
| 452 |
+
# class embeddings and positional embeddings
|
| 453 |
+
init_scale = self.width**-0.5
|
| 454 |
+
|
| 455 |
+
if self.use_cls_token:
|
| 456 |
+
self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
|
| 457 |
+
|
| 458 |
+
if self.use_abs_posemb:
|
| 459 |
+
self.posemb_grid_size = self.image_size // self.patch_size
|
| 460 |
+
self.positional_embedding = nn.Parameter(
|
| 461 |
+
init_scale
|
| 462 |
+
* torch.randn(
|
| 463 |
+
int(self.use_cls_token) + self.posemb_grid_size**2, self.width
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
if self.proj_dim is not None:
|
| 468 |
+
self.proj = nn.Parameter(
|
| 469 |
+
init_scale * torch.randn(self.width, self.proj_dim)
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def load_ckpt(self, ckpt_path: str, verbose: bool = True):
|
| 474 |
+
_sd = torch.load(ckpt_path, weights_only=True)
|
| 475 |
+
if "state_dict" in _sd:
|
| 476 |
+
_sd = _sd["state_dict"]
|
| 477 |
+
elif "weights" in _sd:
|
| 478 |
+
_sd = _sd["weights"]
|
| 479 |
+
|
| 480 |
+
# for backwards compatibility
|
| 481 |
+
_sd = {k.replace("module.", ""): v for k, v in _sd.items()}
|
| 482 |
+
if any(k.startswith("visual.") for k in _sd):
|
| 483 |
+
_sd = {k.replace("visual.", ""): v for k, v in _sd.items() if "visual" in k}
|
| 484 |
+
|
| 485 |
+
m, u = self.load_state_dict(_sd, strict=False)
|
| 486 |
+
|
| 487 |
+
if verbose or (m or u):
|
| 488 |
+
logger.info(f"Missing keys for loading vision encoder: {m}")
|
| 489 |
+
logger.info(f"Unexpected keys for loading vision encoder: {u}")
|
| 490 |
+
print(f"Missing keys for loading vision encoder: {m}")
|
| 491 |
+
print(f"Unexpected keys for loading vision encoder: {u}")
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def truncate(self, layer_idx: int):
|
| 495 |
+
""" Delete layers so the last layer is the given layer index. """
|
| 496 |
+
self.transformer.truncate(layer_idx)
|
| 497 |
+
self.layers = self.transformer.layers
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
@classmethod
|
| 501 |
+
def from_config(
|
| 502 |
+
cls,
|
| 503 |
+
name: str,
|
| 504 |
+
pretrained: bool = False,
|
| 505 |
+
checkpoint_path: Optional[str] = None,
|
| 506 |
+
**kwdargs
|
| 507 |
+
):
|
| 508 |
+
if name not in PE_VISION_CONFIG:
|
| 509 |
+
raise RuntimeError(f"{name} not found in configs.")
|
| 510 |
+
|
| 511 |
+
args = asdict(PE_VISION_CONFIG[name])
|
| 512 |
+
args.update(kwdargs)
|
| 513 |
+
|
| 514 |
+
model = cls(**args)
|
| 515 |
+
if pretrained:
|
| 516 |
+
model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
|
| 517 |
+
|
| 518 |
+
return model
|
| 519 |
+
|
| 520 |
+
@classmethod
|
| 521 |
+
def available_configs(cls):
|
| 522 |
+
return list(PE_VISION_CONFIG.keys())
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
@torch.jit.ignore
|
| 526 |
+
def set_grad_checkpointing(self, enable=True):
|
| 527 |
+
self.transformer.set_grad_checkpointing(enable=enable)
|
| 528 |
+
|
| 529 |
+
def _sample_abs_posemb(self, grid_h: int, grid_w: int):
|
| 530 |
+
"""Interpolates the absolute position embedding if necessary."""
|
| 531 |
+
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
|
| 532 |
+
return self.positional_embedding[None, ...]
|
| 533 |
+
|
| 534 |
+
pos_embed = self.positional_embedding
|
| 535 |
+
if self.use_cls_token:
|
| 536 |
+
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
| 537 |
+
|
| 538 |
+
pos_embed = (
|
| 539 |
+
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
|
| 540 |
+
.permute(0, 3, 1, 2)
|
| 541 |
+
.contiguous()
|
| 542 |
+
)
|
| 543 |
+
pos_embed = F.interpolate(
|
| 544 |
+
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
|
| 545 |
+
)
|
| 546 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
|
| 547 |
+
|
| 548 |
+
if self.use_cls_token:
|
| 549 |
+
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
|
| 550 |
+
|
| 551 |
+
return pos_embed[None, ...]
|
| 552 |
+
|
| 553 |
+
def _pool(self, x: torch.Tensor):
|
| 554 |
+
if self.pool_type == "tok":
|
| 555 |
+
return x[:, 0]
|
| 556 |
+
elif self.pool_type == "avg":
|
| 557 |
+
return x.mean(dim=1)
|
| 558 |
+
elif self.pool_type == "attn":
|
| 559 |
+
return self.attn_pool(x).squeeze(1)
|
| 560 |
+
elif self.pool_type == "none":
|
| 561 |
+
return x
|
| 562 |
+
else:
|
| 563 |
+
raise NotImplementedError
|
| 564 |
+
|
| 565 |
+
def forward_features(
|
| 566 |
+
self,
|
| 567 |
+
x: torch.Tensor,
|
| 568 |
+
norm: bool = False,
|
| 569 |
+
layer_idx: int = -1,
|
| 570 |
+
strip_cls_token: bool = False,
|
| 571 |
+
need_weights: bool = False, # Add need_weights flag
|
| 572 |
+
):
|
| 573 |
+
batch, _, h, w = x.shape
|
| 574 |
+
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
| 575 |
+
|
| 576 |
+
x = self.conv1(x)
|
| 577 |
+
x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
|
| 578 |
+
|
| 579 |
+
if self.use_cls_token:
|
| 580 |
+
x = torch.cat(
|
| 581 |
+
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
|
| 582 |
+
dim=1,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if self.use_abs_posemb:
|
| 586 |
+
x = x + self._sample_abs_posemb(grid_h, grid_w)
|
| 587 |
+
|
| 588 |
+
if self.use_rope2d:
|
| 589 |
+
self.rope.update_grid(x.device, grid_h, grid_w)
|
| 590 |
+
|
| 591 |
+
x = self.ln_pre(x)
|
| 592 |
+
|
| 593 |
+
# Get output from the transformer
|
| 594 |
+
transformer_output = self.transformer(x, layer_idx=layer_idx, need_weights=need_weights)
|
| 595 |
+
|
| 596 |
+
attention_maps = None
|
| 597 |
+
if need_weights:
|
| 598 |
+
x, attention_maps = transformer_output
|
| 599 |
+
else:
|
| 600 |
+
x = transformer_output
|
| 601 |
+
|
| 602 |
+
if norm:
|
| 603 |
+
x = self.ln_post(x)
|
| 604 |
+
|
| 605 |
+
if strip_cls_token and self.use_cls_token:
|
| 606 |
+
x = x[:, 1:, :]
|
| 607 |
+
|
| 608 |
+
if need_weights:
|
| 609 |
+
return x, attention_maps # Return maps
|
| 610 |
+
|
| 611 |
+
return x
|
| 612 |
+
|
| 613 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
| 614 |
+
x = self.forward_features(x, norm=True, **kwargs)
|
| 615 |
+
x = self._pool(x)
|
| 616 |
+
|
| 617 |
+
if self.proj_dim is not None:
|
| 618 |
+
x = x @ self.proj
|
| 619 |
+
|
| 620 |
+
return x
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class TextTransformer(nn.Module):
|
| 625 |
+
def __init__(
|
| 626 |
+
self,
|
| 627 |
+
context_length: int = 72,
|
| 628 |
+
vocab_size: int = 49408,
|
| 629 |
+
width: int = 512,
|
| 630 |
+
heads: int = 8,
|
| 631 |
+
layers: int = 12,
|
| 632 |
+
mlp_ratio: float = 4.0,
|
| 633 |
+
ls_init_value: float = None,
|
| 634 |
+
output_dim: int = 1280,
|
| 635 |
+
no_causal_mask: bool = False,
|
| 636 |
+
pad_id: int = 0,
|
| 637 |
+
pool_type: str = "argmax",
|
| 638 |
+
proj_bias: bool = False,
|
| 639 |
+
act_layer: Callable = nn.GELU,
|
| 640 |
+
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
|
| 641 |
+
output_tokens: bool = False,
|
| 642 |
+
use_ln_post: bool = True,
|
| 643 |
+
):
|
| 644 |
+
super().__init__()
|
| 645 |
+
assert pool_type in ("first", "last", "argmax", "none")
|
| 646 |
+
self.pool_type = pool_type
|
| 647 |
+
self.output_tokens = output_tokens
|
| 648 |
+
self.num_pos = self.context_length = context_length
|
| 649 |
+
self.vocab_size = vocab_size
|
| 650 |
+
self.width = width
|
| 651 |
+
self.output_dim = output_dim
|
| 652 |
+
self.heads = heads
|
| 653 |
+
self.pad_id = pad_id
|
| 654 |
+
self.layers = layers
|
| 655 |
+
|
| 656 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
| 657 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
|
| 658 |
+
|
| 659 |
+
self.transformer = Transformer(
|
| 660 |
+
width=width,
|
| 661 |
+
layers=layers,
|
| 662 |
+
heads=heads,
|
| 663 |
+
mlp_ratio=mlp_ratio,
|
| 664 |
+
ls_init_value=ls_init_value,
|
| 665 |
+
act_layer=act_layer,
|
| 666 |
+
norm_layer=norm_layer,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
|
| 670 |
+
|
| 671 |
+
if no_causal_mask:
|
| 672 |
+
self.attn_mask = None
|
| 673 |
+
else:
|
| 674 |
+
self.register_buffer(
|
| 675 |
+
"attn_mask", self.build_causal_mask(), persistent=False
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
if pool_type == "attn" or pool_type == "attn_eos":
|
| 679 |
+
self.attn_pool = AttentionPooling(
|
| 680 |
+
embed_dim=width,
|
| 681 |
+
num_heads=heads,
|
| 682 |
+
act_layer=act_layer,
|
| 683 |
+
norm_layer=norm_layer,
|
| 684 |
+
)
|
| 685 |
+
else: # argmax
|
| 686 |
+
self.attn_pool = None
|
| 687 |
+
|
| 688 |
+
if proj_bias:
|
| 689 |
+
self.text_projection = nn.Linear(width, output_dim)
|
| 690 |
+
else:
|
| 691 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
| 692 |
+
|
| 693 |
+
def build_causal_mask(self):
|
| 694 |
+
# lazily create causal attention mask, with full attention between the tokens
|
| 695 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 696 |
+
mask = torch.empty(self.num_pos, self.num_pos)
|
| 697 |
+
mask.fill_(float("-inf"))
|
| 698 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 699 |
+
return mask
|
| 700 |
+
|
| 701 |
+
def load_ckpt(self, ckpt_path: str, verbose: bool = True):
|
| 702 |
+
_sd = torch.load(ckpt_path, weights_only=True)
|
| 703 |
+
if "state_dict" in _sd:
|
| 704 |
+
_sd = _sd["state_dict"]
|
| 705 |
+
elif "weights" in _sd:
|
| 706 |
+
_sd = _sd["weights"]
|
| 707 |
+
|
| 708 |
+
_sd = {k.replace("module.", ""): v for k, v in _sd.items()}
|
| 709 |
+
|
| 710 |
+
m, u = self.load_state_dict(_sd, strict=False)
|
| 711 |
+
|
| 712 |
+
if verbose or (m or u):
|
| 713 |
+
logger.info(f"Missing keys for loading model: {m}")
|
| 714 |
+
logger.info(f"Unexpected keys for loading model: {u}")
|
| 715 |
+
print(f"Missing keys for loading model: {m}")
|
| 716 |
+
print(f"Unexpected keys for loading model: {u}")
|
| 717 |
+
|
| 718 |
+
def build_cls_mask(self, text):
|
| 719 |
+
cls_mask = (text != self.pad_id).unsqueeze(1)
|
| 720 |
+
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
|
| 721 |
+
additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
|
| 722 |
+
additive_mask.fill_(0)
|
| 723 |
+
additive_mask.masked_fill_(~cls_mask, float("-inf"))
|
| 724 |
+
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
|
| 725 |
+
return additive_mask
|
| 726 |
+
|
| 727 |
+
def text_global_pool(
|
| 728 |
+
self, x, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
|
| 729 |
+
):
|
| 730 |
+
if pool_type == "first":
|
| 731 |
+
pooled, tokens = x[:, 0], x[:, 1:]
|
| 732 |
+
elif pool_type == "last":
|
| 733 |
+
pooled, tokens = x[:, -1], x[:, :-1]
|
| 734 |
+
elif pool_type == "argmax":
|
| 735 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 736 |
+
assert text is not None
|
| 737 |
+
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
|
| 738 |
+
else:
|
| 739 |
+
pooled = tokens = x
|
| 740 |
+
|
| 741 |
+
return pooled, tokens
|
| 742 |
+
|
| 743 |
+
def forward(self, text):
|
| 744 |
+
seq_len = text.shape[1]
|
| 745 |
+
x = self.token_embedding(
|
| 746 |
+
text
|
| 747 |
+
)
|
| 748 |
+
attn_mask = self.attn_mask
|
| 749 |
+
if attn_mask is not None:
|
| 750 |
+
attn_mask = attn_mask[:seq_len, :seq_len]
|
| 751 |
+
|
| 752 |
+
x = x + self.positional_embedding[:seq_len]
|
| 753 |
+
x = self.transformer(x, attn_mask=attn_mask)
|
| 754 |
+
|
| 755 |
+
x = self.ln_final(x)
|
| 756 |
+
pooled, tokens = self.text_global_pool(x, text, pool_type=self.pool_type)
|
| 757 |
+
|
| 758 |
+
if self.text_projection is not None:
|
| 759 |
+
if isinstance(self.text_projection, nn.Linear):
|
| 760 |
+
pooled = self.text_projection(pooled)
|
| 761 |
+
else:
|
| 762 |
+
pooled = pooled @ self.text_projection
|
| 763 |
+
|
| 764 |
+
if self.output_tokens:
|
| 765 |
+
return pooled, tokens
|
| 766 |
+
|
| 767 |
+
return pooled
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class CLIP(TextTransformer):
|
| 773 |
+
def __init__(
|
| 774 |
+
self,
|
| 775 |
+
vision_cfg: PEConfig,
|
| 776 |
+
text_cfg: PETextConfig,
|
| 777 |
+
init_logit_scale: float = np.log(1 / 0.07)
|
| 778 |
+
):
|
| 779 |
+
super(CLIP, self).__init__(**asdict(text_cfg))
|
| 780 |
+
self.visual = VisionTransformer(**asdict(vision_cfg))
|
| 781 |
+
self.image_size = self.visual.image_size # For ease of use
|
| 782 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def encode_image(self, image, normalize: bool = False):
|
| 786 |
+
x = self.visual(image)
|
| 787 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 788 |
+
|
| 789 |
+
def encode_video(self, video, normalize: bool = False): # b n c h w
|
| 790 |
+
b, n, c, h, w = video.shape
|
| 791 |
+
frms = video.reshape(b * n, c, h, w)
|
| 792 |
+
frm_feats = self.encode_image(frms, normalize=normalize)
|
| 793 |
+
video_feats = frm_feats.reshape(b, n, -1)
|
| 794 |
+
video_feats = video_feats.mean(dim=1)
|
| 795 |
+
return video_feats
|
| 796 |
+
|
| 797 |
+
def encode_text(self, text, normalize: bool = False):
|
| 798 |
+
x = super().forward(text)
|
| 799 |
+
return F.normalize(x, dim=-1) if normalize else x
|
| 800 |
+
|
| 801 |
+
def forward(
|
| 802 |
+
self,
|
| 803 |
+
image: Optional[torch.Tensor] = None,
|
| 804 |
+
text: Optional[torch.Tensor] = None,
|
| 805 |
+
):
|
| 806 |
+
image_features = (
|
| 807 |
+
self.encode_image(image, normalize=True) if image is not None else None
|
| 808 |
+
)
|
| 809 |
+
text_features = (
|
| 810 |
+
self.encode_text(text, normalize=True) if text is not None else None
|
| 811 |
+
)
|
| 812 |
+
return image_features, text_features, self.logit_scale.exp()
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
@classmethod
|
| 816 |
+
def from_config(
|
| 817 |
+
cls,
|
| 818 |
+
name: str,
|
| 819 |
+
pretrained: bool = False,
|
| 820 |
+
checkpoint_path: Optional[str] = None # To load your own
|
| 821 |
+
):
|
| 822 |
+
if name not in PE_VISION_CONFIG or name not in PE_TEXT_CONFIG:
|
| 823 |
+
raise RuntimeError(f"{name} not found in configs.")
|
| 824 |
+
|
| 825 |
+
model = cls(PE_VISION_CONFIG[name], PE_TEXT_CONFIG[name])
|
| 826 |
+
if pretrained:
|
| 827 |
+
model.load_ckpt(fetch_pe_checkpoint(name, checkpoint_path))
|
| 828 |
+
|
| 829 |
+
return model
|
| 830 |
+
|
| 831 |
+
@classmethod
|
| 832 |
+
def available_configs(cls):
|
| 833 |
+
return [k for k in PE_VISION_CONFIG if k in PE_TEXT_CONFIG]
|
core/vision_encoder/rope.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import pi
|
| 2 |
+
from typing import Literal, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
from torch import Tensor, broadcast_tensors, einsum, nn
|
| 7 |
+
from torch.amp import autocast
|
| 8 |
+
from torch.nn import Module
|
| 9 |
+
|
| 10 |
+
# helper functions
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def exists(val):
|
| 14 |
+
return val is not None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def default(val, d):
|
| 18 |
+
return val if exists(val) else d
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# broadcat, as tortoise-tts was using it
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def broadcat(tensors, dim=-1):
|
| 25 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 26 |
+
return torch.cat(broadcasted_tensors, dim=dim)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# rotary embedding helper functions
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def rotate_half(x):
|
| 33 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 34 |
+
x1, x2 = x.unbind(dim=-1)
|
| 35 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 36 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@autocast("cuda", enabled=False)
|
| 40 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
| 41 |
+
dtype = t.dtype
|
| 42 |
+
|
| 43 |
+
if t.ndim == 3:
|
| 44 |
+
seq_len = t.shape[seq_dim]
|
| 45 |
+
freqs = freqs[-seq_len:]
|
| 46 |
+
|
| 47 |
+
rot_dim = freqs.shape[-1]
|
| 48 |
+
end_index = start_index + rot_dim
|
| 49 |
+
|
| 50 |
+
assert (
|
| 51 |
+
rot_dim <= t.shape[-1]
|
| 52 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 53 |
+
|
| 54 |
+
t_left, t, t_right = (
|
| 55 |
+
t[..., :start_index],
|
| 56 |
+
t[..., start_index:end_index],
|
| 57 |
+
t[..., end_index:],
|
| 58 |
+
)
|
| 59 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
| 60 |
+
out = torch.cat((t_left, t, t_right), dim=-1)
|
| 61 |
+
|
| 62 |
+
return out.type(dtype)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# learned rotation helpers
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
| 69 |
+
if exists(freq_ranges):
|
| 70 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
| 71 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
| 72 |
+
|
| 73 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
| 74 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# classes
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RotaryEmbedding(Module):
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
dim,
|
| 84 |
+
custom_freqs: Optional[Tensor] = None,
|
| 85 |
+
freqs_for: Union[
|
| 86 |
+
Literal["lang"], Literal["pixel"], Literal["constant"]
|
| 87 |
+
] = "lang",
|
| 88 |
+
theta=10000,
|
| 89 |
+
max_freq=10,
|
| 90 |
+
num_freqs=1,
|
| 91 |
+
learned_freq=False,
|
| 92 |
+
use_xpos=False,
|
| 93 |
+
xpos_scale_base=512,
|
| 94 |
+
interpolate_factor=1.0,
|
| 95 |
+
theta_rescale_factor=1.0,
|
| 96 |
+
seq_before_head_dim=False,
|
| 97 |
+
cache_if_possible=True,
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 101 |
+
# has some connection to NTK literature
|
| 102 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 103 |
+
|
| 104 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 105 |
+
|
| 106 |
+
self.freqs_for = freqs_for
|
| 107 |
+
|
| 108 |
+
if exists(custom_freqs):
|
| 109 |
+
freqs = custom_freqs
|
| 110 |
+
elif freqs_for == "lang":
|
| 111 |
+
freqs = 1.0 / (
|
| 112 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 113 |
+
)
|
| 114 |
+
elif freqs_for == "pixel":
|
| 115 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 116 |
+
elif freqs_for == "constant":
|
| 117 |
+
freqs = torch.ones(num_freqs).float()
|
| 118 |
+
|
| 119 |
+
self.cache_if_possible = cache_if_possible
|
| 120 |
+
|
| 121 |
+
self.tmp_store("cached_freqs", None)
|
| 122 |
+
self.tmp_store("cached_scales", None)
|
| 123 |
+
|
| 124 |
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
| 125 |
+
|
| 126 |
+
self.learned_freq = learned_freq
|
| 127 |
+
|
| 128 |
+
# dummy for device
|
| 129 |
+
|
| 130 |
+
self.tmp_store("dummy", torch.tensor(0))
|
| 131 |
+
|
| 132 |
+
# default sequence dimension
|
| 133 |
+
|
| 134 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 135 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 136 |
+
|
| 137 |
+
# interpolation factors
|
| 138 |
+
|
| 139 |
+
assert interpolate_factor >= 1.0
|
| 140 |
+
self.interpolate_factor = interpolate_factor
|
| 141 |
+
|
| 142 |
+
# xpos
|
| 143 |
+
|
| 144 |
+
self.use_xpos = use_xpos
|
| 145 |
+
if not use_xpos:
|
| 146 |
+
self.tmp_store("scale", None)
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 150 |
+
|
| 151 |
+
self.scale_base = xpos_scale_base
|
| 152 |
+
self.tmp_store("scale", scale)
|
| 153 |
+
|
| 154 |
+
# add apply_rotary_emb as static method
|
| 155 |
+
|
| 156 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def device(self):
|
| 160 |
+
return self.dummy.device
|
| 161 |
+
|
| 162 |
+
def tmp_store(self, key, value):
|
| 163 |
+
self.register_buffer(key, value, persistent=False)
|
| 164 |
+
|
| 165 |
+
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
| 166 |
+
return (
|
| 167 |
+
torch.arange(seq_len, device=device, dtype=dtype) + offset
|
| 168 |
+
) / self.interpolate_factor
|
| 169 |
+
|
| 170 |
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
|
| 171 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 172 |
+
|
| 173 |
+
assert (
|
| 174 |
+
not self.use_xpos
|
| 175 |
+
), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
| 176 |
+
|
| 177 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 178 |
+
|
| 179 |
+
freqs = self.forward(
|
| 180 |
+
self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
|
| 181 |
+
seq_len=seq_len,
|
| 182 |
+
offset=offset,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if seq_dim == -3:
|
| 186 |
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
| 187 |
+
|
| 188 |
+
return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
|
| 189 |
+
|
| 190 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
| 191 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 192 |
+
|
| 193 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 194 |
+
assert q_len <= k_len
|
| 195 |
+
|
| 196 |
+
rotated_q = self.rotate_queries_or_keys(
|
| 197 |
+
q, seq_dim=seq_dim, offset=k_len - q_len + offset
|
| 198 |
+
)
|
| 199 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
|
| 200 |
+
|
| 201 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 202 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 203 |
+
|
| 204 |
+
return rotated_q, rotated_k
|
| 205 |
+
|
| 206 |
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
| 207 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 208 |
+
|
| 209 |
+
assert self.use_xpos
|
| 210 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 211 |
+
|
| 212 |
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
| 213 |
+
|
| 214 |
+
freqs = self.forward(seq, seq_len=seq_len)
|
| 215 |
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
| 216 |
+
|
| 217 |
+
if seq_dim == -3:
|
| 218 |
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
| 219 |
+
scale = rearrange(scale, "n d -> n 1 d")
|
| 220 |
+
|
| 221 |
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
| 222 |
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
| 223 |
+
|
| 224 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 225 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 226 |
+
|
| 227 |
+
return rotated_q, rotated_k
|
| 228 |
+
|
| 229 |
+
def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
|
| 230 |
+
assert self.use_xpos
|
| 231 |
+
|
| 232 |
+
should_cache = self.cache_if_possible and exists(seq_len)
|
| 233 |
+
|
| 234 |
+
if (
|
| 235 |
+
should_cache
|
| 236 |
+
and exists(self.cached_scales)
|
| 237 |
+
and (seq_len + offset) <= self.cached_scales.shape[0]
|
| 238 |
+
):
|
| 239 |
+
return self.cached_scales[offset : (offset + seq_len)]
|
| 240 |
+
|
| 241 |
+
scale = 1.0
|
| 242 |
+
if self.use_xpos:
|
| 243 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 244 |
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
| 245 |
+
scale = torch.cat((scale, scale), dim=-1)
|
| 246 |
+
|
| 247 |
+
if should_cache:
|
| 248 |
+
self.tmp_store("cached_scales", scale)
|
| 249 |
+
|
| 250 |
+
return scale
|
| 251 |
+
|
| 252 |
+
def get_axial_freqs(self, *dims):
|
| 253 |
+
Colon = slice(None)
|
| 254 |
+
all_freqs = []
|
| 255 |
+
|
| 256 |
+
for ind, dim in enumerate(dims):
|
| 257 |
+
if self.freqs_for == "pixel":
|
| 258 |
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
| 259 |
+
else:
|
| 260 |
+
pos = torch.arange(dim, device=self.device)
|
| 261 |
+
|
| 262 |
+
freqs = self.forward(pos, seq_len=dim)
|
| 263 |
+
|
| 264 |
+
all_axis = [None] * len(dims)
|
| 265 |
+
all_axis[ind] = Colon
|
| 266 |
+
|
| 267 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 268 |
+
all_freqs.append(freqs[new_axis_slice])
|
| 269 |
+
|
| 270 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 271 |
+
return torch.cat(all_freqs, dim=-1)
|
| 272 |
+
|
| 273 |
+
@autocast("cuda", enabled=False)
|
| 274 |
+
def forward(self, t: Tensor, seq_len=None, offset=0):
|
| 275 |
+
should_cache = (
|
| 276 |
+
self.cache_if_possible
|
| 277 |
+
and not self.learned_freq
|
| 278 |
+
and exists(seq_len)
|
| 279 |
+
and self.freqs_for != "pixel"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if (
|
| 283 |
+
should_cache
|
| 284 |
+
and exists(self.cached_freqs)
|
| 285 |
+
and (offset + seq_len) <= self.cached_freqs.shape[0]
|
| 286 |
+
):
|
| 287 |
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
| 288 |
+
|
| 289 |
+
freqs = self.freqs
|
| 290 |
+
|
| 291 |
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
| 292 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 293 |
+
|
| 294 |
+
if should_cache:
|
| 295 |
+
self.tmp_store("cached_freqs", freqs.detach())
|
| 296 |
+
|
| 297 |
+
return freqs
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class Rope2D:
|
| 304 |
+
""" Helper class to apply RoPE2D as well as interpolate on the fly. """
|
| 305 |
+
|
| 306 |
+
def __init__(self, dim, use_cls_token=False):
|
| 307 |
+
self.dim = dim
|
| 308 |
+
self.use_cls_token = use_cls_token
|
| 309 |
+
self.grid_size = None
|
| 310 |
+
self.freq = None
|
| 311 |
+
|
| 312 |
+
def init_tensors(self):
|
| 313 |
+
self.rope = RotaryEmbedding(self.dim // 2)
|
| 314 |
+
|
| 315 |
+
def update_grid(self, device, grid_h, grid_w):
|
| 316 |
+
if self.grid_size != (grid_h, grid_w):
|
| 317 |
+
self.grid_size = (grid_h, grid_w)
|
| 318 |
+
|
| 319 |
+
self.rope = self.rope.to(device)
|
| 320 |
+
|
| 321 |
+
if self.use_cls_token:
|
| 322 |
+
# +1 to leave space for the cls token to be (0, 0)
|
| 323 |
+
grid_y_range = torch.arange(grid_h, device=device) + 1
|
| 324 |
+
grid_x_range = torch.arange(grid_w, device=device) + 1
|
| 325 |
+
else:
|
| 326 |
+
grid_y_range = torch.arange(grid_h, device=device)
|
| 327 |
+
grid_x_range = torch.arange(grid_w, device=device)
|
| 328 |
+
|
| 329 |
+
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
| 330 |
+
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
| 331 |
+
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
| 332 |
+
|
| 333 |
+
if self.use_cls_token:
|
| 334 |
+
freq = torch.cat(
|
| 335 |
+
[torch.zeros(1, freq.shape[-1], device=device), freq], dim=0
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
self.freq = freq[None, ...]
|
| 339 |
+
|
| 340 |
+
self.freq = self.freq.to(device)
|
| 341 |
+
|
| 342 |
+
def __call__(self, q, k):
|
| 343 |
+
# batch, heads, seq, dim = q.shape
|
| 344 |
+
q = apply_rotary_emb(self.freq[:, None, :, :], q)
|
| 345 |
+
k = apply_rotary_emb(self.freq[:, None, :, :], k)
|
| 346 |
+
|
| 347 |
+
return q, k
|
core/vision_encoder/transforms.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision.transforms as T
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_image_transform(
|
| 6 |
+
image_size: int,
|
| 7 |
+
center_crop: bool = False,
|
| 8 |
+
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
|
| 9 |
+
):
|
| 10 |
+
if center_crop:
|
| 11 |
+
crop = [
|
| 12 |
+
T.Resize(image_size, interpolation=interpolation),
|
| 13 |
+
T.CenterCrop(image_size)
|
| 14 |
+
]
|
| 15 |
+
else:
|
| 16 |
+
# "Squash": most versatile
|
| 17 |
+
crop = [
|
| 18 |
+
T.Resize((image_size, image_size), interpolation=interpolation)
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
return T.Compose(crop + [
|
| 22 |
+
T.Lambda(lambda x: x.convert("RGB")),
|
| 23 |
+
T.ToTensor(),
|
| 24 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
|
| 25 |
+
])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
import torchvision.transforms as T
|
| 32 |
+
|
| 33 |
+
from PIL import Image
|
| 34 |
+
|
| 35 |
+
def get_image_transform(
|
| 36 |
+
image_size: int,
|
| 37 |
+
center_crop: bool = False,
|
| 38 |
+
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR # We used bilinear during training
|
| 39 |
+
):
|
| 40 |
+
if center_crop:
|
| 41 |
+
crop = [
|
| 42 |
+
T.Resize(image_size, interpolation=interpolation),
|
| 43 |
+
T.CenterCrop(image_size)
|
| 44 |
+
]
|
| 45 |
+
else:
|
| 46 |
+
# "Squash": most versatile
|
| 47 |
+
crop = [
|
| 48 |
+
T.Resize((image_size, image_size), interpolation=interpolation)
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
return T.Compose(crop + [
|
| 52 |
+
T.Lambda(lambda x: x.convert("RGB")),
|
| 53 |
+
T.ToTensor(),
|
| 54 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
|
| 55 |
+
])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _convert_to_rgb(image: Image.Image) -> Image.Image:
|
| 59 |
+
"""Converts a PIL Image to RGB format."""
|
| 60 |
+
return image.convert("RGB")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_image_transform_fix(
|
| 64 |
+
image_size: int,
|
| 65 |
+
center_crop: bool = False,
|
| 66 |
+
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR
|
| 67 |
+
):
|
| 68 |
+
if center_crop:
|
| 69 |
+
crop = [
|
| 70 |
+
T.Resize(image_size, interpolation=interpolation),
|
| 71 |
+
T.CenterCrop(image_size)
|
| 72 |
+
]
|
| 73 |
+
else:
|
| 74 |
+
# "Squash": most versatile
|
| 75 |
+
crop = [
|
| 76 |
+
T.Resize((image_size, image_size), interpolation=interpolation)
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
return T.Compose(crop + [
|
| 80 |
+
T.Lambda(_convert_to_rgb),
|
| 81 |
+
T.ToTensor(),
|
| 82 |
+
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
|
| 83 |
+
])
|
| 84 |
+
|
| 85 |
+
def get_text_tokenizer(context_length: int):
|
| 86 |
+
return SimpleTokenizer(context_length=context_length)
|
core/vision_projector/base.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseProjector(nn.Module, ABC):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.adaptive_avg_pool = None
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def setup_projector(self):
|
| 15 |
+
"""
|
| 16 |
+
Setup the vision_projector attribute in subclasses.
|
| 17 |
+
"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 22 |
+
x = self.projector(x)
|
| 23 |
+
x = x.permute(1, 0, 2)
|
| 24 |
+
if self.adaptive_avg_pool is not None:
|
| 25 |
+
x = self.adaptive_avg_pool(x)
|
| 26 |
+
return x
|
core/vision_projector/mlp.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from core.utils import get_init_fn
|
| 9 |
+
from core.vision_projector.base import BaseProjector
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AdaptiveAvgPooling(nn.Module):
|
| 13 |
+
def __init__(self, pooling_ratio=2):
|
| 14 |
+
super(AdaptiveAvgPooling, self).__init__()
|
| 15 |
+
self.pooling_ratio = pooling_ratio
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
b, num_tokens, c = x.shape
|
| 19 |
+
h = int(math.sqrt(num_tokens))
|
| 20 |
+
assert h * h == num_tokens
|
| 21 |
+
|
| 22 |
+
shape = (h // self.pooling_ratio, h // self.pooling_ratio)
|
| 23 |
+
x = x.permute(0, 2, 1).reshape(b, -1, h, h)
|
| 24 |
+
x = F.adaptive_avg_pool2d(x, shape)
|
| 25 |
+
x = x.flatten(2).transpose(1, 2)
|
| 26 |
+
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MLPProjector(BaseProjector):
|
| 31 |
+
def __init__(self, args):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.setup_projector(args)
|
| 34 |
+
self.pooling_ratio = args.pooling_ratio
|
| 35 |
+
self.adaptive_avg_pool = AdaptiveAvgPooling(pooling_ratio=args.pooling_ratio)
|
| 36 |
+
self.remove_vision_class_token = args.remove_vision_class_token
|
| 37 |
+
|
| 38 |
+
def init_tensors(self):
|
| 39 |
+
self.init_method(self.projector[0].weight)
|
| 40 |
+
self.init_method(self.projector[0].bias)
|
| 41 |
+
self.init_method(self.projector[2].weight)
|
| 42 |
+
self.init_method(self.projector[2].bias)
|
| 43 |
+
|
| 44 |
+
def setup_projector(self, args):
|
| 45 |
+
self.init_method = get_init_fn(args.mlp_init, args.dim, init_depth=None)
|
| 46 |
+
input_size = args.vision_model["width"]
|
| 47 |
+
output_size = args.dim
|
| 48 |
+
self.projector = nn.Sequential(
|
| 49 |
+
nn.Linear(
|
| 50 |
+
in_features=input_size,
|
| 51 |
+
out_features=output_size,
|
| 52 |
+
bias=True,
|
| 53 |
+
dtype=torch.get_default_dtype(),
|
| 54 |
+
),
|
| 55 |
+
nn.GELU(),
|
| 56 |
+
nn.Linear(
|
| 57 |
+
in_features=output_size,
|
| 58 |
+
out_features=output_size,
|
| 59 |
+
bias=True,
|
| 60 |
+
dtype=torch.get_default_dtype(),
|
| 61 |
+
),
|
| 62 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
opencv-python-headless
|
| 4 |
+
Pillow
|
| 5 |
+
numpy
|
| 6 |
+
einops
|
| 7 |
+
peft
|
| 8 |
+
python-dotenv
|
| 9 |
+
tqdm
|
| 10 |
+
gradio
|
setup.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='pe_adaptation',
|
| 5 |
+
version='0.1',
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
)
|
src/model.py
ADDED
|
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from einops import rearrange
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from core.vision_encoder.pe import SelfAttention, AttentionPooling
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from typing import Dict, List
|
| 12 |
+
from utils.task_config import Task
|
| 13 |
+
import torch
|
| 14 |
+
from typing import Optional, Union, Mapping,OrderedDict
|
| 15 |
+
from src.dlora import *
|
| 16 |
+
from peft import PeftModel, get_peft_model, LoraConfig
|
| 17 |
+
DROPOUT_P = 0.5
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class MTLModel(nn.Module):
|
| 21 |
+
def __init__(self, backbone, tasks: List[Task],
|
| 22 |
+
rank: int = 64,
|
| 23 |
+
use_lora: bool = True,
|
| 24 |
+
truncate_idx: int = 22,
|
| 25 |
+
last_lora_layers: int = -99,
|
| 26 |
+
lora_dropout: float = 0.5,
|
| 27 |
+
use_mtl_lora :bool = False,
|
| 28 |
+
use_deep_head:bool = False,
|
| 29 |
+
use_batch_norm:bool = True,
|
| 30 |
+
use_mtl_attn_pool: bool = True,
|
| 31 |
+
use_dora:bool = True):
|
| 32 |
+
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.use_mtl_attn_pool=use_mtl_attn_pool
|
| 35 |
+
self.tasks = tasks
|
| 36 |
+
self.use_mtl_lora = use_mtl_lora
|
| 37 |
+
self.use_deep_head= use_deep_head
|
| 38 |
+
self.use_lora = use_lora
|
| 39 |
+
self.use_mtlora = use_mtl_lora
|
| 40 |
+
output_dim = backbone.output_dim
|
| 41 |
+
# log_vars is for uncertainty weighting
|
| 42 |
+
self.log_vars = nn.Parameter(torch.zeros(len(tasks)))
|
| 43 |
+
task_names = [task.name for task in tasks]
|
| 44 |
+
self.backbone = backbone
|
| 45 |
+
width = backbone.width
|
| 46 |
+
heads = backbone.heads
|
| 47 |
+
rope = backbone.rope
|
| 48 |
+
|
| 49 |
+
if self.use_mtl_lora:
|
| 50 |
+
# save last residual attention block, as we need the weights values to seed the new mtl version
|
| 51 |
+
orig_last_block = backbone.transformer.resblocks[-1]
|
| 52 |
+
self.ln_post = backbone.ln_post
|
| 53 |
+
|
| 54 |
+
# save the attention pooling, as we need the weights values to seed the task specifics attention pooling layers
|
| 55 |
+
orig_attn_pool = backbone.attn_pool.to('cuda')
|
| 56 |
+
|
| 57 |
+
self.backbone.truncate(layer_idx=truncate_idx) # 23th block becomes the last (the idx is 22)
|
| 58 |
+
|
| 59 |
+
# mtl block that produces t-task specific features maps, plus a shared one
|
| 60 |
+
self.mtl_layer = MTLoRAResidualAttentionBlock(
|
| 61 |
+
d_model=width,
|
| 62 |
+
n_head=heads,
|
| 63 |
+
rope=rope,
|
| 64 |
+
r={'shared': rank, **{name: rank for name in task_names}},
|
| 65 |
+
tasks=task_names,
|
| 66 |
+
shared_mode='matrix' ,
|
| 67 |
+
lora_shared_scale=0.0 # We do not use the shared matrix, so we set it's scale to 0
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
self.mtl_layer.load_from_original_block(orig_last_block)
|
| 72 |
+
print("MTL-LoRA final block created and initialized from pretrained weights.")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if self.use_mtl_attn_pool:
|
| 76 |
+
self.attn_pool = MTLoRAAttentionPooling(
|
| 77 |
+
embed_dim=width,
|
| 78 |
+
num_heads=8,
|
| 79 |
+
tasks=task_names,
|
| 80 |
+
r={'shared': rank, **{name: rank for name in task_names}},
|
| 81 |
+
lora_dropout=lora_dropout,
|
| 82 |
+
lora_task_scale=1.0,
|
| 83 |
+
lora_shared_scale=0.0
|
| 84 |
+
)
|
| 85 |
+
self.attn_pool.load_from_original(orig_attn_pool)
|
| 86 |
+
else:
|
| 87 |
+
self.task_specific_attn_pool = nn.ModuleDict({
|
| 88 |
+
task.name: AttentionPooling(embed_dim=width, num_heads=8)
|
| 89 |
+
for task in self.tasks
|
| 90 |
+
})
|
| 91 |
+
for task in self.tasks:
|
| 92 |
+
self.task_specific_attn_pool[task.name].load_state_dict(orig_attn_pool.state_dict())
|
| 93 |
+
print("Task-specific Attention Pooling layers created and initialized.")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
del self.backbone.attn_pool
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if use_lora:
|
| 101 |
+
# You can modify this list if you want to target only attention layers or mlp layers
|
| 102 |
+
target_layers = ["attn.in_proj", "attn.out_proj", "mlp.c_fc", "mlp.c_proj"]
|
| 103 |
+
target_modules = []
|
| 104 |
+
for name, param in self.backbone.named_modules():
|
| 105 |
+
if not isinstance(param, nn.Linear):
|
| 106 |
+
continue
|
| 107 |
+
is_target_layer = any(s in name for s in target_layers)
|
| 108 |
+
if is_target_layer:
|
| 109 |
+
if "attn_pool" in name:
|
| 110 |
+
target_modules.append(name)
|
| 111 |
+
elif "transformer.resblocks" in name:
|
| 112 |
+
layer_idx = int(name.split('.')[2])
|
| 113 |
+
if layer_idx >= last_lora_layers:
|
| 114 |
+
target_modules.append(name)
|
| 115 |
+
|
| 116 |
+
lora_config = LoraConfig(
|
| 117 |
+
r=rank,
|
| 118 |
+
lora_alpha=rank,
|
| 119 |
+
target_modules= target_modules,
|
| 120 |
+
use_dora=use_dora,
|
| 121 |
+
lora_dropout=lora_dropout,
|
| 122 |
+
bias = "none"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.backbone = get_peft_model(self.backbone,lora_config)
|
| 126 |
+
print("PEFT LoRA module added")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if self.use_deep_head == False:
|
| 130 |
+
self.prediction_layers = nn.ModuleDict({
|
| 131 |
+
task.name: nn.Sequential(
|
| 132 |
+
nn.BatchNorm1d(backbone.output_dim) if use_batch_norm else nn.Identity(),
|
| 133 |
+
nn.Dropout(p=DROPOUT_P),
|
| 134 |
+
nn.Linear( backbone.output_dim, len(task.class_labels))
|
| 135 |
+
)
|
| 136 |
+
for task in self.tasks
|
| 137 |
+
})
|
| 138 |
+
print("Task-specific prediction heads created.")
|
| 139 |
+
else:
|
| 140 |
+
self.prediction_layers = nn.ModuleDict({
|
| 141 |
+
task.name: nn.Sequential(
|
| 142 |
+
nn.BatchNorm1d(backbone.output_dim) if use_batch_norm else nn.Identity(),
|
| 143 |
+
nn.Dropout(p=DROPOUT_P),
|
| 144 |
+
nn.Linear(backbone.output_dim, backbone.output_dim),
|
| 145 |
+
nn.GELU(),
|
| 146 |
+
nn.Linear(backbone.output_dim, len(task.class_labels)),
|
| 147 |
+
)
|
| 148 |
+
for task in self.tasks
|
| 149 |
+
})
|
| 150 |
+
print("Task-specific prediction deep-heads created.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
self.backbone.del_muda()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def enable_gradient_checkpointing(self):
|
| 158 |
+
"""Call this method after setting up parameter requires_grad"""
|
| 159 |
+
backbone_has_trainable = any(param.requires_grad for param in self.backbone.parameters())
|
| 160 |
+
if backbone_has_trainable:
|
| 161 |
+
self.backbone.set_grad_checkpointing()
|
| 162 |
+
print("Gradient checkpointing enabled for backbone (has trainable parameters)")
|
| 163 |
+
else:
|
| 164 |
+
print("Gradient checkpointing not enabled - backbone has no trainable parameters")
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor):
|
| 167 |
+
if self.use_mtl_lora:
|
| 168 |
+
return self._forward_mtl_block(x)
|
| 169 |
+
else:
|
| 170 |
+
return self._forward_shared(x)
|
| 171 |
+
|
| 172 |
+
def _forward_shared(self, x: torch.Tensor):
|
| 173 |
+
logits = {}
|
| 174 |
+
|
| 175 |
+
#if self.attention_specific_pool == True:
|
| 176 |
+
# features = self.backbone.forward_features(x, norm=True, strip_cls_token=False)
|
| 177 |
+
# for task in self.tasks:
|
| 178 |
+
#
|
| 179 |
+
# pooled_feat = self.task_specific_attn_pool[task_name](features)
|
| 180 |
+
# pooled_feat = pooled_feat.squeeze(1)
|
| 181 |
+
# logits[task_name] = self.prediction_layers[task_name](pooled_feat)
|
| 182 |
+
#else:
|
| 183 |
+
features = self.backbone(x)
|
| 184 |
+
# print(features.shape)
|
| 185 |
+
for task in self.tasks:
|
| 186 |
+
logits[task.name] = self.prediction_layers[task.name](features)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
return logits
|
| 190 |
+
|
| 191 |
+
def _forward_mtl_block(self, x: torch.Tensor, return_feat=False, feat_to_return="None"):
|
| 192 |
+
# Shared feature map from the backbone
|
| 193 |
+
# norm=False, because normalization is "trained" on the feature map of the output of the last ResidualAttentionBlock
|
| 194 |
+
# so we will normalize the task specific feature map, instead of the shared one
|
| 195 |
+
# strip_cls_token=False, because in the PE paper it has been shown to be beneficial to keep it
|
| 196 |
+
features = self.backbone.forward_features(x, norm=False, strip_cls_token=False)
|
| 197 |
+
|
| 198 |
+
# Equal for each task, as our mtl layer follows a task-agnostic layer
|
| 199 |
+
task_features_input = {task.name: features for task in self.tasks}
|
| 200 |
+
|
| 201 |
+
# Returns also a shared features map, that is discarded,
|
| 202 |
+
# task features is a dictionary, the key is task name, and the value is a tensor of shape (batch_size, n_tokens, d_model)
|
| 203 |
+
# rappresting the task specific features map
|
| 204 |
+
_, task_features = self.mtl_layer(features, x_tasks=task_features_input)
|
| 205 |
+
|
| 206 |
+
normalized_task_features = {
|
| 207 |
+
task.name: self.ln_post(task_features[task.name])
|
| 208 |
+
for task in self.tasks
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if self.use_mtl_attn_pool:
|
| 212 |
+
pooled_features = self.attn_pool(normalized_task_features)
|
| 213 |
+
else:
|
| 214 |
+
pooled_features = {}
|
| 215 |
+
for task in self.tasks:
|
| 216 |
+
feat = normalized_task_features[task.name]
|
| 217 |
+
pooled_features[task.name] = self.task_specific_attn_pool[task.name](feat)
|
| 218 |
+
|
| 219 |
+
# this stuff is for pca/tsne visualization
|
| 220 |
+
if return_feat:
|
| 221 |
+
if feat_to_return == "Age":
|
| 222 |
+
return pooled_features['Age']
|
| 223 |
+
elif feat_to_return == "Emotion":
|
| 224 |
+
return pooled_features['Emotion']
|
| 225 |
+
elif feat_to_return == "Gender":
|
| 226 |
+
return pooled_features['Gender']
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
logits = {}
|
| 230 |
+
for task in self.tasks:
|
| 231 |
+
# Squeeze the pooling dimension (1)
|
| 232 |
+
pooled_feat = pooled_features[task.name].squeeze(1) # (batch, 1, d_model) -> (batch, d_model)
|
| 233 |
+
logits[task.name] = self.prediction_layers[task.name](pooled_feat)
|
| 234 |
+
|
| 235 |
+
return logits
|
| 236 |
+
|
| 237 |
+
def save_whole_model(self, filepath: str):
|
| 238 |
+
print(f"Saving model state_dict to {filepath}")
|
| 239 |
+
torch.save(self.state_dict(), filepath)
|
| 240 |
+
|
| 241 |
+
def load_model(self, filepath:str,map_location='cuda'):
|
| 242 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 243 |
+
if self.use_lora or self.use_mtlora:
|
| 244 |
+
self.backbone.merge_and_unload()
|
| 245 |
+
self.to(device)
|
| 246 |
+
state_dict = torch.load(filepath, map_location=map_location)
|
| 247 |
+
self.load_state_dict(state_dict, strict=True)
|
| 248 |
+
|
| 249 |
+
def save_adapters_peft(self, save_directory: str):
|
| 250 |
+
|
| 251 |
+
print(f"Saving adapters to directory: {save_directory}")
|
| 252 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 253 |
+
|
| 254 |
+
custom_layers_state_dict = {
|
| 255 |
+
'prediction_layers': self.prediction_layers.state_dict()
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
if self.use_lora:
|
| 259 |
+
self.backbone.save_pretrained(save_directory)
|
| 260 |
+
|
| 261 |
+
if self.use_mtlora:
|
| 262 |
+
custom_layers_state_dict['mtl_layer'] = self.mtl_layer.state_dict()
|
| 263 |
+
#custom_layers_state_dict['task_specific_attn_pooling'] = self.task_specific_attn_pool.state_dict()
|
| 264 |
+
custom_layers_state_dict['mtl_attn_pool'] = self.attn_pool.state_dict()
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
torch.save(custom_layers_state_dict, os.path.join(save_directory, 'custom_layers.pt'))
|
| 268 |
+
print("Successfully saved PEFT backbone and custom task heads.")
|
| 269 |
+
|
| 270 |
+
def load_heads(self, filepaths: List[str],device='cuda'):
|
| 271 |
+
|
| 272 |
+
for ckpt in filepaths:
|
| 273 |
+
checkpoint = torch.load(ckpt, map_location=device)
|
| 274 |
+
model_state_dict = self.state_dict()
|
| 275 |
+
|
| 276 |
+
if "prediction_layers" in checkpoint:
|
| 277 |
+
for loaded_key, value in checkpoint["prediction_layers"].items():
|
| 278 |
+
new_key = loaded_key
|
| 279 |
+
|
| 280 |
+
# Remap prefix: 'heads.emotion.' -> 'prediction_layers.Emotion.'
|
| 281 |
+
if new_key.startswith('heads.emotion.'):
|
| 282 |
+
new_key = new_key.replace('heads.emotion.', 'prediction_layers.Emotion.')
|
| 283 |
+
|
| 284 |
+
if new_key.startswith('heads.age.'):
|
| 285 |
+
new_key = new_key.replace('heads.age.', 'prediction_layers.Age.')
|
| 286 |
+
|
| 287 |
+
if new_key.startswith('heads.gender.'):
|
| 288 |
+
new_key = new_key.replace('heads.gender.', 'prediction_layers.Gender.')
|
| 289 |
+
|
| 290 |
+
# Remap final layer index for deep head: '.5.' -> '.4.'
|
| 291 |
+
if '.5.' in new_key:
|
| 292 |
+
new_key = new_key.replace('.5.', '.4.')
|
| 293 |
+
|
| 294 |
+
if new_key in model_state_dict:
|
| 295 |
+
if model_state_dict[new_key].shape == value.shape:
|
| 296 |
+
model_state_dict[new_key].copy_(value)
|
| 297 |
+
|
| 298 |
+
def load_adapters_peft(self, load_directory: str, custom_head_name:str = 'custom_layers.pt'):
|
| 299 |
+
|
| 300 |
+
print(f"Loading adapters from directory: {load_directory}")
|
| 301 |
+
if self.use_lora:
|
| 302 |
+
self.backbone = self.backbone.merge_and_unload()
|
| 303 |
+
self.backbone = PeftModel.from_pretrained(self.backbone, load_directory)
|
| 304 |
+
|
| 305 |
+
custom_layers_path = os.path.join(load_directory, custom_head_name)
|
| 306 |
+
if not os.path.exists(custom_layers_path):
|
| 307 |
+
raise FileNotFoundError(f"Custom task heads file not found at {custom_layers_path}")
|
| 308 |
+
|
| 309 |
+
checkpoint = torch.load(custom_layers_path, map_location=("cuda" if torch.cuda.is_available() else "cpu"))
|
| 310 |
+
|
| 311 |
+
self.prediction_layers.load_state_dict(checkpoint['prediction_layers'])
|
| 312 |
+
|
| 313 |
+
if self.use_mtlora:
|
| 314 |
+
try:
|
| 315 |
+
self.mtl_layer.load_state_dict(checkpoint['mtl_layer'][0])
|
| 316 |
+
except KeyError:
|
| 317 |
+
self.mtl_layer.load_state_dict(checkpoint['mtl_layer'])
|
| 318 |
+
self.attn_pool.load_state_dict(checkpoint['mtl_attn_pool'])
|
| 319 |
+
|
| 320 |
+
print("Successfully loaded PEFT backbone and custom task heads.")
|
| 321 |
+
|
| 322 |
+
def save_trained(self, filepath: str):
|
| 323 |
+
|
| 324 |
+
trainable_param_names = {name for name, param in self.named_parameters() if param.requires_grad}
|
| 325 |
+
trainable_module_paths = {'.'.join(name.split('.')[:-1]) for name in trainable_param_names}
|
| 326 |
+
|
| 327 |
+
state_to_save = {}
|
| 328 |
+
full_state_dict = self.state_dict()
|
| 329 |
+
|
| 330 |
+
for key, value in full_state_dict.items():
|
| 331 |
+
if key in trainable_param_names:
|
| 332 |
+
state_to_save[key] = value
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
current_module_path = '.'.join(key.split('.')[:-1])
|
| 337 |
+
if current_module_path in trainable_module_paths:
|
| 338 |
+
state_to_save[key] = value
|
| 339 |
+
|
| 340 |
+
print(f"Saving {len(state_to_save)} state entries (parameters and buffers) to {filepath}")
|
| 341 |
+
torch.save(state_to_save, filepath)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def load_trained_legacy(self, filepath: str, device='cuda'):
|
| 345 |
+
"""The training of some checkpoint where done with a different model class,
|
| 346 |
+
so there is the need of remapping the key names, so they match with this new model class"""
|
| 347 |
+
print(f"Loading trained states from structured checkpoint: {filepath}")
|
| 348 |
+
|
| 349 |
+
checkpoint = torch.load(filepath, map_location=device)
|
| 350 |
+
|
| 351 |
+
model_state_dict = self.state_dict()
|
| 352 |
+
|
| 353 |
+
loaded_keys_count = 0
|
| 354 |
+
skipped_keys = []
|
| 355 |
+
remapped_keys_examples = {}
|
| 356 |
+
|
| 357 |
+
if "backbone_state_dict" in checkpoint:
|
| 358 |
+
print("\n--- Processing Backbone Weights ---")
|
| 359 |
+
for loaded_key, value in checkpoint["backbone_state_dict"].items():
|
| 360 |
+
new_key = loaded_key
|
| 361 |
+
|
| 362 |
+
if new_key.startswith('strategy.backbone.'):
|
| 363 |
+
new_key = new_key.replace('strategy.backbone.', 'backbone.')
|
| 364 |
+
|
| 365 |
+
if 'attn.in_proj_weight' in new_key and 'attn.in_proj.weight' not in new_key:
|
| 366 |
+
new_key = new_key.replace('attn.in_proj_weight', 'attn.in_proj.weight')
|
| 367 |
+
if 'attn.in_proj_bias' in new_key and 'attn.in_proj.bias' not in new_key:
|
| 368 |
+
new_key = new_key.replace('attn.in_proj_bias', 'attn.in_proj.bias')
|
| 369 |
+
|
| 370 |
+
if new_key in model_state_dict:
|
| 371 |
+
if model_state_dict[new_key].shape == value.shape:
|
| 372 |
+
model_state_dict[new_key].copy_(value)
|
| 373 |
+
loaded_keys_count += 1
|
| 374 |
+
if loaded_key != new_key and len(remapped_keys_examples) < 5:
|
| 375 |
+
remapped_keys_examples[loaded_key] = new_key
|
| 376 |
+
else:
|
| 377 |
+
skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
|
| 378 |
+
else:
|
| 379 |
+
skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
|
| 380 |
+
|
| 381 |
+
if "prediction_layers" in checkpoint:
|
| 382 |
+
print("\n--- Processing Prediction Head Weights ---")
|
| 383 |
+
for loaded_key, value in checkpoint["prediction_layers"].items():
|
| 384 |
+
new_key = loaded_key
|
| 385 |
+
|
| 386 |
+
if new_key.startswith('heads.emotion.'):
|
| 387 |
+
new_key = new_key.replace('heads.emotion.', 'prediction_layers.Emotion.')
|
| 388 |
+
|
| 389 |
+
if new_key.startswith('heads.age.'):
|
| 390 |
+
new_key = new_key.replace('heads.age.', 'prediction_layers.Age.')
|
| 391 |
+
|
| 392 |
+
if new_key.startswith('heads.gender.'):
|
| 393 |
+
new_key = new_key.replace('heads.gender.', 'prediction_layers.Gender.')
|
| 394 |
+
|
| 395 |
+
if '.5.' in new_key:
|
| 396 |
+
new_key = new_key.replace('.5.', '.4.')
|
| 397 |
+
|
| 398 |
+
# Validate, load, and update trackers
|
| 399 |
+
if new_key in model_state_dict:
|
| 400 |
+
if model_state_dict[new_key].shape == value.shape:
|
| 401 |
+
model_state_dict[new_key].copy_(value)
|
| 402 |
+
loaded_keys_count += 1
|
| 403 |
+
if loaded_key != new_key and len(remapped_keys_examples) < 10:
|
| 404 |
+
remapped_keys_examples[loaded_key] = new_key
|
| 405 |
+
else:
|
| 406 |
+
skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
|
| 407 |
+
else:
|
| 408 |
+
skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
|
| 409 |
+
|
| 410 |
+
if "attn_pool" in checkpoint:
|
| 411 |
+
print("\n--- Processing Attention Pool Weights ---")
|
| 412 |
+
for loaded_key, value in checkpoint["attn_pool"].items():
|
| 413 |
+
# The attn_pool keys in the source file also have the 'strategy.backbone' prefix
|
| 414 |
+
new_key = loaded_key.replace('strategy.backbone.attn_pool.', 'backbone.attn_pool.')
|
| 415 |
+
|
| 416 |
+
# Validate, load, and update trackers
|
| 417 |
+
if new_key in model_state_dict:
|
| 418 |
+
if model_state_dict[new_key].shape == value.shape:
|
| 419 |
+
model_state_dict[new_key].copy_(value)
|
| 420 |
+
loaded_keys_count += 1
|
| 421 |
+
if loaded_key != new_key and len(remapped_keys_examples) < 15:
|
| 422 |
+
remapped_keys_examples[loaded_key] = new_key
|
| 423 |
+
else:
|
| 424 |
+
skipped_keys.append(f"{loaded_key} (Shape Mismatch: Model {model_state_dict[new_key].shape} vs Ckpt {value.shape})")
|
| 425 |
+
else:
|
| 426 |
+
skipped_keys.append(f"{loaded_key} (as {new_key}) -> Not found in model")
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if loaded_keys_count == 0:
|
| 432 |
+
print('LAODED 0')
|
| 433 |
+
self.load_state_dict(torch.load(filepath, map_location=device), strict=False)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class MTLoRAResidualAttentionBlock(nn.Module):
|
| 438 |
+
"""Adaptation of Perception Encoder ResidualAttentionBlock with MTLora, to produce t-task specific feature-maps and a shared feature map"""
|
| 439 |
+
def __init__(
|
| 440 |
+
self,
|
| 441 |
+
d_model: int,
|
| 442 |
+
n_head: int,
|
| 443 |
+
mlp_ratio: float = 4.0,
|
| 444 |
+
ls_init_value: float = None,
|
| 445 |
+
act_layer = nn.GELU,
|
| 446 |
+
norm_layer = nn.LayerNorm,
|
| 447 |
+
drop_path: float = 0.0,
|
| 448 |
+
rope: Optional[nn.Module] = None,
|
| 449 |
+
r: Union[int, Mapping[str, int]] = 0,
|
| 450 |
+
lora_shared_scale: float = 1.0,
|
| 451 |
+
lora_task_scale: float = 1.0,
|
| 452 |
+
lora_dropout: float = DROPOUT_P,
|
| 453 |
+
tasks=None,
|
| 454 |
+
trainable_scale_shared=False,
|
| 455 |
+
trainable_scale_per_task=False,
|
| 456 |
+
shared_mode: str = 'matrix',
|
| 457 |
+
):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.tasks = tasks
|
| 460 |
+
self.num_heads = n_head
|
| 461 |
+
self.head_dim = d_model // n_head
|
| 462 |
+
self.scale = self.head_dim ** -0.5
|
| 463 |
+
self.rope = rope
|
| 464 |
+
|
| 465 |
+
task_scales = {t: lora_task_scale for t in tasks}
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# MultiTask Lora for QKV matrices
|
| 469 |
+
# (MTLoRAQKV does not actually compute attention, but returns the shared QKV matrices and the task-specific QKV matrices)
|
| 470 |
+
self.attn = MTLoRAQKV(
|
| 471 |
+
in_features=d_model,
|
| 472 |
+
out_features=d_model,
|
| 473 |
+
r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
| 474 |
+
lora_dropout=lora_dropout, tasks=tasks, trainable_scale_shared=trainable_scale_shared,
|
| 475 |
+
trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# MultiTask Lora for projection matrices in mha
|
| 479 |
+
self.out_proj = MTLoRALinear(
|
| 480 |
+
in_features=d_model,
|
| 481 |
+
out_features=d_model,
|
| 482 |
+
r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
| 483 |
+
lora_dropout=lora_dropout, tasks=tasks, trainable_scale_shared=trainable_scale_shared,
|
| 484 |
+
trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 488 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
| 489 |
+
|
| 490 |
+
self.ln_1 = norm_layer(d_model)
|
| 491 |
+
self.ln_2 = norm_layer(d_model)
|
| 492 |
+
|
| 493 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 494 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 495 |
+
|
| 496 |
+
# LoRA-enabled MLP
|
| 497 |
+
mlp_width = int(d_model * mlp_ratio)
|
| 498 |
+
self.mlp = nn.Sequential(
|
| 499 |
+
OrderedDict([
|
| 500 |
+
("c_fc", MTLoRALinear(
|
| 501 |
+
d_model, mlp_width, r=r, lora_shared_scale=lora_shared_scale,
|
| 502 |
+
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks,
|
| 503 |
+
trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task,
|
| 504 |
+
shared_mode=shared_mode
|
| 505 |
+
)),
|
| 506 |
+
("gelu", act_layer()),
|
| 507 |
+
("c_proj", MTLoRALinear(
|
| 508 |
+
mlp_width, d_model, r=r, lora_shared_scale=lora_shared_scale,
|
| 509 |
+
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks,
|
| 510 |
+
trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task,
|
| 511 |
+
shared_mode=shared_mode
|
| 512 |
+
)),
|
| 513 |
+
])
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def _call_attn(
|
| 517 |
+
self,
|
| 518 |
+
x_shared: torch.Tensor,
|
| 519 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 520 |
+
x_tasks: Optional[Dict[str, torch.Tensor]] = None,
|
| 521 |
+
):
|
| 522 |
+
# s is the number of patches/tokens, sequence length
|
| 523 |
+
proj, proj_tasks = self.attn(x_shared, x_tasks) # proj is (b s 3*d_model), proj_tasks is dict of (b s 3*d_model), one entry per task
|
| 524 |
+
|
| 525 |
+
def compute_attention(projection_tensor):
|
| 526 |
+
# Reshape Q, K, V
|
| 527 |
+
# projection_tensor is (b s 3*d_model), need to split and rearrange
|
| 528 |
+
_, s, _ = projection_tensor.shape
|
| 529 |
+
# output_features from MTLoRAQKV is d_model, so 3 * d_model
|
| 530 |
+
split_size = self.attn.q.linear.out_features # This should be d_model
|
| 531 |
+
|
| 532 |
+
# Unflatten into (b s 3 d_model) then transpose to get (3 b s d_model)
|
| 533 |
+
q, k, v = projection_tensor.unflatten(-1, (3, split_size)).permute(2, 0, 1, 3).contiguous()
|
| 534 |
+
# Rearrange for multi-head attention (b h s d)
|
| 535 |
+
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
|
| 536 |
+
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
|
| 537 |
+
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
|
| 538 |
+
|
| 539 |
+
if self.rope:
|
| 540 |
+
q, k = self.rope(q, k)
|
| 541 |
+
|
| 542 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
|
| 543 |
+
return rearrange(attn_output, "b h s d -> b s (h d)")
|
| 544 |
+
|
| 545 |
+
# Process shared path
|
| 546 |
+
attn_result = compute_attention(proj)
|
| 547 |
+
|
| 548 |
+
# Process task-specific paths
|
| 549 |
+
attn_tasks_results = {}
|
| 550 |
+
if proj_tasks:
|
| 551 |
+
for task, task_proj in proj_tasks.items():
|
| 552 |
+
attn_tasks_results[task] = compute_attention(task_proj)
|
| 553 |
+
|
| 554 |
+
# Apply output projection
|
| 555 |
+
# out_proj is an MTLoRALinear, so its forward expects (x, x_tasks)
|
| 556 |
+
shared_out, tasks_out = self.out_proj(attn_result, x_tasks=attn_tasks_results if attn_tasks_results else None)
|
| 557 |
+
|
| 558 |
+
return shared_out, tasks_out
|
| 559 |
+
|
| 560 |
+
def forward(
|
| 561 |
+
self,
|
| 562 |
+
x: torch.Tensor,
|
| 563 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 564 |
+
x_tasks: Optional[Dict[str, torch.Tensor]] = None,
|
| 565 |
+
):
|
| 566 |
+
# Attention block
|
| 567 |
+
norm_x = self.ln_1(x)
|
| 568 |
+
norm_x_tasks = {task: self.ln_1(x_tasks[task]) for task in self.tasks} if x_tasks else None
|
| 569 |
+
|
| 570 |
+
attn_out, attn_tasks_out = self._call_attn(norm_x, attn_mask=attn_mask, x_tasks=norm_x_tasks)
|
| 571 |
+
|
| 572 |
+
x = x + self.drop_path1(self.ls_1(attn_out))
|
| 573 |
+
if attn_tasks_out and x_tasks:
|
| 574 |
+
for task in self.tasks:
|
| 575 |
+
x_tasks[task] = x_tasks[task] + self.drop_path1(self.ls_1(attn_tasks_out[task]))
|
| 576 |
+
|
| 577 |
+
# MLP block
|
| 578 |
+
norm_x = self.ln_2(x)
|
| 579 |
+
norm_x_tasks = {task: self.ln_2(x_tasks[task]) for task in self.tasks} if x_tasks else None
|
| 580 |
+
|
| 581 |
+
# The MTLoRALinear forward needs to be called directly for the sequential MLP
|
| 582 |
+
mlp_fc_out, mlp_fc_tasks_out = self.mlp.c_fc(norm_x, norm_x_tasks)
|
| 583 |
+
gelu_out = self.mlp.gelu(mlp_fc_out)
|
| 584 |
+
gelu_tasks_out = {task: self.mlp.gelu(mlp_fc_tasks_out[task]) for task in self.tasks} if mlp_fc_tasks_out else None
|
| 585 |
+
|
| 586 |
+
mlp_proj_out, mlp_proj_tasks_out = self.mlp.c_proj(gelu_out, gelu_tasks_out)
|
| 587 |
+
|
| 588 |
+
x = x + self.drop_path2(self.ls_2(mlp_proj_out))
|
| 589 |
+
if mlp_proj_tasks_out and x_tasks:
|
| 590 |
+
for task in self.tasks:
|
| 591 |
+
x_tasks[task] = x_tasks[task] + self.drop_path2(self.ls_2(mlp_proj_tasks_out[task]))
|
| 592 |
+
|
| 593 |
+
return x, x_tasks
|
| 594 |
+
|
| 595 |
+
def load_from_original_block(self, original_block):
|
| 596 |
+
"""
|
| 597 |
+
Initializes the weights of this block from a pre-trained ResidualAttentionBlock.
|
| 598 |
+
The LoRA-specific parameters are reset to their initial state.
|
| 599 |
+
"""
|
| 600 |
+
with torch.no_grad():
|
| 601 |
+
# Copy LayerNorm and LayerScale weights
|
| 602 |
+
self.ln_1.load_state_dict(original_block.ln_1.state_dict())
|
| 603 |
+
self.ln_2.load_state_dict(original_block.ln_2.state_dict())
|
| 604 |
+
self.ls_1.load_state_dict(original_block.ls_1.state_dict())
|
| 605 |
+
self.ls_2.load_state_dict(original_block.ls_2.state_dict())
|
| 606 |
+
|
| 607 |
+
# Copy MLP weights into the .linear attribute of the MTLoRALinear layers
|
| 608 |
+
self.mlp.c_fc.linear.load_state_dict(original_block.mlp.c_fc.state_dict())
|
| 609 |
+
self.mlp.c_proj.linear.load_state_dict(original_block.mlp.c_proj.state_dict())
|
| 610 |
+
|
| 611 |
+
# Copy Attention weights
|
| 612 |
+
# Both SelfAttention and nn.MultiheadAttention store QKV weights combined
|
| 613 |
+
if isinstance(original_block.attn, SelfAttention):
|
| 614 |
+
# Using migrate_weights ensures the Parameters are copied to the Linear layer first
|
| 615 |
+
# Then we can extract from the Linear layer
|
| 616 |
+
original_block.attn.migrate_weights() # Ensure weights are in .in_proj and .out_proj
|
| 617 |
+
|
| 618 |
+
# Split the combined weight and bias tensors into Q, K, V from .in_proj
|
| 619 |
+
qkv_weight = original_block.attn.in_proj.weight
|
| 620 |
+
qkv_bias = original_block.attn.in_proj.bias
|
| 621 |
+
|
| 622 |
+
q_w, k_w, v_w = qkv_weight.chunk(3)
|
| 623 |
+
q_b, k_b, v_b = qkv_bias.chunk(3)
|
| 624 |
+
|
| 625 |
+
# Load into the .linear attributes of the MTLoRAQKV module
|
| 626 |
+
self.attn.q.linear.weight.copy_(q_w)
|
| 627 |
+
self.attn.q.linear.bias.copy_(q_b)
|
| 628 |
+
|
| 629 |
+
self.attn.k.linear.weight.copy_(k_w)
|
| 630 |
+
self.attn.k.linear.bias.copy_(k_b)
|
| 631 |
+
|
| 632 |
+
self.attn.v.linear.weight.copy_(v_w)
|
| 633 |
+
self.attn.v.linear.bias.copy_(v_b)
|
| 634 |
+
|
| 635 |
+
# Load the output projection weights
|
| 636 |
+
self.out_proj.linear.load_state_dict(original_block.attn.out_proj.state_dict())
|
| 637 |
+
elif isinstance(original_block.attn, nn.MultiheadAttention):
|
| 638 |
+
self.attn.q.linear.weight.copy_(original_block.attn.in_proj_weight[:self.attn.q.linear.out_features, :])
|
| 639 |
+
self.attn.q.linear.bias.copy_(original_block.attn.in_proj_bias[:self.attn.q.linear.out_features])
|
| 640 |
+
|
| 641 |
+
self.attn.k.linear.weight.copy_(original_block.attn.in_proj_weight[self.attn.q.linear.out_features:2*self.attn.q.linear.out_features, :])
|
| 642 |
+
self.attn.k.linear.bias.copy_(original_block.attn.in_proj_bias[self.attn.q.linear.out_features:2*self.attn.q.linear.out_features])
|
| 643 |
+
|
| 644 |
+
self.attn.v.linear.weight.copy_(original_block.attn.in_proj_weight[2*self.attn.q.linear.out_features:3*self.attn.q.linear.out_features, :])
|
| 645 |
+
self.attn.v.linear.bias.copy_(original_block.attn.in_proj_bias[2*self.attn.q.linear.out_features:3*self.attn.q.linear.out_features])
|
| 646 |
+
|
| 647 |
+
self.out_proj.linear.weight.copy_(original_block.attn.out_proj.weight)
|
| 648 |
+
self.out_proj.linear.bias.copy_(original_block.attn.out_proj.bias)
|
| 649 |
+
|
| 650 |
+
else:
|
| 651 |
+
raise TypeError(f"Unsupported attention module type in original_block: {type(original_block.attn)}")
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
# After loading pretrained weights, re-initialize LoRA-specific parameters
|
| 655 |
+
# This ensures that at the start of finetuning, the LoRA adjustment is zero.
|
| 656 |
+
self.attn.reset_parameters()
|
| 657 |
+
self.out_proj.reset_parameters()
|
| 658 |
+
self.mlp.c_fc.reset_parameters()
|
| 659 |
+
self.mlp.c_proj.reset_parameters()
|
| 660 |
+
|
| 661 |
+
print("Successfully loaded weights from original ResidualAttentionBlock and reset LoRA parameters.")
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class MTLoRAAttentionPooling(nn.Module):
|
| 665 |
+
"""
|
| 666 |
+
A MT-LoRA equivalent of the AttentionPooling transformer block.
|
| 667 |
+
|
| 668 |
+
This module replicates the full original architecture:
|
| 669 |
+
1. Task-specific probes for attention pooling.
|
| 670 |
+
2. MT-LoRA enabled Q/K/V and Output projections.
|
| 671 |
+
3. A LayerNorm layer.
|
| 672 |
+
4. An MLP block with MT-LoRA enabled linear layers.
|
| 673 |
+
5. A final residual connection, matching the original's structure.
|
| 674 |
+
"""
|
| 675 |
+
def __init__(
|
| 676 |
+
self,
|
| 677 |
+
embed_dim: int,
|
| 678 |
+
num_heads: int,
|
| 679 |
+
tasks: List[str],
|
| 680 |
+
r: Union[int, Mapping[str, int]] = 0,
|
| 681 |
+
lora_shared_scale: float = 1.0,
|
| 682 |
+
lora_task_scale: float = 1.0,
|
| 683 |
+
lora_dropout: float = 0.0,
|
| 684 |
+
mlp_ratio: int = 4,
|
| 685 |
+
act_layer = nn.GELU,
|
| 686 |
+
norm_layer = nn.LayerNorm,
|
| 687 |
+
):
|
| 688 |
+
super().__init__()
|
| 689 |
+
self.tasks = tasks
|
| 690 |
+
self.num_heads = num_heads
|
| 691 |
+
|
| 692 |
+
self.probe = nn.ParameterDict({
|
| 693 |
+
task: nn.Parameter(torch.randn(1, 1, embed_dim))
|
| 694 |
+
for task in tasks
|
| 695 |
+
})
|
| 696 |
+
|
| 697 |
+
task_scales = {t: lora_task_scale for t in tasks}
|
| 698 |
+
|
| 699 |
+
self.q_proj = MTLoRALinear(
|
| 700 |
+
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
| 701 |
+
lora_dropout=lora_dropout, tasks=tasks
|
| 702 |
+
)
|
| 703 |
+
self.k_proj = MTLoRALinear(
|
| 704 |
+
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
| 705 |
+
lora_dropout=lora_dropout, tasks=tasks
|
| 706 |
+
)
|
| 707 |
+
self.v_proj = MTLoRALinear(
|
| 708 |
+
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
| 709 |
+
lora_dropout=lora_dropout, tasks=tasks
|
| 710 |
+
)
|
| 711 |
+
self.out_proj = MTLoRALinear(
|
| 712 |
+
embed_dim, embed_dim, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=task_scales,
|
| 713 |
+
lora_dropout=lora_dropout, tasks=tasks
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
self.layernorm = norm_layer(embed_dim)
|
| 717 |
+
mlp_width = int(embed_dim * mlp_ratio)
|
| 718 |
+
self.mlp = nn.Sequential(
|
| 719 |
+
OrderedDict([
|
| 720 |
+
("c_fc", MTLoRALinear(
|
| 721 |
+
embed_dim, mlp_width, r=r, lora_shared_scale=lora_shared_scale,
|
| 722 |
+
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks
|
| 723 |
+
)),
|
| 724 |
+
("gelu", nn.GELU()),
|
| 725 |
+
("c_proj", MTLoRALinear(
|
| 726 |
+
mlp_width, embed_dim, r=r, lora_shared_scale=lora_shared_scale,
|
| 727 |
+
lora_task_scale=task_scales, lora_dropout=lora_dropout, tasks=tasks
|
| 728 |
+
)),
|
| 729 |
+
])
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
def load_from_original(self, original_pool: AttentionPooling):
|
| 733 |
+
"""Initializes all weights from the pretrained AttentionPooling block."""
|
| 734 |
+
with torch.no_grad():
|
| 735 |
+
original_attn = original_pool.attn
|
| 736 |
+
|
| 737 |
+
for task in self.tasks:
|
| 738 |
+
self.probe[task].copy_(original_pool.probe)
|
| 739 |
+
|
| 740 |
+
q_w, k_w, v_w = original_attn.in_proj_weight.chunk(3)
|
| 741 |
+
q_b, k_b, v_b = original_attn.in_proj_bias.chunk(3)
|
| 742 |
+
|
| 743 |
+
self.q_proj.linear.weight.copy_(q_w)
|
| 744 |
+
self.q_proj.linear.bias.copy_(q_b)
|
| 745 |
+
self.k_proj.linear.weight.copy_(k_w)
|
| 746 |
+
self.k_proj.linear.bias.copy_(k_b)
|
| 747 |
+
self.v_proj.linear.weight.copy_(v_w)
|
| 748 |
+
self.v_proj.linear.bias.copy_(v_b)
|
| 749 |
+
|
| 750 |
+
self.out_proj.linear.load_state_dict(original_attn.out_proj.state_dict())
|
| 751 |
+
|
| 752 |
+
self.layernorm.load_state_dict(original_pool.layernorm.state_dict())
|
| 753 |
+
|
| 754 |
+
self.mlp.c_fc.linear.load_state_dict(original_pool.mlp.c_fc.state_dict())
|
| 755 |
+
self.mlp.c_proj.linear.load_state_dict(original_pool.mlp.c_proj.state_dict())
|
| 756 |
+
|
| 757 |
+
self.q_proj.reset_parameters()
|
| 758 |
+
self.k_proj.reset_parameters()
|
| 759 |
+
self.v_proj.reset_parameters()
|
| 760 |
+
self.out_proj.reset_parameters()
|
| 761 |
+
self.mlp.c_fc.reset_parameters()
|
| 762 |
+
self.mlp.c_proj.reset_parameters()
|
| 763 |
+
print("Full MT-LoRA Attention Pooling block created and initialized from pretrained weights.")
|
| 764 |
+
|
| 765 |
+
def forward(self, x_tasks: Dict[str, torch.Tensor]):
|
| 766 |
+
"""
|
| 767 |
+
Forward pass that correctly handles unique inputs for each task.
|
| 768 |
+
|
| 769 |
+
In this version, K and V are calculated inside the loop based on
|
| 770 |
+
the task-specific input 'x', and the each task has it's unique probe.
|
| 771 |
+
"""
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
final_outputs = {}
|
| 775 |
+
for task, x in x_tasks.items():
|
| 776 |
+
B, N, C = x.shape
|
| 777 |
+
probe = self.probe[task].repeat(B, 1, 1)
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
_, q_task_dict = self.q_proj(probe, x_tasks={task: probe})
|
| 781 |
+
q = q_task_dict[task]
|
| 782 |
+
|
| 783 |
+
_, k_task_dict = self.k_proj(x, x_tasks={task: x})
|
| 784 |
+
k = k_task_dict[task]
|
| 785 |
+
|
| 786 |
+
_, v_task_dict = self.v_proj(x, x_tasks={task: x})
|
| 787 |
+
v = v_task_dict[task]
|
| 788 |
+
|
| 789 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
|
| 790 |
+
k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
|
| 791 |
+
v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
|
| 792 |
+
|
| 793 |
+
attn_out = F.scaled_dot_product_attention(q, k, v)
|
| 794 |
+
attn_out_rearranged = rearrange(attn_out, 'b h n d -> b n (h d)')
|
| 795 |
+
|
| 796 |
+
_, out_proj_dict = self.out_proj(attn_out_rearranged, x_tasks={task: attn_out_rearranged})
|
| 797 |
+
x_attn = out_proj_dict[task]
|
| 798 |
+
|
| 799 |
+
norm_attn = self.layernorm(x_attn)
|
| 800 |
+
|
| 801 |
+
_, fc_task_dict = self.mlp.c_fc(norm_attn, x_tasks={task: norm_attn})
|
| 802 |
+
gelu_out = self.mlp.gelu(fc_task_dict[task])
|
| 803 |
+
_, proj_task_dict = self.mlp.c_proj(gelu_out, x_tasks={task: gelu_out})
|
| 804 |
+
mlp_out = proj_task_dict[task]
|
| 805 |
+
|
| 806 |
+
final_outputs[task] = x_attn + mlp_out
|
| 807 |
+
|
| 808 |
+
return final_outputs
|
| 809 |
+
|
utils/__pycache__/commons.cpython-313.pyc
ADDED
|
Binary file (8.06 kB). View file
|
|
|
utils/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (26.6 kB). View file
|
|
|
utils/__pycache__/face_detector.cpython-313.pyc
ADDED
|
Binary file (6.43 kB). View file
|
|
|
utils/__pycache__/task_config.cpython-313.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
utils/commons.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains functions used for loading and logging models """
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
from transformers import AutoModel, AutoProcessor
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import core.vision_encoder.pe as pe
|
| 9 |
+
import core.vision_encoder.transforms as transforms_pe
|
| 10 |
+
from core.vision_encoder.config import PE_VISION_CONFIG
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import requests
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def print_trainable_params(model):
|
| 17 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 18 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 19 |
+
percent = (trainable_params / total_params * 100) if total_params > 0 else 0
|
| 20 |
+
print("\n--- Summary ---")
|
| 21 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
| 22 |
+
print(f"Total parameters: {total_params:,}")
|
| 23 |
+
print(f"Percentage: {percent:.2f}%")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_backbone_pe(version, print_info=False, apply_migration_flag=False,pretrained=True):
|
| 27 |
+
"""
|
| 28 |
+
Load PE ViT model, return model, transforms and size of output (dimension of embedding of last token)
|
| 29 |
+
"""
|
| 30 |
+
print(f'Loading {version}...')
|
| 31 |
+
backbone = pe.VisionTransformer.from_config(version, pretrained=pretrained)
|
| 32 |
+
backbone_config = PE_VISION_CONFIG[version]
|
| 33 |
+
transform = transforms_pe.get_image_transform_fix(image_size=backbone_config.image_size)
|
| 34 |
+
|
| 35 |
+
print("\nYou can ignore the Missing keys list above.")
|
| 36 |
+
print(f"Applying migration = {apply_migration_flag}")
|
| 37 |
+
|
| 38 |
+
if print_info:
|
| 39 |
+
attnpool= backbone.attn_pool
|
| 40 |
+
print(f'embed_dim={attnpool.embed_dim}\nnum_heads={attnpool.num_heads}')
|
| 41 |
+
print(f'OUTPUT DIM = {backbone_config.output_dim}')
|
| 42 |
+
|
| 43 |
+
def apply_migration(m):
|
| 44 |
+
if isinstance(m, pe.SelfAttention):
|
| 45 |
+
m.migrate_weights()
|
| 46 |
+
|
| 47 |
+
if apply_migration_flag == True: # when testing/resuming no migration should be used
|
| 48 |
+
print('[MIGRATION] Migrating weights for PEFT compatibiltyy')
|
| 49 |
+
backbone.apply(apply_migration)
|
| 50 |
+
|
| 51 |
+
return backbone, transform, backbone_config.output_dim
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_backbone_dinov3(model_name: str="facebook/dinov3-vitb16-pretrain-lvd1689m", print_info=False):
|
| 55 |
+
print(f"Loading Hugging Face model: {model_name}")
|
| 56 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 57 |
+
|
| 58 |
+
# Extract image processing configuration from the loaded processor
|
| 59 |
+
image_processor_config = processor
|
| 60 |
+
image_size = image_processor_config.size['height']
|
| 61 |
+
image_mean = image_processor_config.image_mean
|
| 62 |
+
image_std = image_processor_config.image_std
|
| 63 |
+
|
| 64 |
+
transform = transforms.Compose([
|
| 65 |
+
transforms.Lambda(_convert_to_rgb),
|
| 66 |
+
transforms.Resize((image_size, image_size), antialias=True),
|
| 67 |
+
transforms.ToTensor(),
|
| 68 |
+
transforms.Normalize(mean=image_mean, std=image_std)
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
# Load the model and return only the vision backbone
|
| 72 |
+
vision_model = AutoModel.from_pretrained(model_name)
|
| 73 |
+
|
| 74 |
+
if print_info:
|
| 75 |
+
print(f'\nVISION CONFIGS:\n{vision_model.config}')
|
| 76 |
+
print(f'\n\n\n{vision_model}')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
return vision_model, transform, vision_model.config.hidden_size
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_backbone_siglip2(model_name: str='google/siglip2-base-patch16-224',print_info=False):
|
| 83 |
+
"""
|
| 84 |
+
Load siglip2 ViT model, return model, transforms and size of output (dimension of embedding of last token)
|
| 85 |
+
"""
|
| 86 |
+
print(f"Loading Hugging Face model: {model_name}")
|
| 87 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Extract image processing configuration from the loaded processor
|
| 91 |
+
image_processor_config = processor.image_processor
|
| 92 |
+
image_size = image_processor_config.size['height']
|
| 93 |
+
image_mean = image_processor_config.image_mean
|
| 94 |
+
image_std = image_processor_config.image_std
|
| 95 |
+
|
| 96 |
+
transform = transforms.Compose([
|
| 97 |
+
transforms.Lambda(_convert_to_rgb),
|
| 98 |
+
transforms.Resize((image_size, image_size), antialias=True),
|
| 99 |
+
transforms.ToTensor(),
|
| 100 |
+
transforms.Normalize(mean=image_mean, std=image_std)
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
# Load the model and return only the vision backbone
|
| 104 |
+
model = AutoModel.from_pretrained(model_name)
|
| 105 |
+
vision_model = model.vision_model
|
| 106 |
+
|
| 107 |
+
if print_info:
|
| 108 |
+
print(f'\nVISION CONFIGS:\n{vision_model.config}')
|
| 109 |
+
print(f'\n\n***************MHAP\n{vision_model.head}')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
return vision_model, transform, vision_model.config.hidden_size
|
| 113 |
+
|
| 114 |
+
def _convert_to_rgb(image: Image.Image) -> Image.Image:
|
| 115 |
+
"""Converts a PIL Image to RGB format."""
|
| 116 |
+
return image.convert("RGB")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_backbone(version: str, apply_migration : bool = False):
|
| 120 |
+
"""
|
| 121 |
+
Returns vision transformer backbone
|
| 122 |
+
Args:
|
| 123 |
+
version: Name of the backbone to use, PE-Core or Siglip
|
| 124 |
+
ckpt: if different from null, loads backbone from .pt file specified, only for PE
|
| 125 |
+
"""
|
| 126 |
+
if 'PE-Core-' in version:
|
| 127 |
+
return get_backbone_pe(version, False, apply_migration)
|
| 128 |
+
elif 'siglip2' in version:
|
| 129 |
+
print('[LOADING SIGLIP2]')
|
| 130 |
+
return get_backbone_siglip2(version)
|
| 131 |
+
elif 'dinov3' in version:
|
| 132 |
+
return get_backbone_dinov3(version)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def send_telegram_message(message: str):
|
| 137 |
+
"""Sends a message to a Telegram chat using credentials from the config."""
|
| 138 |
+
# Get credentials from your config object. Use getattr for safety.
|
| 139 |
+
token = os.getenv("BOT_TOKEN")
|
| 140 |
+
chat_id = "1220514183"
|
| 141 |
+
|
| 142 |
+
if not token or not chat_id:
|
| 143 |
+
# Silently fail if credentials are not set
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
api_url = f"https://api.telegram.org/bot{token}/sendMessage"
|
| 147 |
+
payload = {
|
| 148 |
+
'chat_id': chat_id,
|
| 149 |
+
'text': message,
|
| 150 |
+
'parse_mode': 'Markdown' # For nice formatting (bold, italics, etc.)
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
response = requests.post(api_url, data=payload, timeout=10)
|
| 155 |
+
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
| 156 |
+
except requests.exceptions.RequestException as e:
|
| 157 |
+
# Don't crash the training loop if Telegram is down
|
| 158 |
+
print(f"\nWarning: Could not send Telegram message. Error: {e}")
|
utils/deploy.prototxt
ADDED
|
@@ -0,0 +1,1789 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input: "data"
|
| 2 |
+
input_shape {
|
| 3 |
+
dim: 1
|
| 4 |
+
dim: 3
|
| 5 |
+
dim: 300
|
| 6 |
+
dim: 300
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
layer {
|
| 10 |
+
name: "data_bn"
|
| 11 |
+
type: "BatchNorm"
|
| 12 |
+
bottom: "data"
|
| 13 |
+
top: "data_bn"
|
| 14 |
+
param {
|
| 15 |
+
lr_mult: 0.0
|
| 16 |
+
}
|
| 17 |
+
param {
|
| 18 |
+
lr_mult: 0.0
|
| 19 |
+
}
|
| 20 |
+
param {
|
| 21 |
+
lr_mult: 0.0
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
layer {
|
| 25 |
+
name: "data_scale"
|
| 26 |
+
type: "Scale"
|
| 27 |
+
bottom: "data_bn"
|
| 28 |
+
top: "data_bn"
|
| 29 |
+
param {
|
| 30 |
+
lr_mult: 1.0
|
| 31 |
+
decay_mult: 1.0
|
| 32 |
+
}
|
| 33 |
+
param {
|
| 34 |
+
lr_mult: 2.0
|
| 35 |
+
decay_mult: 1.0
|
| 36 |
+
}
|
| 37 |
+
scale_param {
|
| 38 |
+
bias_term: true
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
layer {
|
| 42 |
+
name: "conv1_h"
|
| 43 |
+
type: "Convolution"
|
| 44 |
+
bottom: "data_bn"
|
| 45 |
+
top: "conv1_h"
|
| 46 |
+
param {
|
| 47 |
+
lr_mult: 1.0
|
| 48 |
+
decay_mult: 1.0
|
| 49 |
+
}
|
| 50 |
+
param {
|
| 51 |
+
lr_mult: 2.0
|
| 52 |
+
decay_mult: 1.0
|
| 53 |
+
}
|
| 54 |
+
convolution_param {
|
| 55 |
+
num_output: 32
|
| 56 |
+
pad: 3
|
| 57 |
+
kernel_size: 7
|
| 58 |
+
stride: 2
|
| 59 |
+
weight_filler {
|
| 60 |
+
type: "msra"
|
| 61 |
+
variance_norm: FAN_OUT
|
| 62 |
+
}
|
| 63 |
+
bias_filler {
|
| 64 |
+
type: "constant"
|
| 65 |
+
value: 0.0
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
layer {
|
| 70 |
+
name: "conv1_bn_h"
|
| 71 |
+
type: "BatchNorm"
|
| 72 |
+
bottom: "conv1_h"
|
| 73 |
+
top: "conv1_h"
|
| 74 |
+
param {
|
| 75 |
+
lr_mult: 0.0
|
| 76 |
+
}
|
| 77 |
+
param {
|
| 78 |
+
lr_mult: 0.0
|
| 79 |
+
}
|
| 80 |
+
param {
|
| 81 |
+
lr_mult: 0.0
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
layer {
|
| 85 |
+
name: "conv1_scale_h"
|
| 86 |
+
type: "Scale"
|
| 87 |
+
bottom: "conv1_h"
|
| 88 |
+
top: "conv1_h"
|
| 89 |
+
param {
|
| 90 |
+
lr_mult: 1.0
|
| 91 |
+
decay_mult: 1.0
|
| 92 |
+
}
|
| 93 |
+
param {
|
| 94 |
+
lr_mult: 2.0
|
| 95 |
+
decay_mult: 1.0
|
| 96 |
+
}
|
| 97 |
+
scale_param {
|
| 98 |
+
bias_term: true
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
layer {
|
| 102 |
+
name: "conv1_relu"
|
| 103 |
+
type: "ReLU"
|
| 104 |
+
bottom: "conv1_h"
|
| 105 |
+
top: "conv1_h"
|
| 106 |
+
}
|
| 107 |
+
layer {
|
| 108 |
+
name: "conv1_pool"
|
| 109 |
+
type: "Pooling"
|
| 110 |
+
bottom: "conv1_h"
|
| 111 |
+
top: "conv1_pool"
|
| 112 |
+
pooling_param {
|
| 113 |
+
kernel_size: 3
|
| 114 |
+
stride: 2
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
layer {
|
| 118 |
+
name: "layer_64_1_conv1_h"
|
| 119 |
+
type: "Convolution"
|
| 120 |
+
bottom: "conv1_pool"
|
| 121 |
+
top: "layer_64_1_conv1_h"
|
| 122 |
+
param {
|
| 123 |
+
lr_mult: 1.0
|
| 124 |
+
decay_mult: 1.0
|
| 125 |
+
}
|
| 126 |
+
convolution_param {
|
| 127 |
+
num_output: 32
|
| 128 |
+
bias_term: false
|
| 129 |
+
pad: 1
|
| 130 |
+
kernel_size: 3
|
| 131 |
+
stride: 1
|
| 132 |
+
weight_filler {
|
| 133 |
+
type: "msra"
|
| 134 |
+
}
|
| 135 |
+
bias_filler {
|
| 136 |
+
type: "constant"
|
| 137 |
+
value: 0.0
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
layer {
|
| 142 |
+
name: "layer_64_1_bn2_h"
|
| 143 |
+
type: "BatchNorm"
|
| 144 |
+
bottom: "layer_64_1_conv1_h"
|
| 145 |
+
top: "layer_64_1_conv1_h"
|
| 146 |
+
param {
|
| 147 |
+
lr_mult: 0.0
|
| 148 |
+
}
|
| 149 |
+
param {
|
| 150 |
+
lr_mult: 0.0
|
| 151 |
+
}
|
| 152 |
+
param {
|
| 153 |
+
lr_mult: 0.0
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
layer {
|
| 157 |
+
name: "layer_64_1_scale2_h"
|
| 158 |
+
type: "Scale"
|
| 159 |
+
bottom: "layer_64_1_conv1_h"
|
| 160 |
+
top: "layer_64_1_conv1_h"
|
| 161 |
+
param {
|
| 162 |
+
lr_mult: 1.0
|
| 163 |
+
decay_mult: 1.0
|
| 164 |
+
}
|
| 165 |
+
param {
|
| 166 |
+
lr_mult: 2.0
|
| 167 |
+
decay_mult: 1.0
|
| 168 |
+
}
|
| 169 |
+
scale_param {
|
| 170 |
+
bias_term: true
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
layer {
|
| 174 |
+
name: "layer_64_1_relu2"
|
| 175 |
+
type: "ReLU"
|
| 176 |
+
bottom: "layer_64_1_conv1_h"
|
| 177 |
+
top: "layer_64_1_conv1_h"
|
| 178 |
+
}
|
| 179 |
+
layer {
|
| 180 |
+
name: "layer_64_1_conv2_h"
|
| 181 |
+
type: "Convolution"
|
| 182 |
+
bottom: "layer_64_1_conv1_h"
|
| 183 |
+
top: "layer_64_1_conv2_h"
|
| 184 |
+
param {
|
| 185 |
+
lr_mult: 1.0
|
| 186 |
+
decay_mult: 1.0
|
| 187 |
+
}
|
| 188 |
+
convolution_param {
|
| 189 |
+
num_output: 32
|
| 190 |
+
bias_term: false
|
| 191 |
+
pad: 1
|
| 192 |
+
kernel_size: 3
|
| 193 |
+
stride: 1
|
| 194 |
+
weight_filler {
|
| 195 |
+
type: "msra"
|
| 196 |
+
}
|
| 197 |
+
bias_filler {
|
| 198 |
+
type: "constant"
|
| 199 |
+
value: 0.0
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
layer {
|
| 204 |
+
name: "layer_64_1_sum"
|
| 205 |
+
type: "Eltwise"
|
| 206 |
+
bottom: "layer_64_1_conv2_h"
|
| 207 |
+
bottom: "conv1_pool"
|
| 208 |
+
top: "layer_64_1_sum"
|
| 209 |
+
}
|
| 210 |
+
layer {
|
| 211 |
+
name: "layer_128_1_bn1_h"
|
| 212 |
+
type: "BatchNorm"
|
| 213 |
+
bottom: "layer_64_1_sum"
|
| 214 |
+
top: "layer_128_1_bn1_h"
|
| 215 |
+
param {
|
| 216 |
+
lr_mult: 0.0
|
| 217 |
+
}
|
| 218 |
+
param {
|
| 219 |
+
lr_mult: 0.0
|
| 220 |
+
}
|
| 221 |
+
param {
|
| 222 |
+
lr_mult: 0.0
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
layer {
|
| 226 |
+
name: "layer_128_1_scale1_h"
|
| 227 |
+
type: "Scale"
|
| 228 |
+
bottom: "layer_128_1_bn1_h"
|
| 229 |
+
top: "layer_128_1_bn1_h"
|
| 230 |
+
param {
|
| 231 |
+
lr_mult: 1.0
|
| 232 |
+
decay_mult: 1.0
|
| 233 |
+
}
|
| 234 |
+
param {
|
| 235 |
+
lr_mult: 2.0
|
| 236 |
+
decay_mult: 1.0
|
| 237 |
+
}
|
| 238 |
+
scale_param {
|
| 239 |
+
bias_term: true
|
| 240 |
+
}
|
| 241 |
+
}
|
| 242 |
+
layer {
|
| 243 |
+
name: "layer_128_1_relu1"
|
| 244 |
+
type: "ReLU"
|
| 245 |
+
bottom: "layer_128_1_bn1_h"
|
| 246 |
+
top: "layer_128_1_bn1_h"
|
| 247 |
+
}
|
| 248 |
+
layer {
|
| 249 |
+
name: "layer_128_1_conv1_h"
|
| 250 |
+
type: "Convolution"
|
| 251 |
+
bottom: "layer_128_1_bn1_h"
|
| 252 |
+
top: "layer_128_1_conv1_h"
|
| 253 |
+
param {
|
| 254 |
+
lr_mult: 1.0
|
| 255 |
+
decay_mult: 1.0
|
| 256 |
+
}
|
| 257 |
+
convolution_param {
|
| 258 |
+
num_output: 128
|
| 259 |
+
bias_term: false
|
| 260 |
+
pad: 1
|
| 261 |
+
kernel_size: 3
|
| 262 |
+
stride: 2
|
| 263 |
+
weight_filler {
|
| 264 |
+
type: "msra"
|
| 265 |
+
}
|
| 266 |
+
bias_filler {
|
| 267 |
+
type: "constant"
|
| 268 |
+
value: 0.0
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
}
|
| 272 |
+
layer {
|
| 273 |
+
name: "layer_128_1_bn2"
|
| 274 |
+
type: "BatchNorm"
|
| 275 |
+
bottom: "layer_128_1_conv1_h"
|
| 276 |
+
top: "layer_128_1_conv1_h"
|
| 277 |
+
param {
|
| 278 |
+
lr_mult: 0.0
|
| 279 |
+
}
|
| 280 |
+
param {
|
| 281 |
+
lr_mult: 0.0
|
| 282 |
+
}
|
| 283 |
+
param {
|
| 284 |
+
lr_mult: 0.0
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
layer {
|
| 288 |
+
name: "layer_128_1_scale2"
|
| 289 |
+
type: "Scale"
|
| 290 |
+
bottom: "layer_128_1_conv1_h"
|
| 291 |
+
top: "layer_128_1_conv1_h"
|
| 292 |
+
param {
|
| 293 |
+
lr_mult: 1.0
|
| 294 |
+
decay_mult: 1.0
|
| 295 |
+
}
|
| 296 |
+
param {
|
| 297 |
+
lr_mult: 2.0
|
| 298 |
+
decay_mult: 1.0
|
| 299 |
+
}
|
| 300 |
+
scale_param {
|
| 301 |
+
bias_term: true
|
| 302 |
+
}
|
| 303 |
+
}
|
| 304 |
+
layer {
|
| 305 |
+
name: "layer_128_1_relu2"
|
| 306 |
+
type: "ReLU"
|
| 307 |
+
bottom: "layer_128_1_conv1_h"
|
| 308 |
+
top: "layer_128_1_conv1_h"
|
| 309 |
+
}
|
| 310 |
+
layer {
|
| 311 |
+
name: "layer_128_1_conv2"
|
| 312 |
+
type: "Convolution"
|
| 313 |
+
bottom: "layer_128_1_conv1_h"
|
| 314 |
+
top: "layer_128_1_conv2"
|
| 315 |
+
param {
|
| 316 |
+
lr_mult: 1.0
|
| 317 |
+
decay_mult: 1.0
|
| 318 |
+
}
|
| 319 |
+
convolution_param {
|
| 320 |
+
num_output: 128
|
| 321 |
+
bias_term: false
|
| 322 |
+
pad: 1
|
| 323 |
+
kernel_size: 3
|
| 324 |
+
stride: 1
|
| 325 |
+
weight_filler {
|
| 326 |
+
type: "msra"
|
| 327 |
+
}
|
| 328 |
+
bias_filler {
|
| 329 |
+
type: "constant"
|
| 330 |
+
value: 0.0
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
layer {
|
| 335 |
+
name: "layer_128_1_conv_expand_h"
|
| 336 |
+
type: "Convolution"
|
| 337 |
+
bottom: "layer_128_1_bn1_h"
|
| 338 |
+
top: "layer_128_1_conv_expand_h"
|
| 339 |
+
param {
|
| 340 |
+
lr_mult: 1.0
|
| 341 |
+
decay_mult: 1.0
|
| 342 |
+
}
|
| 343 |
+
convolution_param {
|
| 344 |
+
num_output: 128
|
| 345 |
+
bias_term: false
|
| 346 |
+
pad: 0
|
| 347 |
+
kernel_size: 1
|
| 348 |
+
stride: 2
|
| 349 |
+
weight_filler {
|
| 350 |
+
type: "msra"
|
| 351 |
+
}
|
| 352 |
+
bias_filler {
|
| 353 |
+
type: "constant"
|
| 354 |
+
value: 0.0
|
| 355 |
+
}
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
layer {
|
| 359 |
+
name: "layer_128_1_sum"
|
| 360 |
+
type: "Eltwise"
|
| 361 |
+
bottom: "layer_128_1_conv2"
|
| 362 |
+
bottom: "layer_128_1_conv_expand_h"
|
| 363 |
+
top: "layer_128_1_sum"
|
| 364 |
+
}
|
| 365 |
+
layer {
|
| 366 |
+
name: "layer_256_1_bn1"
|
| 367 |
+
type: "BatchNorm"
|
| 368 |
+
bottom: "layer_128_1_sum"
|
| 369 |
+
top: "layer_256_1_bn1"
|
| 370 |
+
param {
|
| 371 |
+
lr_mult: 0.0
|
| 372 |
+
}
|
| 373 |
+
param {
|
| 374 |
+
lr_mult: 0.0
|
| 375 |
+
}
|
| 376 |
+
param {
|
| 377 |
+
lr_mult: 0.0
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
layer {
|
| 381 |
+
name: "layer_256_1_scale1"
|
| 382 |
+
type: "Scale"
|
| 383 |
+
bottom: "layer_256_1_bn1"
|
| 384 |
+
top: "layer_256_1_bn1"
|
| 385 |
+
param {
|
| 386 |
+
lr_mult: 1.0
|
| 387 |
+
decay_mult: 1.0
|
| 388 |
+
}
|
| 389 |
+
param {
|
| 390 |
+
lr_mult: 2.0
|
| 391 |
+
decay_mult: 1.0
|
| 392 |
+
}
|
| 393 |
+
scale_param {
|
| 394 |
+
bias_term: true
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
layer {
|
| 398 |
+
name: "layer_256_1_relu1"
|
| 399 |
+
type: "ReLU"
|
| 400 |
+
bottom: "layer_256_1_bn1"
|
| 401 |
+
top: "layer_256_1_bn1"
|
| 402 |
+
}
|
| 403 |
+
layer {
|
| 404 |
+
name: "layer_256_1_conv1"
|
| 405 |
+
type: "Convolution"
|
| 406 |
+
bottom: "layer_256_1_bn1"
|
| 407 |
+
top: "layer_256_1_conv1"
|
| 408 |
+
param {
|
| 409 |
+
lr_mult: 1.0
|
| 410 |
+
decay_mult: 1.0
|
| 411 |
+
}
|
| 412 |
+
convolution_param {
|
| 413 |
+
num_output: 256
|
| 414 |
+
bias_term: false
|
| 415 |
+
pad: 1
|
| 416 |
+
kernel_size: 3
|
| 417 |
+
stride: 2
|
| 418 |
+
weight_filler {
|
| 419 |
+
type: "msra"
|
| 420 |
+
}
|
| 421 |
+
bias_filler {
|
| 422 |
+
type: "constant"
|
| 423 |
+
value: 0.0
|
| 424 |
+
}
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
layer {
|
| 428 |
+
name: "layer_256_1_bn2"
|
| 429 |
+
type: "BatchNorm"
|
| 430 |
+
bottom: "layer_256_1_conv1"
|
| 431 |
+
top: "layer_256_1_conv1"
|
| 432 |
+
param {
|
| 433 |
+
lr_mult: 0.0
|
| 434 |
+
}
|
| 435 |
+
param {
|
| 436 |
+
lr_mult: 0.0
|
| 437 |
+
}
|
| 438 |
+
param {
|
| 439 |
+
lr_mult: 0.0
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
layer {
|
| 443 |
+
name: "layer_256_1_scale2"
|
| 444 |
+
type: "Scale"
|
| 445 |
+
bottom: "layer_256_1_conv1"
|
| 446 |
+
top: "layer_256_1_conv1"
|
| 447 |
+
param {
|
| 448 |
+
lr_mult: 1.0
|
| 449 |
+
decay_mult: 1.0
|
| 450 |
+
}
|
| 451 |
+
param {
|
| 452 |
+
lr_mult: 2.0
|
| 453 |
+
decay_mult: 1.0
|
| 454 |
+
}
|
| 455 |
+
scale_param {
|
| 456 |
+
bias_term: true
|
| 457 |
+
}
|
| 458 |
+
}
|
| 459 |
+
layer {
|
| 460 |
+
name: "layer_256_1_relu2"
|
| 461 |
+
type: "ReLU"
|
| 462 |
+
bottom: "layer_256_1_conv1"
|
| 463 |
+
top: "layer_256_1_conv1"
|
| 464 |
+
}
|
| 465 |
+
layer {
|
| 466 |
+
name: "layer_256_1_conv2"
|
| 467 |
+
type: "Convolution"
|
| 468 |
+
bottom: "layer_256_1_conv1"
|
| 469 |
+
top: "layer_256_1_conv2"
|
| 470 |
+
param {
|
| 471 |
+
lr_mult: 1.0
|
| 472 |
+
decay_mult: 1.0
|
| 473 |
+
}
|
| 474 |
+
convolution_param {
|
| 475 |
+
num_output: 256
|
| 476 |
+
bias_term: false
|
| 477 |
+
pad: 1
|
| 478 |
+
kernel_size: 3
|
| 479 |
+
stride: 1
|
| 480 |
+
weight_filler {
|
| 481 |
+
type: "msra"
|
| 482 |
+
}
|
| 483 |
+
bias_filler {
|
| 484 |
+
type: "constant"
|
| 485 |
+
value: 0.0
|
| 486 |
+
}
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
layer {
|
| 490 |
+
name: "layer_256_1_conv_expand"
|
| 491 |
+
type: "Convolution"
|
| 492 |
+
bottom: "layer_256_1_bn1"
|
| 493 |
+
top: "layer_256_1_conv_expand"
|
| 494 |
+
param {
|
| 495 |
+
lr_mult: 1.0
|
| 496 |
+
decay_mult: 1.0
|
| 497 |
+
}
|
| 498 |
+
convolution_param {
|
| 499 |
+
num_output: 256
|
| 500 |
+
bias_term: false
|
| 501 |
+
pad: 0
|
| 502 |
+
kernel_size: 1
|
| 503 |
+
stride: 2
|
| 504 |
+
weight_filler {
|
| 505 |
+
type: "msra"
|
| 506 |
+
}
|
| 507 |
+
bias_filler {
|
| 508 |
+
type: "constant"
|
| 509 |
+
value: 0.0
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
layer {
|
| 514 |
+
name: "layer_256_1_sum"
|
| 515 |
+
type: "Eltwise"
|
| 516 |
+
bottom: "layer_256_1_conv2"
|
| 517 |
+
bottom: "layer_256_1_conv_expand"
|
| 518 |
+
top: "layer_256_1_sum"
|
| 519 |
+
}
|
| 520 |
+
layer {
|
| 521 |
+
name: "layer_512_1_bn1"
|
| 522 |
+
type: "BatchNorm"
|
| 523 |
+
bottom: "layer_256_1_sum"
|
| 524 |
+
top: "layer_512_1_bn1"
|
| 525 |
+
param {
|
| 526 |
+
lr_mult: 0.0
|
| 527 |
+
}
|
| 528 |
+
param {
|
| 529 |
+
lr_mult: 0.0
|
| 530 |
+
}
|
| 531 |
+
param {
|
| 532 |
+
lr_mult: 0.0
|
| 533 |
+
}
|
| 534 |
+
}
|
| 535 |
+
layer {
|
| 536 |
+
name: "layer_512_1_scale1"
|
| 537 |
+
type: "Scale"
|
| 538 |
+
bottom: "layer_512_1_bn1"
|
| 539 |
+
top: "layer_512_1_bn1"
|
| 540 |
+
param {
|
| 541 |
+
lr_mult: 1.0
|
| 542 |
+
decay_mult: 1.0
|
| 543 |
+
}
|
| 544 |
+
param {
|
| 545 |
+
lr_mult: 2.0
|
| 546 |
+
decay_mult: 1.0
|
| 547 |
+
}
|
| 548 |
+
scale_param {
|
| 549 |
+
bias_term: true
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
layer {
|
| 553 |
+
name: "layer_512_1_relu1"
|
| 554 |
+
type: "ReLU"
|
| 555 |
+
bottom: "layer_512_1_bn1"
|
| 556 |
+
top: "layer_512_1_bn1"
|
| 557 |
+
}
|
| 558 |
+
layer {
|
| 559 |
+
name: "layer_512_1_conv1_h"
|
| 560 |
+
type: "Convolution"
|
| 561 |
+
bottom: "layer_512_1_bn1"
|
| 562 |
+
top: "layer_512_1_conv1_h"
|
| 563 |
+
param {
|
| 564 |
+
lr_mult: 1.0
|
| 565 |
+
decay_mult: 1.0
|
| 566 |
+
}
|
| 567 |
+
convolution_param {
|
| 568 |
+
num_output: 128
|
| 569 |
+
bias_term: false
|
| 570 |
+
pad: 1
|
| 571 |
+
kernel_size: 3
|
| 572 |
+
stride: 1 # 2
|
| 573 |
+
weight_filler {
|
| 574 |
+
type: "msra"
|
| 575 |
+
}
|
| 576 |
+
bias_filler {
|
| 577 |
+
type: "constant"
|
| 578 |
+
value: 0.0
|
| 579 |
+
}
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
layer {
|
| 583 |
+
name: "layer_512_1_bn2_h"
|
| 584 |
+
type: "BatchNorm"
|
| 585 |
+
bottom: "layer_512_1_conv1_h"
|
| 586 |
+
top: "layer_512_1_conv1_h"
|
| 587 |
+
param {
|
| 588 |
+
lr_mult: 0.0
|
| 589 |
+
}
|
| 590 |
+
param {
|
| 591 |
+
lr_mult: 0.0
|
| 592 |
+
}
|
| 593 |
+
param {
|
| 594 |
+
lr_mult: 0.0
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
layer {
|
| 598 |
+
name: "layer_512_1_scale2_h"
|
| 599 |
+
type: "Scale"
|
| 600 |
+
bottom: "layer_512_1_conv1_h"
|
| 601 |
+
top: "layer_512_1_conv1_h"
|
| 602 |
+
param {
|
| 603 |
+
lr_mult: 1.0
|
| 604 |
+
decay_mult: 1.0
|
| 605 |
+
}
|
| 606 |
+
param {
|
| 607 |
+
lr_mult: 2.0
|
| 608 |
+
decay_mult: 1.0
|
| 609 |
+
}
|
| 610 |
+
scale_param {
|
| 611 |
+
bias_term: true
|
| 612 |
+
}
|
| 613 |
+
}
|
| 614 |
+
layer {
|
| 615 |
+
name: "layer_512_1_relu2"
|
| 616 |
+
type: "ReLU"
|
| 617 |
+
bottom: "layer_512_1_conv1_h"
|
| 618 |
+
top: "layer_512_1_conv1_h"
|
| 619 |
+
}
|
| 620 |
+
layer {
|
| 621 |
+
name: "layer_512_1_conv2_h"
|
| 622 |
+
type: "Convolution"
|
| 623 |
+
bottom: "layer_512_1_conv1_h"
|
| 624 |
+
top: "layer_512_1_conv2_h"
|
| 625 |
+
param {
|
| 626 |
+
lr_mult: 1.0
|
| 627 |
+
decay_mult: 1.0
|
| 628 |
+
}
|
| 629 |
+
convolution_param {
|
| 630 |
+
num_output: 256
|
| 631 |
+
bias_term: false
|
| 632 |
+
pad: 2 # 1
|
| 633 |
+
kernel_size: 3
|
| 634 |
+
stride: 1
|
| 635 |
+
dilation: 2
|
| 636 |
+
weight_filler {
|
| 637 |
+
type: "msra"
|
| 638 |
+
}
|
| 639 |
+
bias_filler {
|
| 640 |
+
type: "constant"
|
| 641 |
+
value: 0.0
|
| 642 |
+
}
|
| 643 |
+
}
|
| 644 |
+
}
|
| 645 |
+
layer {
|
| 646 |
+
name: "layer_512_1_conv_expand_h"
|
| 647 |
+
type: "Convolution"
|
| 648 |
+
bottom: "layer_512_1_bn1"
|
| 649 |
+
top: "layer_512_1_conv_expand_h"
|
| 650 |
+
param {
|
| 651 |
+
lr_mult: 1.0
|
| 652 |
+
decay_mult: 1.0
|
| 653 |
+
}
|
| 654 |
+
convolution_param {
|
| 655 |
+
num_output: 256
|
| 656 |
+
bias_term: false
|
| 657 |
+
pad: 0
|
| 658 |
+
kernel_size: 1
|
| 659 |
+
stride: 1 # 2
|
| 660 |
+
weight_filler {
|
| 661 |
+
type: "msra"
|
| 662 |
+
}
|
| 663 |
+
bias_filler {
|
| 664 |
+
type: "constant"
|
| 665 |
+
value: 0.0
|
| 666 |
+
}
|
| 667 |
+
}
|
| 668 |
+
}
|
| 669 |
+
layer {
|
| 670 |
+
name: "layer_512_1_sum"
|
| 671 |
+
type: "Eltwise"
|
| 672 |
+
bottom: "layer_512_1_conv2_h"
|
| 673 |
+
bottom: "layer_512_1_conv_expand_h"
|
| 674 |
+
top: "layer_512_1_sum"
|
| 675 |
+
}
|
| 676 |
+
layer {
|
| 677 |
+
name: "last_bn_h"
|
| 678 |
+
type: "BatchNorm"
|
| 679 |
+
bottom: "layer_512_1_sum"
|
| 680 |
+
top: "layer_512_1_sum"
|
| 681 |
+
param {
|
| 682 |
+
lr_mult: 0.0
|
| 683 |
+
}
|
| 684 |
+
param {
|
| 685 |
+
lr_mult: 0.0
|
| 686 |
+
}
|
| 687 |
+
param {
|
| 688 |
+
lr_mult: 0.0
|
| 689 |
+
}
|
| 690 |
+
}
|
| 691 |
+
layer {
|
| 692 |
+
name: "last_scale_h"
|
| 693 |
+
type: "Scale"
|
| 694 |
+
bottom: "layer_512_1_sum"
|
| 695 |
+
top: "layer_512_1_sum"
|
| 696 |
+
param {
|
| 697 |
+
lr_mult: 1.0
|
| 698 |
+
decay_mult: 1.0
|
| 699 |
+
}
|
| 700 |
+
param {
|
| 701 |
+
lr_mult: 2.0
|
| 702 |
+
decay_mult: 1.0
|
| 703 |
+
}
|
| 704 |
+
scale_param {
|
| 705 |
+
bias_term: true
|
| 706 |
+
}
|
| 707 |
+
}
|
| 708 |
+
layer {
|
| 709 |
+
name: "last_relu"
|
| 710 |
+
type: "ReLU"
|
| 711 |
+
bottom: "layer_512_1_sum"
|
| 712 |
+
top: "fc7"
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
layer {
|
| 716 |
+
name: "conv6_1_h"
|
| 717 |
+
type: "Convolution"
|
| 718 |
+
bottom: "fc7"
|
| 719 |
+
top: "conv6_1_h"
|
| 720 |
+
param {
|
| 721 |
+
lr_mult: 1
|
| 722 |
+
decay_mult: 1
|
| 723 |
+
}
|
| 724 |
+
param {
|
| 725 |
+
lr_mult: 2
|
| 726 |
+
decay_mult: 0
|
| 727 |
+
}
|
| 728 |
+
convolution_param {
|
| 729 |
+
num_output: 128
|
| 730 |
+
pad: 0
|
| 731 |
+
kernel_size: 1
|
| 732 |
+
stride: 1
|
| 733 |
+
weight_filler {
|
| 734 |
+
type: "xavier"
|
| 735 |
+
}
|
| 736 |
+
bias_filler {
|
| 737 |
+
type: "constant"
|
| 738 |
+
value: 0
|
| 739 |
+
}
|
| 740 |
+
}
|
| 741 |
+
}
|
| 742 |
+
layer {
|
| 743 |
+
name: "conv6_1_relu"
|
| 744 |
+
type: "ReLU"
|
| 745 |
+
bottom: "conv6_1_h"
|
| 746 |
+
top: "conv6_1_h"
|
| 747 |
+
}
|
| 748 |
+
layer {
|
| 749 |
+
name: "conv6_2_h"
|
| 750 |
+
type: "Convolution"
|
| 751 |
+
bottom: "conv6_1_h"
|
| 752 |
+
top: "conv6_2_h"
|
| 753 |
+
param {
|
| 754 |
+
lr_mult: 1
|
| 755 |
+
decay_mult: 1
|
| 756 |
+
}
|
| 757 |
+
param {
|
| 758 |
+
lr_mult: 2
|
| 759 |
+
decay_mult: 0
|
| 760 |
+
}
|
| 761 |
+
convolution_param {
|
| 762 |
+
num_output: 256
|
| 763 |
+
pad: 1
|
| 764 |
+
kernel_size: 3
|
| 765 |
+
stride: 2
|
| 766 |
+
weight_filler {
|
| 767 |
+
type: "xavier"
|
| 768 |
+
}
|
| 769 |
+
bias_filler {
|
| 770 |
+
type: "constant"
|
| 771 |
+
value: 0
|
| 772 |
+
}
|
| 773 |
+
}
|
| 774 |
+
}
|
| 775 |
+
layer {
|
| 776 |
+
name: "conv6_2_relu"
|
| 777 |
+
type: "ReLU"
|
| 778 |
+
bottom: "conv6_2_h"
|
| 779 |
+
top: "conv6_2_h"
|
| 780 |
+
}
|
| 781 |
+
layer {
|
| 782 |
+
name: "conv7_1_h"
|
| 783 |
+
type: "Convolution"
|
| 784 |
+
bottom: "conv6_2_h"
|
| 785 |
+
top: "conv7_1_h"
|
| 786 |
+
param {
|
| 787 |
+
lr_mult: 1
|
| 788 |
+
decay_mult: 1
|
| 789 |
+
}
|
| 790 |
+
param {
|
| 791 |
+
lr_mult: 2
|
| 792 |
+
decay_mult: 0
|
| 793 |
+
}
|
| 794 |
+
convolution_param {
|
| 795 |
+
num_output: 64
|
| 796 |
+
pad: 0
|
| 797 |
+
kernel_size: 1
|
| 798 |
+
stride: 1
|
| 799 |
+
weight_filler {
|
| 800 |
+
type: "xavier"
|
| 801 |
+
}
|
| 802 |
+
bias_filler {
|
| 803 |
+
type: "constant"
|
| 804 |
+
value: 0
|
| 805 |
+
}
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
layer {
|
| 809 |
+
name: "conv7_1_relu"
|
| 810 |
+
type: "ReLU"
|
| 811 |
+
bottom: "conv7_1_h"
|
| 812 |
+
top: "conv7_1_h"
|
| 813 |
+
}
|
| 814 |
+
layer {
|
| 815 |
+
name: "conv7_2_h"
|
| 816 |
+
type: "Convolution"
|
| 817 |
+
bottom: "conv7_1_h"
|
| 818 |
+
top: "conv7_2_h"
|
| 819 |
+
param {
|
| 820 |
+
lr_mult: 1
|
| 821 |
+
decay_mult: 1
|
| 822 |
+
}
|
| 823 |
+
param {
|
| 824 |
+
lr_mult: 2
|
| 825 |
+
decay_mult: 0
|
| 826 |
+
}
|
| 827 |
+
convolution_param {
|
| 828 |
+
num_output: 128
|
| 829 |
+
pad: 1
|
| 830 |
+
kernel_size: 3
|
| 831 |
+
stride: 2
|
| 832 |
+
weight_filler {
|
| 833 |
+
type: "xavier"
|
| 834 |
+
}
|
| 835 |
+
bias_filler {
|
| 836 |
+
type: "constant"
|
| 837 |
+
value: 0
|
| 838 |
+
}
|
| 839 |
+
}
|
| 840 |
+
}
|
| 841 |
+
layer {
|
| 842 |
+
name: "conv7_2_relu"
|
| 843 |
+
type: "ReLU"
|
| 844 |
+
bottom: "conv7_2_h"
|
| 845 |
+
top: "conv7_2_h"
|
| 846 |
+
}
|
| 847 |
+
layer {
|
| 848 |
+
name: "conv8_1_h"
|
| 849 |
+
type: "Convolution"
|
| 850 |
+
bottom: "conv7_2_h"
|
| 851 |
+
top: "conv8_1_h"
|
| 852 |
+
param {
|
| 853 |
+
lr_mult: 1
|
| 854 |
+
decay_mult: 1
|
| 855 |
+
}
|
| 856 |
+
param {
|
| 857 |
+
lr_mult: 2
|
| 858 |
+
decay_mult: 0
|
| 859 |
+
}
|
| 860 |
+
convolution_param {
|
| 861 |
+
num_output: 64
|
| 862 |
+
pad: 0
|
| 863 |
+
kernel_size: 1
|
| 864 |
+
stride: 1
|
| 865 |
+
weight_filler {
|
| 866 |
+
type: "xavier"
|
| 867 |
+
}
|
| 868 |
+
bias_filler {
|
| 869 |
+
type: "constant"
|
| 870 |
+
value: 0
|
| 871 |
+
}
|
| 872 |
+
}
|
| 873 |
+
}
|
| 874 |
+
layer {
|
| 875 |
+
name: "conv8_1_relu"
|
| 876 |
+
type: "ReLU"
|
| 877 |
+
bottom: "conv8_1_h"
|
| 878 |
+
top: "conv8_1_h"
|
| 879 |
+
}
|
| 880 |
+
layer {
|
| 881 |
+
name: "conv8_2_h"
|
| 882 |
+
type: "Convolution"
|
| 883 |
+
bottom: "conv8_1_h"
|
| 884 |
+
top: "conv8_2_h"
|
| 885 |
+
param {
|
| 886 |
+
lr_mult: 1
|
| 887 |
+
decay_mult: 1
|
| 888 |
+
}
|
| 889 |
+
param {
|
| 890 |
+
lr_mult: 2
|
| 891 |
+
decay_mult: 0
|
| 892 |
+
}
|
| 893 |
+
convolution_param {
|
| 894 |
+
num_output: 128
|
| 895 |
+
pad: 1
|
| 896 |
+
kernel_size: 3
|
| 897 |
+
stride: 1
|
| 898 |
+
weight_filler {
|
| 899 |
+
type: "xavier"
|
| 900 |
+
}
|
| 901 |
+
bias_filler {
|
| 902 |
+
type: "constant"
|
| 903 |
+
value: 0
|
| 904 |
+
}
|
| 905 |
+
}
|
| 906 |
+
}
|
| 907 |
+
layer {
|
| 908 |
+
name: "conv8_2_relu"
|
| 909 |
+
type: "ReLU"
|
| 910 |
+
bottom: "conv8_2_h"
|
| 911 |
+
top: "conv8_2_h"
|
| 912 |
+
}
|
| 913 |
+
layer {
|
| 914 |
+
name: "conv9_1_h"
|
| 915 |
+
type: "Convolution"
|
| 916 |
+
bottom: "conv8_2_h"
|
| 917 |
+
top: "conv9_1_h"
|
| 918 |
+
param {
|
| 919 |
+
lr_mult: 1
|
| 920 |
+
decay_mult: 1
|
| 921 |
+
}
|
| 922 |
+
param {
|
| 923 |
+
lr_mult: 2
|
| 924 |
+
decay_mult: 0
|
| 925 |
+
}
|
| 926 |
+
convolution_param {
|
| 927 |
+
num_output: 64
|
| 928 |
+
pad: 0
|
| 929 |
+
kernel_size: 1
|
| 930 |
+
stride: 1
|
| 931 |
+
weight_filler {
|
| 932 |
+
type: "xavier"
|
| 933 |
+
}
|
| 934 |
+
bias_filler {
|
| 935 |
+
type: "constant"
|
| 936 |
+
value: 0
|
| 937 |
+
}
|
| 938 |
+
}
|
| 939 |
+
}
|
| 940 |
+
layer {
|
| 941 |
+
name: "conv9_1_relu"
|
| 942 |
+
type: "ReLU"
|
| 943 |
+
bottom: "conv9_1_h"
|
| 944 |
+
top: "conv9_1_h"
|
| 945 |
+
}
|
| 946 |
+
layer {
|
| 947 |
+
name: "conv9_2_h"
|
| 948 |
+
type: "Convolution"
|
| 949 |
+
bottom: "conv9_1_h"
|
| 950 |
+
top: "conv9_2_h"
|
| 951 |
+
param {
|
| 952 |
+
lr_mult: 1
|
| 953 |
+
decay_mult: 1
|
| 954 |
+
}
|
| 955 |
+
param {
|
| 956 |
+
lr_mult: 2
|
| 957 |
+
decay_mult: 0
|
| 958 |
+
}
|
| 959 |
+
convolution_param {
|
| 960 |
+
num_output: 128
|
| 961 |
+
pad: 1
|
| 962 |
+
kernel_size: 3
|
| 963 |
+
stride: 1
|
| 964 |
+
weight_filler {
|
| 965 |
+
type: "xavier"
|
| 966 |
+
}
|
| 967 |
+
bias_filler {
|
| 968 |
+
type: "constant"
|
| 969 |
+
value: 0
|
| 970 |
+
}
|
| 971 |
+
}
|
| 972 |
+
}
|
| 973 |
+
layer {
|
| 974 |
+
name: "conv9_2_relu"
|
| 975 |
+
type: "ReLU"
|
| 976 |
+
bottom: "conv9_2_h"
|
| 977 |
+
top: "conv9_2_h"
|
| 978 |
+
}
|
| 979 |
+
layer {
|
| 980 |
+
name: "conv4_3_norm"
|
| 981 |
+
type: "Normalize"
|
| 982 |
+
bottom: "layer_256_1_bn1"
|
| 983 |
+
top: "conv4_3_norm"
|
| 984 |
+
norm_param {
|
| 985 |
+
across_spatial: false
|
| 986 |
+
scale_filler {
|
| 987 |
+
type: "constant"
|
| 988 |
+
value: 20
|
| 989 |
+
}
|
| 990 |
+
channel_shared: false
|
| 991 |
+
}
|
| 992 |
+
}
|
| 993 |
+
layer {
|
| 994 |
+
name: "conv4_3_norm_mbox_loc"
|
| 995 |
+
type: "Convolution"
|
| 996 |
+
bottom: "conv4_3_norm"
|
| 997 |
+
top: "conv4_3_norm_mbox_loc"
|
| 998 |
+
param {
|
| 999 |
+
lr_mult: 1
|
| 1000 |
+
decay_mult: 1
|
| 1001 |
+
}
|
| 1002 |
+
param {
|
| 1003 |
+
lr_mult: 2
|
| 1004 |
+
decay_mult: 0
|
| 1005 |
+
}
|
| 1006 |
+
convolution_param {
|
| 1007 |
+
num_output: 16
|
| 1008 |
+
pad: 1
|
| 1009 |
+
kernel_size: 3
|
| 1010 |
+
stride: 1
|
| 1011 |
+
weight_filler {
|
| 1012 |
+
type: "xavier"
|
| 1013 |
+
}
|
| 1014 |
+
bias_filler {
|
| 1015 |
+
type: "constant"
|
| 1016 |
+
value: 0
|
| 1017 |
+
}
|
| 1018 |
+
}
|
| 1019 |
+
}
|
| 1020 |
+
layer {
|
| 1021 |
+
name: "conv4_3_norm_mbox_loc_perm"
|
| 1022 |
+
type: "Permute"
|
| 1023 |
+
bottom: "conv4_3_norm_mbox_loc"
|
| 1024 |
+
top: "conv4_3_norm_mbox_loc_perm"
|
| 1025 |
+
permute_param {
|
| 1026 |
+
order: 0
|
| 1027 |
+
order: 2
|
| 1028 |
+
order: 3
|
| 1029 |
+
order: 1
|
| 1030 |
+
}
|
| 1031 |
+
}
|
| 1032 |
+
layer {
|
| 1033 |
+
name: "conv4_3_norm_mbox_loc_flat"
|
| 1034 |
+
type: "Flatten"
|
| 1035 |
+
bottom: "conv4_3_norm_mbox_loc_perm"
|
| 1036 |
+
top: "conv4_3_norm_mbox_loc_flat"
|
| 1037 |
+
flatten_param {
|
| 1038 |
+
axis: 1
|
| 1039 |
+
}
|
| 1040 |
+
}
|
| 1041 |
+
layer {
|
| 1042 |
+
name: "conv4_3_norm_mbox_conf"
|
| 1043 |
+
type: "Convolution"
|
| 1044 |
+
bottom: "conv4_3_norm"
|
| 1045 |
+
top: "conv4_3_norm_mbox_conf"
|
| 1046 |
+
param {
|
| 1047 |
+
lr_mult: 1
|
| 1048 |
+
decay_mult: 1
|
| 1049 |
+
}
|
| 1050 |
+
param {
|
| 1051 |
+
lr_mult: 2
|
| 1052 |
+
decay_mult: 0
|
| 1053 |
+
}
|
| 1054 |
+
convolution_param {
|
| 1055 |
+
num_output: 8 # 84
|
| 1056 |
+
pad: 1
|
| 1057 |
+
kernel_size: 3
|
| 1058 |
+
stride: 1
|
| 1059 |
+
weight_filler {
|
| 1060 |
+
type: "xavier"
|
| 1061 |
+
}
|
| 1062 |
+
bias_filler {
|
| 1063 |
+
type: "constant"
|
| 1064 |
+
value: 0
|
| 1065 |
+
}
|
| 1066 |
+
}
|
| 1067 |
+
}
|
| 1068 |
+
layer {
|
| 1069 |
+
name: "conv4_3_norm_mbox_conf_perm"
|
| 1070 |
+
type: "Permute"
|
| 1071 |
+
bottom: "conv4_3_norm_mbox_conf"
|
| 1072 |
+
top: "conv4_3_norm_mbox_conf_perm"
|
| 1073 |
+
permute_param {
|
| 1074 |
+
order: 0
|
| 1075 |
+
order: 2
|
| 1076 |
+
order: 3
|
| 1077 |
+
order: 1
|
| 1078 |
+
}
|
| 1079 |
+
}
|
| 1080 |
+
layer {
|
| 1081 |
+
name: "conv4_3_norm_mbox_conf_flat"
|
| 1082 |
+
type: "Flatten"
|
| 1083 |
+
bottom: "conv4_3_norm_mbox_conf_perm"
|
| 1084 |
+
top: "conv4_3_norm_mbox_conf_flat"
|
| 1085 |
+
flatten_param {
|
| 1086 |
+
axis: 1
|
| 1087 |
+
}
|
| 1088 |
+
}
|
| 1089 |
+
layer {
|
| 1090 |
+
name: "conv4_3_norm_mbox_priorbox"
|
| 1091 |
+
type: "PriorBox"
|
| 1092 |
+
bottom: "conv4_3_norm"
|
| 1093 |
+
bottom: "data"
|
| 1094 |
+
top: "conv4_3_norm_mbox_priorbox"
|
| 1095 |
+
prior_box_param {
|
| 1096 |
+
min_size: 30.0
|
| 1097 |
+
max_size: 60.0
|
| 1098 |
+
aspect_ratio: 2
|
| 1099 |
+
flip: true
|
| 1100 |
+
clip: false
|
| 1101 |
+
variance: 0.1
|
| 1102 |
+
variance: 0.1
|
| 1103 |
+
variance: 0.2
|
| 1104 |
+
variance: 0.2
|
| 1105 |
+
step: 8
|
| 1106 |
+
offset: 0.5
|
| 1107 |
+
}
|
| 1108 |
+
}
|
| 1109 |
+
layer {
|
| 1110 |
+
name: "fc7_mbox_loc"
|
| 1111 |
+
type: "Convolution"
|
| 1112 |
+
bottom: "fc7"
|
| 1113 |
+
top: "fc7_mbox_loc"
|
| 1114 |
+
param {
|
| 1115 |
+
lr_mult: 1
|
| 1116 |
+
decay_mult: 1
|
| 1117 |
+
}
|
| 1118 |
+
param {
|
| 1119 |
+
lr_mult: 2
|
| 1120 |
+
decay_mult: 0
|
| 1121 |
+
}
|
| 1122 |
+
convolution_param {
|
| 1123 |
+
num_output: 24
|
| 1124 |
+
pad: 1
|
| 1125 |
+
kernel_size: 3
|
| 1126 |
+
stride: 1
|
| 1127 |
+
weight_filler {
|
| 1128 |
+
type: "xavier"
|
| 1129 |
+
}
|
| 1130 |
+
bias_filler {
|
| 1131 |
+
type: "constant"
|
| 1132 |
+
value: 0
|
| 1133 |
+
}
|
| 1134 |
+
}
|
| 1135 |
+
}
|
| 1136 |
+
layer {
|
| 1137 |
+
name: "fc7_mbox_loc_perm"
|
| 1138 |
+
type: "Permute"
|
| 1139 |
+
bottom: "fc7_mbox_loc"
|
| 1140 |
+
top: "fc7_mbox_loc_perm"
|
| 1141 |
+
permute_param {
|
| 1142 |
+
order: 0
|
| 1143 |
+
order: 2
|
| 1144 |
+
order: 3
|
| 1145 |
+
order: 1
|
| 1146 |
+
}
|
| 1147 |
+
}
|
| 1148 |
+
layer {
|
| 1149 |
+
name: "fc7_mbox_loc_flat"
|
| 1150 |
+
type: "Flatten"
|
| 1151 |
+
bottom: "fc7_mbox_loc_perm"
|
| 1152 |
+
top: "fc7_mbox_loc_flat"
|
| 1153 |
+
flatten_param {
|
| 1154 |
+
axis: 1
|
| 1155 |
+
}
|
| 1156 |
+
}
|
| 1157 |
+
layer {
|
| 1158 |
+
name: "fc7_mbox_conf"
|
| 1159 |
+
type: "Convolution"
|
| 1160 |
+
bottom: "fc7"
|
| 1161 |
+
top: "fc7_mbox_conf"
|
| 1162 |
+
param {
|
| 1163 |
+
lr_mult: 1
|
| 1164 |
+
decay_mult: 1
|
| 1165 |
+
}
|
| 1166 |
+
param {
|
| 1167 |
+
lr_mult: 2
|
| 1168 |
+
decay_mult: 0
|
| 1169 |
+
}
|
| 1170 |
+
convolution_param {
|
| 1171 |
+
num_output: 12 # 126
|
| 1172 |
+
pad: 1
|
| 1173 |
+
kernel_size: 3
|
| 1174 |
+
stride: 1
|
| 1175 |
+
weight_filler {
|
| 1176 |
+
type: "xavier"
|
| 1177 |
+
}
|
| 1178 |
+
bias_filler {
|
| 1179 |
+
type: "constant"
|
| 1180 |
+
value: 0
|
| 1181 |
+
}
|
| 1182 |
+
}
|
| 1183 |
+
}
|
| 1184 |
+
layer {
|
| 1185 |
+
name: "fc7_mbox_conf_perm"
|
| 1186 |
+
type: "Permute"
|
| 1187 |
+
bottom: "fc7_mbox_conf"
|
| 1188 |
+
top: "fc7_mbox_conf_perm"
|
| 1189 |
+
permute_param {
|
| 1190 |
+
order: 0
|
| 1191 |
+
order: 2
|
| 1192 |
+
order: 3
|
| 1193 |
+
order: 1
|
| 1194 |
+
}
|
| 1195 |
+
}
|
| 1196 |
+
layer {
|
| 1197 |
+
name: "fc7_mbox_conf_flat"
|
| 1198 |
+
type: "Flatten"
|
| 1199 |
+
bottom: "fc7_mbox_conf_perm"
|
| 1200 |
+
top: "fc7_mbox_conf_flat"
|
| 1201 |
+
flatten_param {
|
| 1202 |
+
axis: 1
|
| 1203 |
+
}
|
| 1204 |
+
}
|
| 1205 |
+
layer {
|
| 1206 |
+
name: "fc7_mbox_priorbox"
|
| 1207 |
+
type: "PriorBox"
|
| 1208 |
+
bottom: "fc7"
|
| 1209 |
+
bottom: "data"
|
| 1210 |
+
top: "fc7_mbox_priorbox"
|
| 1211 |
+
prior_box_param {
|
| 1212 |
+
min_size: 60.0
|
| 1213 |
+
max_size: 111.0
|
| 1214 |
+
aspect_ratio: 2
|
| 1215 |
+
aspect_ratio: 3
|
| 1216 |
+
flip: true
|
| 1217 |
+
clip: false
|
| 1218 |
+
variance: 0.1
|
| 1219 |
+
variance: 0.1
|
| 1220 |
+
variance: 0.2
|
| 1221 |
+
variance: 0.2
|
| 1222 |
+
step: 16
|
| 1223 |
+
offset: 0.5
|
| 1224 |
+
}
|
| 1225 |
+
}
|
| 1226 |
+
layer {
|
| 1227 |
+
name: "conv6_2_mbox_loc"
|
| 1228 |
+
type: "Convolution"
|
| 1229 |
+
bottom: "conv6_2_h"
|
| 1230 |
+
top: "conv6_2_mbox_loc"
|
| 1231 |
+
param {
|
| 1232 |
+
lr_mult: 1
|
| 1233 |
+
decay_mult: 1
|
| 1234 |
+
}
|
| 1235 |
+
param {
|
| 1236 |
+
lr_mult: 2
|
| 1237 |
+
decay_mult: 0
|
| 1238 |
+
}
|
| 1239 |
+
convolution_param {
|
| 1240 |
+
num_output: 24
|
| 1241 |
+
pad: 1
|
| 1242 |
+
kernel_size: 3
|
| 1243 |
+
stride: 1
|
| 1244 |
+
weight_filler {
|
| 1245 |
+
type: "xavier"
|
| 1246 |
+
}
|
| 1247 |
+
bias_filler {
|
| 1248 |
+
type: "constant"
|
| 1249 |
+
value: 0
|
| 1250 |
+
}
|
| 1251 |
+
}
|
| 1252 |
+
}
|
| 1253 |
+
layer {
|
| 1254 |
+
name: "conv6_2_mbox_loc_perm"
|
| 1255 |
+
type: "Permute"
|
| 1256 |
+
bottom: "conv6_2_mbox_loc"
|
| 1257 |
+
top: "conv6_2_mbox_loc_perm"
|
| 1258 |
+
permute_param {
|
| 1259 |
+
order: 0
|
| 1260 |
+
order: 2
|
| 1261 |
+
order: 3
|
| 1262 |
+
order: 1
|
| 1263 |
+
}
|
| 1264 |
+
}
|
| 1265 |
+
layer {
|
| 1266 |
+
name: "conv6_2_mbox_loc_flat"
|
| 1267 |
+
type: "Flatten"
|
| 1268 |
+
bottom: "conv6_2_mbox_loc_perm"
|
| 1269 |
+
top: "conv6_2_mbox_loc_flat"
|
| 1270 |
+
flatten_param {
|
| 1271 |
+
axis: 1
|
| 1272 |
+
}
|
| 1273 |
+
}
|
| 1274 |
+
layer {
|
| 1275 |
+
name: "conv6_2_mbox_conf"
|
| 1276 |
+
type: "Convolution"
|
| 1277 |
+
bottom: "conv6_2_h"
|
| 1278 |
+
top: "conv6_2_mbox_conf"
|
| 1279 |
+
param {
|
| 1280 |
+
lr_mult: 1
|
| 1281 |
+
decay_mult: 1
|
| 1282 |
+
}
|
| 1283 |
+
param {
|
| 1284 |
+
lr_mult: 2
|
| 1285 |
+
decay_mult: 0
|
| 1286 |
+
}
|
| 1287 |
+
convolution_param {
|
| 1288 |
+
num_output: 12 # 126
|
| 1289 |
+
pad: 1
|
| 1290 |
+
kernel_size: 3
|
| 1291 |
+
stride: 1
|
| 1292 |
+
weight_filler {
|
| 1293 |
+
type: "xavier"
|
| 1294 |
+
}
|
| 1295 |
+
bias_filler {
|
| 1296 |
+
type: "constant"
|
| 1297 |
+
value: 0
|
| 1298 |
+
}
|
| 1299 |
+
}
|
| 1300 |
+
}
|
| 1301 |
+
layer {
|
| 1302 |
+
name: "conv6_2_mbox_conf_perm"
|
| 1303 |
+
type: "Permute"
|
| 1304 |
+
bottom: "conv6_2_mbox_conf"
|
| 1305 |
+
top: "conv6_2_mbox_conf_perm"
|
| 1306 |
+
permute_param {
|
| 1307 |
+
order: 0
|
| 1308 |
+
order: 2
|
| 1309 |
+
order: 3
|
| 1310 |
+
order: 1
|
| 1311 |
+
}
|
| 1312 |
+
}
|
| 1313 |
+
layer {
|
| 1314 |
+
name: "conv6_2_mbox_conf_flat"
|
| 1315 |
+
type: "Flatten"
|
| 1316 |
+
bottom: "conv6_2_mbox_conf_perm"
|
| 1317 |
+
top: "conv6_2_mbox_conf_flat"
|
| 1318 |
+
flatten_param {
|
| 1319 |
+
axis: 1
|
| 1320 |
+
}
|
| 1321 |
+
}
|
| 1322 |
+
layer {
|
| 1323 |
+
name: "conv6_2_mbox_priorbox"
|
| 1324 |
+
type: "PriorBox"
|
| 1325 |
+
bottom: "conv6_2_h"
|
| 1326 |
+
bottom: "data"
|
| 1327 |
+
top: "conv6_2_mbox_priorbox"
|
| 1328 |
+
prior_box_param {
|
| 1329 |
+
min_size: 111.0
|
| 1330 |
+
max_size: 162.0
|
| 1331 |
+
aspect_ratio: 2
|
| 1332 |
+
aspect_ratio: 3
|
| 1333 |
+
flip: true
|
| 1334 |
+
clip: false
|
| 1335 |
+
variance: 0.1
|
| 1336 |
+
variance: 0.1
|
| 1337 |
+
variance: 0.2
|
| 1338 |
+
variance: 0.2
|
| 1339 |
+
step: 32
|
| 1340 |
+
offset: 0.5
|
| 1341 |
+
}
|
| 1342 |
+
}
|
| 1343 |
+
layer {
|
| 1344 |
+
name: "conv7_2_mbox_loc"
|
| 1345 |
+
type: "Convolution"
|
| 1346 |
+
bottom: "conv7_2_h"
|
| 1347 |
+
top: "conv7_2_mbox_loc"
|
| 1348 |
+
param {
|
| 1349 |
+
lr_mult: 1
|
| 1350 |
+
decay_mult: 1
|
| 1351 |
+
}
|
| 1352 |
+
param {
|
| 1353 |
+
lr_mult: 2
|
| 1354 |
+
decay_mult: 0
|
| 1355 |
+
}
|
| 1356 |
+
convolution_param {
|
| 1357 |
+
num_output: 24
|
| 1358 |
+
pad: 1
|
| 1359 |
+
kernel_size: 3
|
| 1360 |
+
stride: 1
|
| 1361 |
+
weight_filler {
|
| 1362 |
+
type: "xavier"
|
| 1363 |
+
}
|
| 1364 |
+
bias_filler {
|
| 1365 |
+
type: "constant"
|
| 1366 |
+
value: 0
|
| 1367 |
+
}
|
| 1368 |
+
}
|
| 1369 |
+
}
|
| 1370 |
+
layer {
|
| 1371 |
+
name: "conv7_2_mbox_loc_perm"
|
| 1372 |
+
type: "Permute"
|
| 1373 |
+
bottom: "conv7_2_mbox_loc"
|
| 1374 |
+
top: "conv7_2_mbox_loc_perm"
|
| 1375 |
+
permute_param {
|
| 1376 |
+
order: 0
|
| 1377 |
+
order: 2
|
| 1378 |
+
order: 3
|
| 1379 |
+
order: 1
|
| 1380 |
+
}
|
| 1381 |
+
}
|
| 1382 |
+
layer {
|
| 1383 |
+
name: "conv7_2_mbox_loc_flat"
|
| 1384 |
+
type: "Flatten"
|
| 1385 |
+
bottom: "conv7_2_mbox_loc_perm"
|
| 1386 |
+
top: "conv7_2_mbox_loc_flat"
|
| 1387 |
+
flatten_param {
|
| 1388 |
+
axis: 1
|
| 1389 |
+
}
|
| 1390 |
+
}
|
| 1391 |
+
layer {
|
| 1392 |
+
name: "conv7_2_mbox_conf"
|
| 1393 |
+
type: "Convolution"
|
| 1394 |
+
bottom: "conv7_2_h"
|
| 1395 |
+
top: "conv7_2_mbox_conf"
|
| 1396 |
+
param {
|
| 1397 |
+
lr_mult: 1
|
| 1398 |
+
decay_mult: 1
|
| 1399 |
+
}
|
| 1400 |
+
param {
|
| 1401 |
+
lr_mult: 2
|
| 1402 |
+
decay_mult: 0
|
| 1403 |
+
}
|
| 1404 |
+
convolution_param {
|
| 1405 |
+
num_output: 12 # 126
|
| 1406 |
+
pad: 1
|
| 1407 |
+
kernel_size: 3
|
| 1408 |
+
stride: 1
|
| 1409 |
+
weight_filler {
|
| 1410 |
+
type: "xavier"
|
| 1411 |
+
}
|
| 1412 |
+
bias_filler {
|
| 1413 |
+
type: "constant"
|
| 1414 |
+
value: 0
|
| 1415 |
+
}
|
| 1416 |
+
}
|
| 1417 |
+
}
|
| 1418 |
+
layer {
|
| 1419 |
+
name: "conv7_2_mbox_conf_perm"
|
| 1420 |
+
type: "Permute"
|
| 1421 |
+
bottom: "conv7_2_mbox_conf"
|
| 1422 |
+
top: "conv7_2_mbox_conf_perm"
|
| 1423 |
+
permute_param {
|
| 1424 |
+
order: 0
|
| 1425 |
+
order: 2
|
| 1426 |
+
order: 3
|
| 1427 |
+
order: 1
|
| 1428 |
+
}
|
| 1429 |
+
}
|
| 1430 |
+
layer {
|
| 1431 |
+
name: "conv7_2_mbox_conf_flat"
|
| 1432 |
+
type: "Flatten"
|
| 1433 |
+
bottom: "conv7_2_mbox_conf_perm"
|
| 1434 |
+
top: "conv7_2_mbox_conf_flat"
|
| 1435 |
+
flatten_param {
|
| 1436 |
+
axis: 1
|
| 1437 |
+
}
|
| 1438 |
+
}
|
| 1439 |
+
layer {
|
| 1440 |
+
name: "conv7_2_mbox_priorbox"
|
| 1441 |
+
type: "PriorBox"
|
| 1442 |
+
bottom: "conv7_2_h"
|
| 1443 |
+
bottom: "data"
|
| 1444 |
+
top: "conv7_2_mbox_priorbox"
|
| 1445 |
+
prior_box_param {
|
| 1446 |
+
min_size: 162.0
|
| 1447 |
+
max_size: 213.0
|
| 1448 |
+
aspect_ratio: 2
|
| 1449 |
+
aspect_ratio: 3
|
| 1450 |
+
flip: true
|
| 1451 |
+
clip: false
|
| 1452 |
+
variance: 0.1
|
| 1453 |
+
variance: 0.1
|
| 1454 |
+
variance: 0.2
|
| 1455 |
+
variance: 0.2
|
| 1456 |
+
step: 64
|
| 1457 |
+
offset: 0.5
|
| 1458 |
+
}
|
| 1459 |
+
}
|
| 1460 |
+
layer {
|
| 1461 |
+
name: "conv8_2_mbox_loc"
|
| 1462 |
+
type: "Convolution"
|
| 1463 |
+
bottom: "conv8_2_h"
|
| 1464 |
+
top: "conv8_2_mbox_loc"
|
| 1465 |
+
param {
|
| 1466 |
+
lr_mult: 1
|
| 1467 |
+
decay_mult: 1
|
| 1468 |
+
}
|
| 1469 |
+
param {
|
| 1470 |
+
lr_mult: 2
|
| 1471 |
+
decay_mult: 0
|
| 1472 |
+
}
|
| 1473 |
+
convolution_param {
|
| 1474 |
+
num_output: 16
|
| 1475 |
+
pad: 1
|
| 1476 |
+
kernel_size: 3
|
| 1477 |
+
stride: 1
|
| 1478 |
+
weight_filler {
|
| 1479 |
+
type: "xavier"
|
| 1480 |
+
}
|
| 1481 |
+
bias_filler {
|
| 1482 |
+
type: "constant"
|
| 1483 |
+
value: 0
|
| 1484 |
+
}
|
| 1485 |
+
}
|
| 1486 |
+
}
|
| 1487 |
+
layer {
|
| 1488 |
+
name: "conv8_2_mbox_loc_perm"
|
| 1489 |
+
type: "Permute"
|
| 1490 |
+
bottom: "conv8_2_mbox_loc"
|
| 1491 |
+
top: "conv8_2_mbox_loc_perm"
|
| 1492 |
+
permute_param {
|
| 1493 |
+
order: 0
|
| 1494 |
+
order: 2
|
| 1495 |
+
order: 3
|
| 1496 |
+
order: 1
|
| 1497 |
+
}
|
| 1498 |
+
}
|
| 1499 |
+
layer {
|
| 1500 |
+
name: "conv8_2_mbox_loc_flat"
|
| 1501 |
+
type: "Flatten"
|
| 1502 |
+
bottom: "conv8_2_mbox_loc_perm"
|
| 1503 |
+
top: "conv8_2_mbox_loc_flat"
|
| 1504 |
+
flatten_param {
|
| 1505 |
+
axis: 1
|
| 1506 |
+
}
|
| 1507 |
+
}
|
| 1508 |
+
layer {
|
| 1509 |
+
name: "conv8_2_mbox_conf"
|
| 1510 |
+
type: "Convolution"
|
| 1511 |
+
bottom: "conv8_2_h"
|
| 1512 |
+
top: "conv8_2_mbox_conf"
|
| 1513 |
+
param {
|
| 1514 |
+
lr_mult: 1
|
| 1515 |
+
decay_mult: 1
|
| 1516 |
+
}
|
| 1517 |
+
param {
|
| 1518 |
+
lr_mult: 2
|
| 1519 |
+
decay_mult: 0
|
| 1520 |
+
}
|
| 1521 |
+
convolution_param {
|
| 1522 |
+
num_output: 8 # 84
|
| 1523 |
+
pad: 1
|
| 1524 |
+
kernel_size: 3
|
| 1525 |
+
stride: 1
|
| 1526 |
+
weight_filler {
|
| 1527 |
+
type: "xavier"
|
| 1528 |
+
}
|
| 1529 |
+
bias_filler {
|
| 1530 |
+
type: "constant"
|
| 1531 |
+
value: 0
|
| 1532 |
+
}
|
| 1533 |
+
}
|
| 1534 |
+
}
|
| 1535 |
+
layer {
|
| 1536 |
+
name: "conv8_2_mbox_conf_perm"
|
| 1537 |
+
type: "Permute"
|
| 1538 |
+
bottom: "conv8_2_mbox_conf"
|
| 1539 |
+
top: "conv8_2_mbox_conf_perm"
|
| 1540 |
+
permute_param {
|
| 1541 |
+
order: 0
|
| 1542 |
+
order: 2
|
| 1543 |
+
order: 3
|
| 1544 |
+
order: 1
|
| 1545 |
+
}
|
| 1546 |
+
}
|
| 1547 |
+
layer {
|
| 1548 |
+
name: "conv8_2_mbox_conf_flat"
|
| 1549 |
+
type: "Flatten"
|
| 1550 |
+
bottom: "conv8_2_mbox_conf_perm"
|
| 1551 |
+
top: "conv8_2_mbox_conf_flat"
|
| 1552 |
+
flatten_param {
|
| 1553 |
+
axis: 1
|
| 1554 |
+
}
|
| 1555 |
+
}
|
| 1556 |
+
layer {
|
| 1557 |
+
name: "conv8_2_mbox_priorbox"
|
| 1558 |
+
type: "PriorBox"
|
| 1559 |
+
bottom: "conv8_2_h"
|
| 1560 |
+
bottom: "data"
|
| 1561 |
+
top: "conv8_2_mbox_priorbox"
|
| 1562 |
+
prior_box_param {
|
| 1563 |
+
min_size: 213.0
|
| 1564 |
+
max_size: 264.0
|
| 1565 |
+
aspect_ratio: 2
|
| 1566 |
+
flip: true
|
| 1567 |
+
clip: false
|
| 1568 |
+
variance: 0.1
|
| 1569 |
+
variance: 0.1
|
| 1570 |
+
variance: 0.2
|
| 1571 |
+
variance: 0.2
|
| 1572 |
+
step: 100
|
| 1573 |
+
offset: 0.5
|
| 1574 |
+
}
|
| 1575 |
+
}
|
| 1576 |
+
layer {
|
| 1577 |
+
name: "conv9_2_mbox_loc"
|
| 1578 |
+
type: "Convolution"
|
| 1579 |
+
bottom: "conv9_2_h"
|
| 1580 |
+
top: "conv9_2_mbox_loc"
|
| 1581 |
+
param {
|
| 1582 |
+
lr_mult: 1
|
| 1583 |
+
decay_mult: 1
|
| 1584 |
+
}
|
| 1585 |
+
param {
|
| 1586 |
+
lr_mult: 2
|
| 1587 |
+
decay_mult: 0
|
| 1588 |
+
}
|
| 1589 |
+
convolution_param {
|
| 1590 |
+
num_output: 16
|
| 1591 |
+
pad: 1
|
| 1592 |
+
kernel_size: 3
|
| 1593 |
+
stride: 1
|
| 1594 |
+
weight_filler {
|
| 1595 |
+
type: "xavier"
|
| 1596 |
+
}
|
| 1597 |
+
bias_filler {
|
| 1598 |
+
type: "constant"
|
| 1599 |
+
value: 0
|
| 1600 |
+
}
|
| 1601 |
+
}
|
| 1602 |
+
}
|
| 1603 |
+
layer {
|
| 1604 |
+
name: "conv9_2_mbox_loc_perm"
|
| 1605 |
+
type: "Permute"
|
| 1606 |
+
bottom: "conv9_2_mbox_loc"
|
| 1607 |
+
top: "conv9_2_mbox_loc_perm"
|
| 1608 |
+
permute_param {
|
| 1609 |
+
order: 0
|
| 1610 |
+
order: 2
|
| 1611 |
+
order: 3
|
| 1612 |
+
order: 1
|
| 1613 |
+
}
|
| 1614 |
+
}
|
| 1615 |
+
layer {
|
| 1616 |
+
name: "conv9_2_mbox_loc_flat"
|
| 1617 |
+
type: "Flatten"
|
| 1618 |
+
bottom: "conv9_2_mbox_loc_perm"
|
| 1619 |
+
top: "conv9_2_mbox_loc_flat"
|
| 1620 |
+
flatten_param {
|
| 1621 |
+
axis: 1
|
| 1622 |
+
}
|
| 1623 |
+
}
|
| 1624 |
+
layer {
|
| 1625 |
+
name: "conv9_2_mbox_conf"
|
| 1626 |
+
type: "Convolution"
|
| 1627 |
+
bottom: "conv9_2_h"
|
| 1628 |
+
top: "conv9_2_mbox_conf"
|
| 1629 |
+
param {
|
| 1630 |
+
lr_mult: 1
|
| 1631 |
+
decay_mult: 1
|
| 1632 |
+
}
|
| 1633 |
+
param {
|
| 1634 |
+
lr_mult: 2
|
| 1635 |
+
decay_mult: 0
|
| 1636 |
+
}
|
| 1637 |
+
convolution_param {
|
| 1638 |
+
num_output: 8 # 84
|
| 1639 |
+
pad: 1
|
| 1640 |
+
kernel_size: 3
|
| 1641 |
+
stride: 1
|
| 1642 |
+
weight_filler {
|
| 1643 |
+
type: "xavier"
|
| 1644 |
+
}
|
| 1645 |
+
bias_filler {
|
| 1646 |
+
type: "constant"
|
| 1647 |
+
value: 0
|
| 1648 |
+
}
|
| 1649 |
+
}
|
| 1650 |
+
}
|
| 1651 |
+
layer {
|
| 1652 |
+
name: "conv9_2_mbox_conf_perm"
|
| 1653 |
+
type: "Permute"
|
| 1654 |
+
bottom: "conv9_2_mbox_conf"
|
| 1655 |
+
top: "conv9_2_mbox_conf_perm"
|
| 1656 |
+
permute_param {
|
| 1657 |
+
order: 0
|
| 1658 |
+
order: 2
|
| 1659 |
+
order: 3
|
| 1660 |
+
order: 1
|
| 1661 |
+
}
|
| 1662 |
+
}
|
| 1663 |
+
layer {
|
| 1664 |
+
name: "conv9_2_mbox_conf_flat"
|
| 1665 |
+
type: "Flatten"
|
| 1666 |
+
bottom: "conv9_2_mbox_conf_perm"
|
| 1667 |
+
top: "conv9_2_mbox_conf_flat"
|
| 1668 |
+
flatten_param {
|
| 1669 |
+
axis: 1
|
| 1670 |
+
}
|
| 1671 |
+
}
|
| 1672 |
+
layer {
|
| 1673 |
+
name: "conv9_2_mbox_priorbox"
|
| 1674 |
+
type: "PriorBox"
|
| 1675 |
+
bottom: "conv9_2_h"
|
| 1676 |
+
bottom: "data"
|
| 1677 |
+
top: "conv9_2_mbox_priorbox"
|
| 1678 |
+
prior_box_param {
|
| 1679 |
+
min_size: 264.0
|
| 1680 |
+
max_size: 315.0
|
| 1681 |
+
aspect_ratio: 2
|
| 1682 |
+
flip: true
|
| 1683 |
+
clip: false
|
| 1684 |
+
variance: 0.1
|
| 1685 |
+
variance: 0.1
|
| 1686 |
+
variance: 0.2
|
| 1687 |
+
variance: 0.2
|
| 1688 |
+
step: 300
|
| 1689 |
+
offset: 0.5
|
| 1690 |
+
}
|
| 1691 |
+
}
|
| 1692 |
+
layer {
|
| 1693 |
+
name: "mbox_loc"
|
| 1694 |
+
type: "Concat"
|
| 1695 |
+
bottom: "conv4_3_norm_mbox_loc_flat"
|
| 1696 |
+
bottom: "fc7_mbox_loc_flat"
|
| 1697 |
+
bottom: "conv6_2_mbox_loc_flat"
|
| 1698 |
+
bottom: "conv7_2_mbox_loc_flat"
|
| 1699 |
+
bottom: "conv8_2_mbox_loc_flat"
|
| 1700 |
+
bottom: "conv9_2_mbox_loc_flat"
|
| 1701 |
+
top: "mbox_loc"
|
| 1702 |
+
concat_param {
|
| 1703 |
+
axis: 1
|
| 1704 |
+
}
|
| 1705 |
+
}
|
| 1706 |
+
layer {
|
| 1707 |
+
name: "mbox_conf"
|
| 1708 |
+
type: "Concat"
|
| 1709 |
+
bottom: "conv4_3_norm_mbox_conf_flat"
|
| 1710 |
+
bottom: "fc7_mbox_conf_flat"
|
| 1711 |
+
bottom: "conv6_2_mbox_conf_flat"
|
| 1712 |
+
bottom: "conv7_2_mbox_conf_flat"
|
| 1713 |
+
bottom: "conv8_2_mbox_conf_flat"
|
| 1714 |
+
bottom: "conv9_2_mbox_conf_flat"
|
| 1715 |
+
top: "mbox_conf"
|
| 1716 |
+
concat_param {
|
| 1717 |
+
axis: 1
|
| 1718 |
+
}
|
| 1719 |
+
}
|
| 1720 |
+
layer {
|
| 1721 |
+
name: "mbox_priorbox"
|
| 1722 |
+
type: "Concat"
|
| 1723 |
+
bottom: "conv4_3_norm_mbox_priorbox"
|
| 1724 |
+
bottom: "fc7_mbox_priorbox"
|
| 1725 |
+
bottom: "conv6_2_mbox_priorbox"
|
| 1726 |
+
bottom: "conv7_2_mbox_priorbox"
|
| 1727 |
+
bottom: "conv8_2_mbox_priorbox"
|
| 1728 |
+
bottom: "conv9_2_mbox_priorbox"
|
| 1729 |
+
top: "mbox_priorbox"
|
| 1730 |
+
concat_param {
|
| 1731 |
+
axis: 2
|
| 1732 |
+
}
|
| 1733 |
+
}
|
| 1734 |
+
|
| 1735 |
+
layer {
|
| 1736 |
+
name: "mbox_conf_reshape"
|
| 1737 |
+
type: "Reshape"
|
| 1738 |
+
bottom: "mbox_conf"
|
| 1739 |
+
top: "mbox_conf_reshape"
|
| 1740 |
+
reshape_param {
|
| 1741 |
+
shape {
|
| 1742 |
+
dim: 0
|
| 1743 |
+
dim: -1
|
| 1744 |
+
dim: 2
|
| 1745 |
+
}
|
| 1746 |
+
}
|
| 1747 |
+
}
|
| 1748 |
+
layer {
|
| 1749 |
+
name: "mbox_conf_softmax"
|
| 1750 |
+
type: "Softmax"
|
| 1751 |
+
bottom: "mbox_conf_reshape"
|
| 1752 |
+
top: "mbox_conf_softmax"
|
| 1753 |
+
softmax_param {
|
| 1754 |
+
axis: 2
|
| 1755 |
+
}
|
| 1756 |
+
}
|
| 1757 |
+
layer {
|
| 1758 |
+
name: "mbox_conf_flatten"
|
| 1759 |
+
type: "Flatten"
|
| 1760 |
+
bottom: "mbox_conf_softmax"
|
| 1761 |
+
top: "mbox_conf_flatten"
|
| 1762 |
+
flatten_param {
|
| 1763 |
+
axis: 1
|
| 1764 |
+
}
|
| 1765 |
+
}
|
| 1766 |
+
|
| 1767 |
+
layer {
|
| 1768 |
+
name: "detection_out"
|
| 1769 |
+
type: "DetectionOutput"
|
| 1770 |
+
bottom: "mbox_loc"
|
| 1771 |
+
bottom: "mbox_conf_flatten"
|
| 1772 |
+
bottom: "mbox_priorbox"
|
| 1773 |
+
top: "detection_out"
|
| 1774 |
+
include {
|
| 1775 |
+
phase: TEST
|
| 1776 |
+
}
|
| 1777 |
+
detection_output_param {
|
| 1778 |
+
num_classes: 2
|
| 1779 |
+
share_location: true
|
| 1780 |
+
background_label_id: 0
|
| 1781 |
+
nms_param {
|
| 1782 |
+
nms_threshold: 0.45
|
| 1783 |
+
top_k: 400
|
| 1784 |
+
}
|
| 1785 |
+
code_type: CENTER_SIZE
|
| 1786 |
+
keep_top_k: 200
|
| 1787 |
+
confidence_threshold: 0.01
|
| 1788 |
+
}
|
| 1789 |
+
}
|
utils/face_detector.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Face detector, used only for the demo to crop faces, as datasets have already been face-cropped"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
DEFAULT_FACE_DETECTOR = "utils/res10_300x300_ssd_iter_140000_fp16.caffemodel"
|
| 9 |
+
DEFAULT_DEPLOY = "utils/deploy.prototxt"
|
| 10 |
+
|
| 11 |
+
def enclosing_square(rect):
|
| 12 |
+
# Crea un quadrato che contiene il rettangolo passato in ingresso
|
| 13 |
+
x, y, w, h = rect
|
| 14 |
+
side = max(w, h)
|
| 15 |
+
# Centra il quadrato sulla bbox originale
|
| 16 |
+
cx = x + w // 2
|
| 17 |
+
cy = y + h // 2
|
| 18 |
+
x_new = cx - side // 2
|
| 19 |
+
y_new = cy - side // 2
|
| 20 |
+
return (x_new, y_new, side, side)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def cut(frame, roi):
|
| 24 |
+
pA = (int(roi[0]), int(roi[1]))
|
| 25 |
+
pB = (int(roi[0] + roi[2]), int(roi[1] + roi[3])) # pB will be an internal point
|
| 26 |
+
W, H = frame.shape[1], frame.shape[0]
|
| 27 |
+
A0 = pA[0] if pA[0] >= 0 else 0
|
| 28 |
+
A1 = pA[1] if pA[1] >= 0 else 0
|
| 29 |
+
data = frame[A1:pB[1], A0:pB[0]]
|
| 30 |
+
if pB[0] < W and pB[1] < H and pA[0] >= 0 and pA[1] >= 0:
|
| 31 |
+
return data
|
| 32 |
+
w, h = int(roi[2]), int(roi[3])
|
| 33 |
+
img = np.zeros((h, w, frame.shape[2]), dtype=np.uint8)
|
| 34 |
+
offX = int(-roi[0]) if roi[0] < 0 else 0
|
| 35 |
+
offY = int(-roi[1]) if roi[1] < 0 else 0
|
| 36 |
+
np.copyto(img[offY:offY + data.shape[0], offX:offX + data.shape[1]], data)
|
| 37 |
+
return img
|
| 38 |
+
|
| 39 |
+
class FaceDetector:
|
| 40 |
+
"""Face detector to spot faces inside a picture."""
|
| 41 |
+
def __init__(self, face_detector = DEFAULT_FACE_DETECTOR, deploy=DEFAULT_DEPLOY, confidence_threshold=0.8):
|
| 42 |
+
self.detector = cv2.dnn.readNetFromCaffe(deploy, face_detector)
|
| 43 |
+
self.confidence_threshold = confidence_threshold
|
| 44 |
+
|
| 45 |
+
def detect(self, image, pad_rect=True):
|
| 46 |
+
blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123], False, False)
|
| 47 |
+
frameHeight, frameWidth, channels = image.shape
|
| 48 |
+
self.detector.setInput(blob)
|
| 49 |
+
detections = self.detector.forward()
|
| 50 |
+
|
| 51 |
+
faces_result = []
|
| 52 |
+
for i in range(detections.shape[2]):
|
| 53 |
+
confidence = detections[0, 0, i, 2]
|
| 54 |
+
if confidence > self.confidence_threshold:
|
| 55 |
+
x1 = int(detections[0, 0, i, 3] * frameWidth)
|
| 56 |
+
y1 = int(detections[0, 0, i, 4] * frameHeight)
|
| 57 |
+
x2 = int(detections[0, 0, i, 5] * frameWidth)
|
| 58 |
+
y2 = int(detections[0, 0, i, 6] * frameHeight)
|
| 59 |
+
f = (x1, y1, x2 - x1, y2 - y1) # bbox: (x, y, w, h)
|
| 60 |
+
if f[2] > 1 and f[3] > 1:
|
| 61 |
+
rect = enclosing_square(f) if pad_rect else f
|
| 62 |
+
img_crop = cut(image, rect)
|
| 63 |
+
if img_crop.shape[0] > 0 and img_crop.shape[1] > 0:
|
| 64 |
+
faces_result.append((img_crop, confidence, rect)) # usa rect (quadrato) come bbox finale
|
| 65 |
+
if len(faces_result) == 0:
|
| 66 |
+
return None
|
| 67 |
+
return faces_result
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
input_folder = "src/demo_images"
|
| 71 |
+
output_crop_folder = "./test/detector/crop"
|
| 72 |
+
output_bbox_folder = "./test/detector/bbox"
|
| 73 |
+
|
| 74 |
+
os.makedirs(output_crop_folder, exist_ok=True)
|
| 75 |
+
os.makedirs(output_bbox_folder, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
face_detector = FaceDetector(confidence_threshold=0.8)
|
| 78 |
+
|
| 79 |
+
image_files = sorted([
|
| 80 |
+
f for f in os.listdir(input_folder)
|
| 81 |
+
if os.path.isfile(os.path.join(input_folder, f))
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
for img_file in image_files:
|
| 85 |
+
img_path = os.path.join(input_folder, img_file)
|
| 86 |
+
img = cv2.imread(img_path)
|
| 87 |
+
if img is None:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
faces = face_detector.detect(img, pad_rect=True)
|
| 91 |
+
base_name = os.path.splitext(os.path.basename(img_path))[0]
|
| 92 |
+
|
| 93 |
+
if faces is not None:
|
| 94 |
+
# Salva i crop dei volti
|
| 95 |
+
for idx, (crop, confidence, bbox) in enumerate(faces):
|
| 96 |
+
crop_path = os.path.join(output_crop_folder, f"{base_name}_face{idx}.jpg")
|
| 97 |
+
cv2.imwrite(crop_path, crop)
|
| 98 |
+
|
| 99 |
+
# Salva l'immagine originale con bbox quadrata
|
| 100 |
+
img_bbox = img.copy()
|
| 101 |
+
for idx, (_, _, bbox) in enumerate(faces):
|
| 102 |
+
x, y, w, h = bbox # bbox è già quadrata
|
| 103 |
+
cv2.rectangle(img_bbox, (x, y), (x + w, y + h), (0, 0, 255), 2) # rosso BGR
|
| 104 |
+
bbox_path = os.path.join(output_bbox_folder, f"{base_name}_bbox.jpg")
|
| 105 |
+
cv2.imwrite(bbox_path, img_bbox)
|
utils/task_config.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task class definition """
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import List, Type
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class Task:
|
| 11 |
+
"""Encapsulates all configuration for a single task."""
|
| 12 |
+
name: str
|
| 13 |
+
class_labels: List[str]
|
| 14 |
+
criterion: Type[nn.Module]
|
| 15 |
+
weight: float = 1.0
|
| 16 |
+
use_weighted_loss: bool = False
|
| 17 |
+
|
| 18 |
+
@property
|
| 19 |
+
def num_classes(self) -> int:
|
| 20 |
+
return len(self.class_labels)
|
| 21 |
+
|