update demo
Browse files- CHMCorr.py +1 -0
- app.py +41 -18
- visualization.py +117 -10
CHMCorr.py
CHANGED
|
@@ -494,6 +494,7 @@ def export_visualizations_results(
|
|
| 494 |
"chm-prediction": pfn,
|
| 495 |
"chm-prediction-confidence": pr,
|
| 496 |
"chm-nearest-neighbors": rfiles,
|
|
|
|
| 497 |
"correspondance_map": cmaps,
|
| 498 |
"masked_cos_values": MASKED_COSINE_VALUES,
|
| 499 |
"src-keypoints": list_of_source_points,
|
|
|
|
| 494 |
"chm-prediction": pfn,
|
| 495 |
"chm-prediction-confidence": pr,
|
| 496 |
"chm-nearest-neighbors": rfiles,
|
| 497 |
+
"chm-nearest-neighbors-all": reranked_nns,
|
| 498 |
"correspondance_map": cmaps,
|
| 499 |
"masked_cos_values": MASKED_COSINE_VALUES,
|
| 500 |
"src-keypoints": list_of_source_points,
|
app.py
CHANGED
|
@@ -13,7 +13,7 @@ from PIL import Image
|
|
| 13 |
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
| 14 |
from ExtractEmbedding import QueryToEmbedding
|
| 15 |
from CHMCorr import chm_classify_and_visualize
|
| 16 |
-
from visualization import
|
| 17 |
|
| 18 |
csv.field_size_limit(sys.maxsize)
|
| 19 |
|
|
@@ -74,7 +74,7 @@ id_to_bird_name = {
|
|
| 74 |
}
|
| 75 |
|
| 76 |
|
| 77 |
-
def search(query_image,
|
| 78 |
query_embedding = QueryToEmbedding(query_image)
|
| 79 |
scores, indices, labels = searcher.search(query_embedding, k=50)
|
| 80 |
|
|
@@ -101,7 +101,7 @@ def search(query_image, draw_arcs, searcher=searcher):
|
|
| 101 |
query_image, kNN_results, support, training_folder
|
| 102 |
)
|
| 103 |
|
| 104 |
-
fig =
|
| 105 |
|
| 106 |
# Resize the output
|
| 107 |
|
|
@@ -117,35 +117,58 @@ def search(query_image, draw_arcs, searcher=searcher):
|
|
| 117 |
right = (width + new_width) / 2
|
| 118 |
bottom = (height + new_height) / 2
|
| 119 |
|
| 120 |
-
viz_image = image.crop((left +
|
| 121 |
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
blocks = gr.Blocks()
|
| 126 |
|
| 127 |
with blocks:
|
| 128 |
gr.Markdown(""" # CHM-Corr DEMO""")
|
| 129 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
# with gr.Row():
|
| 132 |
input_image = gr.Image(type="filepath")
|
| 133 |
-
with gr.Column():
|
| 134 |
-
arcs_checkbox = gr.Checkbox(label="Draw Arcs")
|
| 135 |
run_btn = gr.Button("Classify")
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
gr.
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
run_btn.click(
|
| 144 |
search,
|
| 145 |
-
inputs=[input_image
|
| 146 |
-
outputs=[viz_plot,
|
| 147 |
)
|
| 148 |
|
|
|
|
| 149 |
if __name__ == "__main__":
|
| 150 |
blocks.launch(
|
| 151 |
debug=True,
|
|
|
|
| 13 |
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
| 14 |
from ExtractEmbedding import QueryToEmbedding
|
| 15 |
from CHMCorr import chm_classify_and_visualize
|
| 16 |
+
from visualization import plot_from_reranker_corrmap
|
| 17 |
|
| 18 |
csv.field_size_limit(sys.maxsize)
|
| 19 |
|
|
|
|
| 74 |
}
|
| 75 |
|
| 76 |
|
| 77 |
+
def search(query_image, searcher=searcher):
|
| 78 |
query_embedding = QueryToEmbedding(query_image)
|
| 79 |
scores, indices, labels = searcher.search(query_embedding, k=50)
|
| 80 |
|
|
|
|
| 101 |
query_image, kNN_results, support, training_folder
|
| 102 |
)
|
| 103 |
|
| 104 |
+
fig, chm_output_label = plot_from_reranker_corrmap(chm_output)
|
| 105 |
|
| 106 |
# Resize the output
|
| 107 |
|
|
|
|
| 117 |
right = (width + new_width) / 2
|
| 118 |
bottom = (height + new_height) / 2
|
| 119 |
|
| 120 |
+
viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80))
|
| 121 |
|
| 122 |
+
chm_output_labels = Counter(
|
| 123 |
+
[
|
| 124 |
+
x.split("/")[-2].replace(".", " ").replace("_", " ")
|
| 125 |
+
for x in chm_output["chm-nearest-neighbors-all"][:20]
|
| 126 |
+
]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()}
|
| 130 |
|
| 131 |
|
| 132 |
blocks = gr.Blocks()
|
| 133 |
|
| 134 |
with blocks:
|
| 135 |
gr.Markdown(""" # CHM-Corr DEMO""")
|
| 136 |
+
gr.Markdown(
|
| 137 |
+
""" ### Parameters: N=50, k=20 - Using ``ImageNet Pretrained ResNet50`` features"""
|
| 138 |
+
)
|
| 139 |
|
|
|
|
| 140 |
input_image = gr.Image(type="filepath")
|
|
|
|
|
|
|
| 141 |
run_btn = gr.Button("Classify")
|
| 142 |
+
gr.Markdown(""" ### CHM-Corr Output Visualization """)
|
| 143 |
+
viz_plot = gr.Image(type="pil", label="Visualization")
|
| 144 |
+
with gr.Row():
|
| 145 |
+
with gr.Column():
|
| 146 |
+
gr.Markdown(""" ### CHM-Corr Prediction """)
|
| 147 |
+
labels = gr.Label(label="Prediction")
|
| 148 |
+
with gr.Column():
|
| 149 |
+
gr.Markdown(""" ### Examples """)
|
| 150 |
+
examples = gr.Examples(
|
| 151 |
+
examples=[
|
| 152 |
+
["./examples/bird.jpg"],
|
| 153 |
+
["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
|
| 154 |
+
["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
|
| 155 |
+
["./examples/sample1.jpeg"],
|
| 156 |
+
["./examples/sample2.jpeg"],
|
| 157 |
+
["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
|
| 158 |
+
["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
|
| 159 |
+
],
|
| 160 |
+
inputs=[input_image],
|
| 161 |
+
outputs=[viz_plot, labels],
|
| 162 |
+
fn=search,
|
| 163 |
+
cache_examples=False,
|
| 164 |
+
)
|
| 165 |
run_btn.click(
|
| 166 |
search,
|
| 167 |
+
inputs=[input_image],
|
| 168 |
+
outputs=[viz_plot, labels],
|
| 169 |
)
|
| 170 |
|
| 171 |
+
|
| 172 |
if __name__ == "__main__":
|
| 173 |
blocks.launch(
|
| 174 |
debug=True,
|
visualization.py
CHANGED
|
@@ -38,7 +38,6 @@ def arg_topK(inputarray, topK=5):
|
|
| 38 |
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
|
| 39 |
|
| 40 |
|
| 41 |
-
# FOR MULTI
|
| 42 |
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
| 43 |
"""
|
| 44 |
visualize chm results from a reranker output dict
|
|
@@ -261,14 +260,122 @@ def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
|
| 261 |
color="black",
|
| 262 |
fontsize=22,
|
| 263 |
)
|
| 264 |
-
# fig.text(
|
| 265 |
-
# 0.8,
|
| 266 |
-
# 0.95,
|
| 267 |
-
# f"KNN: {reranker_output['knn-prediction']}",
|
| 268 |
-
# ha="right",
|
| 269 |
-
# va="bottom",
|
| 270 |
-
# color="black",
|
| 271 |
-
# fontsize=22,
|
| 272 |
-
# )
|
| 273 |
|
| 274 |
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
|
| 39 |
|
| 40 |
|
|
|
|
| 41 |
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
| 42 |
"""
|
| 43 |
visualize chm results from a reranker output dict
|
|
|
|
| 260 |
color="black",
|
| 261 |
fontsize=22,
|
| 262 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
return fig
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def plot_from_reranker_corrmap(reranker_output, draw_box=True):
|
| 268 |
+
"""
|
| 269 |
+
visualize chm results from a reranker output dict
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
### SET COLORS
|
| 273 |
+
cmap = matplotlib.cm.get_cmap("gist_rainbow")
|
| 274 |
+
rgba = cmap(0.5)
|
| 275 |
+
colors = []
|
| 276 |
+
for k in range(5):
|
| 277 |
+
colors.append(cmap(k / 5.0))
|
| 278 |
+
|
| 279 |
+
### SET POINTS
|
| 280 |
+
A = np.linspace(1 + 17, 240 - 17 - 1, 7)
|
| 281 |
+
point_list = list(product(A, A))
|
| 282 |
+
|
| 283 |
+
fig, axes = plt.subplots(
|
| 284 |
+
2,
|
| 285 |
+
7,
|
| 286 |
+
figsize=(25, 8),
|
| 287 |
+
gridspec_kw={
|
| 288 |
+
"wspace": 0,
|
| 289 |
+
"hspace": 0,
|
| 290 |
+
"width_ratios": [1, 0.28, 1, 1, 1, 1, 1],
|
| 291 |
+
},
|
| 292 |
+
facecolor=(1, 1, 1),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
for i in range(2):
|
| 296 |
+
for j in range(7):
|
| 297 |
+
axes[i][j].axis("off")
|
| 298 |
+
|
| 299 |
+
axes[0][0].imshow(
|
| 300 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
| 304 |
+
axes[0][2 + i].imshow(
|
| 305 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Lower ROWs CHM Top5
|
| 309 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
| 310 |
+
axes[1][2 + i].imshow(
|
| 311 |
+
display_transform(
|
| 312 |
+
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
|
| 313 |
+
)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
if reranker_output["chm-prediction-confidence"] < 5:
|
| 317 |
+
for i in range(reranker_output["chm-prediction-confidence"], 5):
|
| 318 |
+
axes[0][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
| 319 |
+
axes[1][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
| 320 |
+
|
| 321 |
+
nzm = reranker_output["non_zero_mask"]
|
| 322 |
+
# Go throught top 5 nearest images
|
| 323 |
+
|
| 324 |
+
# #################################################################################
|
| 325 |
+
if draw_box:
|
| 326 |
+
# SQUARAES
|
| 327 |
+
for NC in range(min(5, reranker_output["chm-prediction-confidence"])):
|
| 328 |
+
# ON SOURCE
|
| 329 |
+
valid_patches_source = arg_topK(
|
| 330 |
+
reranker_output["masked_cos_values"][NC], topK=nzm
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# ON QUERY
|
| 334 |
+
target_masked_patches = arg_topK(
|
| 335 |
+
reranker_output["masked_cos_values"][NC], topK=nzm
|
| 336 |
+
)
|
| 337 |
+
valid_patches_target = [
|
| 338 |
+
reranker_output["correspondance_map"][NC][x]
|
| 339 |
+
for x in target_masked_patches
|
| 340 |
+
]
|
| 341 |
+
valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target]
|
| 342 |
+
|
| 343 |
+
patch_colors = [c for c in colors]
|
| 344 |
+
overlaps = [
|
| 345 |
+
item
|
| 346 |
+
for item, count in Counter(valid_patches_target).items()
|
| 347 |
+
if count > 1
|
| 348 |
+
]
|
| 349 |
+
|
| 350 |
+
for O in overlaps:
|
| 351 |
+
indices = [i for i, val in enumerate(valid_patches_target) if val == O]
|
| 352 |
+
for ii in indices[1:]:
|
| 353 |
+
patch_colors[ii] = patch_colors[indices[0]]
|
| 354 |
+
|
| 355 |
+
for i in valid_patches_source:
|
| 356 |
+
Psource = point_list[i]
|
| 357 |
+
rect = patches.Rectangle(
|
| 358 |
+
(Psource[0] - 16, Psource[1] - 16),
|
| 359 |
+
32,
|
| 360 |
+
32,
|
| 361 |
+
linewidth=2,
|
| 362 |
+
edgecolor=patch_colors[valid_patches_source.tolist().index(i)],
|
| 363 |
+
facecolor="none",
|
| 364 |
+
alpha=1,
|
| 365 |
+
)
|
| 366 |
+
axes[0][2 + NC].add_patch(rect)
|
| 367 |
+
|
| 368 |
+
for i in valid_patches_target:
|
| 369 |
+
Psource = point_list[i]
|
| 370 |
+
rect = patches.Rectangle(
|
| 371 |
+
(Psource[0] - 16, Psource[1] - 16),
|
| 372 |
+
32,
|
| 373 |
+
32,
|
| 374 |
+
linewidth=2,
|
| 375 |
+
edgecolor=patch_colors[valid_patches_target.index(i)],
|
| 376 |
+
facecolor="none",
|
| 377 |
+
alpha=1,
|
| 378 |
+
)
|
| 379 |
+
axes[1][2 + NC].add_patch(rect)
|
| 380 |
+
|
| 381 |
+
return fig, reranker_output["chm-prediction"]
|