File size: 4,537 Bytes
ca7a2c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e98b5a
 
ca7a2c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Visual RAG Tool - Image similarity search using local SigLIP embeddings.

Schema: place_image_embeddings (place_id, embedding, image_url, metadata)
        places_metadata (place_id, name, category, rating, raw_data)
        
Uses local SigLIP model (ViT-B-16-SigLIP) for generating 768-dim image embeddings.
"""

from dataclasses import dataclass
from typing import Optional

from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

from app.shared.integrations.siglip_client import get_siglip_client


@dataclass
class ImageSearchResult:
    """Result from visual similarity search."""

    place_id: str
    name: str
    category: str
    rating: float
    similarity: float
    matched_images: int = 1
    image_url: str = ""

# Tool definition for agent - imported from centralized prompts
from app.shared.prompts import RETRIEVE_SIMILAR_VISUALS_TOOL as TOOL_DEFINITION


async def retrieve_similar_visuals(
    db: AsyncSession,
    image_url: str | None = None,
    image_bytes: bytes | None = None,
    limit: int = 10,
    threshold: float = 0.2,
) -> list[ImageSearchResult]:
    """
    Visual similarity search using local SigLIP embeddings.

    Uses place_image_embeddings table with JOIN to places_metadata.

    Args:
        db: Database session
        image_url: URL of the query image
        image_bytes: Raw image bytes (alternative to URL)
        limit: Maximum results
        threshold: Minimum similarity threshold

    Returns:
        List of places with visual similarity scores
    """
    # Get SigLIP client (singleton)
    siglip = get_siglip_client()
    
    # Generate image embedding using local model
    if image_bytes:
        image_embedding = siglip.embed_image_bytes(image_bytes)
    elif image_url:
        image_embedding = siglip.embed_image_url(image_url)
    else:
        return []

    if image_embedding is None:
        return []

    # Convert numpy array to PostgreSQL vector format
    embedding_str = "[" + ",".join(str(x) for x in image_embedding.tolist()) + "]"

    # Search with JOIN to places_metadata
    sql = text(f"""
        SELECT 
            e.place_id,
            e.image_url,
            1 - (e.embedding <=> '{embedding_str}'::vector) as similarity,
            m.name,
            m.category,
            m.rating
        FROM place_image_embeddings e
        JOIN places_metadata m ON e.place_id = m.place_id
        WHERE 1 - (e.embedding <=> '{embedding_str}'::vector) > :threshold
          AND m.name IS NOT NULL 
          AND m.name != ''
        ORDER BY e.embedding <=> '{embedding_str}'::vector
        LIMIT 100
    """)

    results = await db.execute(sql, {
        "threshold": threshold,
    })

    rows = results.fetchall()

    # Aggregate by place (multiple images per place)
    place_scores: dict = {}

    for r in rows:
        pid = r.place_id

        if pid not in place_scores:
            place_scores[pid] = {
                'total_score': 0.0,
                'count': 0,
                'data': r,
                'best_image': r.image_url,
            }

        place_scores[pid]['total_score'] += float(r.similarity)
        place_scores[pid]['count'] += 1
        
        # Keep best matching image
        if float(r.similarity) > place_scores[pid]['total_score'] / place_scores[pid]['count']:
            place_scores[pid]['best_image'] = r.image_url

    # Sort by average similarity
    sorted_places = sorted(
        place_scores.items(),
        key=lambda x: x[1]['total_score'] / x[1]['count'],
        reverse=True
    )[:limit]

    # Build results
    return [
        ImageSearchResult(
            place_id=pid,
            name=data['data'].name or '',
            category=data['data'].category or '',
            rating=float(data['data'].rating or 0),
            similarity=round(data['total_score'] / data['count'], 4),
            matched_images=data['count'],
            image_url=data['best_image'] or '',
        )
        for pid, data in sorted_places
    ]


async def search_by_image_url(
    db: AsyncSession,
    image_url: str,
    limit: int = 10,
) -> list[ImageSearchResult]:
    """Search places by image URL."""
    return await retrieve_similar_visuals(db=db, image_url=image_url, limit=limit)


async def search_by_image_bytes(
    db: AsyncSession,
    image_bytes: bytes,
    limit: int = 10,
) -> list[ImageSearchResult]:
    """Search places by uploading image bytes."""
    return await retrieve_similar_visuals(db=db, image_bytes=image_bytes, limit=limit)