File size: 4,537 Bytes
ca7a2c2 9e98b5a ca7a2c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""Visual RAG Tool - Image similarity search using local SigLIP embeddings.
Schema: place_image_embeddings (place_id, embedding, image_url, metadata)
places_metadata (place_id, name, category, rating, raw_data)
Uses local SigLIP model (ViT-B-16-SigLIP) for generating 768-dim image embeddings.
"""
from dataclasses import dataclass
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.shared.integrations.siglip_client import get_siglip_client
@dataclass
class ImageSearchResult:
"""Result from visual similarity search."""
place_id: str
name: str
category: str
rating: float
similarity: float
matched_images: int = 1
image_url: str = ""
# Tool definition for agent - imported from centralized prompts
from app.shared.prompts import RETRIEVE_SIMILAR_VISUALS_TOOL as TOOL_DEFINITION
async def retrieve_similar_visuals(
db: AsyncSession,
image_url: str | None = None,
image_bytes: bytes | None = None,
limit: int = 10,
threshold: float = 0.2,
) -> list[ImageSearchResult]:
"""
Visual similarity search using local SigLIP embeddings.
Uses place_image_embeddings table with JOIN to places_metadata.
Args:
db: Database session
image_url: URL of the query image
image_bytes: Raw image bytes (alternative to URL)
limit: Maximum results
threshold: Minimum similarity threshold
Returns:
List of places with visual similarity scores
"""
# Get SigLIP client (singleton)
siglip = get_siglip_client()
# Generate image embedding using local model
if image_bytes:
image_embedding = siglip.embed_image_bytes(image_bytes)
elif image_url:
image_embedding = siglip.embed_image_url(image_url)
else:
return []
if image_embedding is None:
return []
# Convert numpy array to PostgreSQL vector format
embedding_str = "[" + ",".join(str(x) for x in image_embedding.tolist()) + "]"
# Search with JOIN to places_metadata
sql = text(f"""
SELECT
e.place_id,
e.image_url,
1 - (e.embedding <=> '{embedding_str}'::vector) as similarity,
m.name,
m.category,
m.rating
FROM place_image_embeddings e
JOIN places_metadata m ON e.place_id = m.place_id
WHERE 1 - (e.embedding <=> '{embedding_str}'::vector) > :threshold
AND m.name IS NOT NULL
AND m.name != ''
ORDER BY e.embedding <=> '{embedding_str}'::vector
LIMIT 100
""")
results = await db.execute(sql, {
"threshold": threshold,
})
rows = results.fetchall()
# Aggregate by place (multiple images per place)
place_scores: dict = {}
for r in rows:
pid = r.place_id
if pid not in place_scores:
place_scores[pid] = {
'total_score': 0.0,
'count': 0,
'data': r,
'best_image': r.image_url,
}
place_scores[pid]['total_score'] += float(r.similarity)
place_scores[pid]['count'] += 1
# Keep best matching image
if float(r.similarity) > place_scores[pid]['total_score'] / place_scores[pid]['count']:
place_scores[pid]['best_image'] = r.image_url
# Sort by average similarity
sorted_places = sorted(
place_scores.items(),
key=lambda x: x[1]['total_score'] / x[1]['count'],
reverse=True
)[:limit]
# Build results
return [
ImageSearchResult(
place_id=pid,
name=data['data'].name or '',
category=data['data'].category or '',
rating=float(data['data'].rating or 0),
similarity=round(data['total_score'] / data['count'], 4),
matched_images=data['count'],
image_url=data['best_image'] or '',
)
for pid, data in sorted_places
]
async def search_by_image_url(
db: AsyncSession,
image_url: str,
limit: int = 10,
) -> list[ImageSearchResult]:
"""Search places by image URL."""
return await retrieve_similar_visuals(db=db, image_url=image_url, limit=limit)
async def search_by_image_bytes(
db: AsyncSession,
image_bytes: bytes,
limit: int = 10,
) -> list[ImageSearchResult]:
"""Search places by uploading image bytes."""
return await retrieve_similar_visuals(db=db, image_bytes=image_bytes, limit=limit)
|