|
|
"""Visual RAG Tool - Image similarity search using local SigLIP embeddings. |
|
|
|
|
|
Schema: place_image_embeddings (place_id, embedding, image_url, metadata) |
|
|
places_metadata (place_id, name, category, rating, raw_data) |
|
|
|
|
|
Uses local SigLIP model (ViT-B-16-SigLIP) for generating 768-dim image embeddings. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
from sqlalchemy import text |
|
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
|
|
|
|
from app.shared.integrations.siglip_client import get_siglip_client |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ImageSearchResult: |
|
|
"""Result from visual similarity search.""" |
|
|
|
|
|
place_id: str |
|
|
name: str |
|
|
category: str |
|
|
rating: float |
|
|
similarity: float |
|
|
matched_images: int = 1 |
|
|
image_url: str = "" |
|
|
|
|
|
|
|
|
from app.shared.prompts import RETRIEVE_SIMILAR_VISUALS_TOOL as TOOL_DEFINITION |
|
|
|
|
|
|
|
|
async def retrieve_similar_visuals( |
|
|
db: AsyncSession, |
|
|
image_url: str | None = None, |
|
|
image_bytes: bytes | None = None, |
|
|
limit: int = 10, |
|
|
threshold: float = 0.2, |
|
|
) -> list[ImageSearchResult]: |
|
|
""" |
|
|
Visual similarity search using local SigLIP embeddings. |
|
|
|
|
|
Uses place_image_embeddings table with JOIN to places_metadata. |
|
|
|
|
|
Args: |
|
|
db: Database session |
|
|
image_url: URL of the query image |
|
|
image_bytes: Raw image bytes (alternative to URL) |
|
|
limit: Maximum results |
|
|
threshold: Minimum similarity threshold |
|
|
|
|
|
Returns: |
|
|
List of places with visual similarity scores |
|
|
""" |
|
|
|
|
|
siglip = get_siglip_client() |
|
|
|
|
|
|
|
|
if image_bytes: |
|
|
image_embedding = siglip.embed_image_bytes(image_bytes) |
|
|
elif image_url: |
|
|
image_embedding = siglip.embed_image_url(image_url) |
|
|
else: |
|
|
return [] |
|
|
|
|
|
if image_embedding is None: |
|
|
return [] |
|
|
|
|
|
|
|
|
embedding_str = "[" + ",".join(str(x) for x in image_embedding.tolist()) + "]" |
|
|
|
|
|
|
|
|
sql = text(f""" |
|
|
SELECT |
|
|
e.place_id, |
|
|
e.image_url, |
|
|
1 - (e.embedding <=> '{embedding_str}'::vector) as similarity, |
|
|
m.name, |
|
|
m.category, |
|
|
m.rating |
|
|
FROM place_image_embeddings e |
|
|
JOIN places_metadata m ON e.place_id = m.place_id |
|
|
WHERE 1 - (e.embedding <=> '{embedding_str}'::vector) > :threshold |
|
|
AND m.name IS NOT NULL |
|
|
AND m.name != '' |
|
|
ORDER BY e.embedding <=> '{embedding_str}'::vector |
|
|
LIMIT 100 |
|
|
""") |
|
|
|
|
|
results = await db.execute(sql, { |
|
|
"threshold": threshold, |
|
|
}) |
|
|
|
|
|
rows = results.fetchall() |
|
|
|
|
|
|
|
|
place_scores: dict = {} |
|
|
|
|
|
for r in rows: |
|
|
pid = r.place_id |
|
|
|
|
|
if pid not in place_scores: |
|
|
place_scores[pid] = { |
|
|
'total_score': 0.0, |
|
|
'count': 0, |
|
|
'data': r, |
|
|
'best_image': r.image_url, |
|
|
} |
|
|
|
|
|
place_scores[pid]['total_score'] += float(r.similarity) |
|
|
place_scores[pid]['count'] += 1 |
|
|
|
|
|
|
|
|
if float(r.similarity) > place_scores[pid]['total_score'] / place_scores[pid]['count']: |
|
|
place_scores[pid]['best_image'] = r.image_url |
|
|
|
|
|
|
|
|
sorted_places = sorted( |
|
|
place_scores.items(), |
|
|
key=lambda x: x[1]['total_score'] / x[1]['count'], |
|
|
reverse=True |
|
|
)[:limit] |
|
|
|
|
|
|
|
|
return [ |
|
|
ImageSearchResult( |
|
|
place_id=pid, |
|
|
name=data['data'].name or '', |
|
|
category=data['data'].category or '', |
|
|
rating=float(data['data'].rating or 0), |
|
|
similarity=round(data['total_score'] / data['count'], 4), |
|
|
matched_images=data['count'], |
|
|
image_url=data['best_image'] or '', |
|
|
) |
|
|
for pid, data in sorted_places |
|
|
] |
|
|
|
|
|
|
|
|
async def search_by_image_url( |
|
|
db: AsyncSession, |
|
|
image_url: str, |
|
|
limit: int = 10, |
|
|
) -> list[ImageSearchResult]: |
|
|
"""Search places by image URL.""" |
|
|
return await retrieve_similar_visuals(db=db, image_url=image_url, limit=limit) |
|
|
|
|
|
|
|
|
async def search_by_image_bytes( |
|
|
db: AsyncSession, |
|
|
image_bytes: bytes, |
|
|
limit: int = 10, |
|
|
) -> list[ImageSearchResult]: |
|
|
"""Search places by uploading image bytes.""" |
|
|
return await retrieve_similar_visuals(db=db, image_bytes=image_bytes, limit=limit) |
|
|
|