UnMelow commited on
Commit
d77255d
·
verified ·
1 Parent(s): 26508ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -488
app.py CHANGED
@@ -1,12 +1,8 @@
1
  import os
2
- import sys
3
  import re
4
- import shutil
5
  import tempfile
6
- import warnings
7
- import base64
8
- from io import StringIO, BytesIO
9
- from typing import List, Tuple
10
 
11
  import gradio as gr
12
  import torch
@@ -15,185 +11,61 @@ from PIL import Image, ImageDraw, ImageFont, ImageOps
15
  import fitz # PyMuPDF
16
 
17
  from transformers import (
18
- AutoModel,
19
- AutoTokenizer,
20
  AutoProcessor,
21
  VisionEncoderDecoderModel,
22
  BlipProcessor,
23
  BlipForConditionalGeneration,
24
  )
25
 
26
- # --- Optional HF Spaces GPU decorator (safe fallback for local runs) ---
27
- try:
28
- import spaces # type: ignore
29
-
30
- gpu_decorator = spaces.GPU
31
- except Exception:
32
- def gpu_decorator(*args, **kwargs):
33
- def wrap(fn):
34
- return fn
35
- return wrap
36
 
 
 
37
 
38
- # =========================
39
- # Device / dtype utilities
40
- # =========================
41
- def get_device() -> str:
42
- return "cuda" if torch.cuda.is_available() else "cpu"
43
 
 
 
44
 
45
- def get_cuda_dtype() -> torch.dtype:
46
- # bf16 only on supported GPUs (Ampere+). Otherwise fp16.
 
 
47
  try:
48
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
49
- return torch.bfloat16
 
 
50
  except Exception:
51
- pass
52
- return torch.float16
53
-
54
-
55
- DEVICE = get_device()
56
- CUDA_DTYPE = get_cuda_dtype() if DEVICE == "cuda" else torch.float32
57
-
58
-
59
- # =========================
60
- # Model names
61
- # =========================
62
- DEEPSEEK_OCR_NAME = os.getenv("DEEPSEEK_OCR_MODEL", "deepseek-ai/DeepSeek-OCR")
63
- # Optional pin to a specific revision/commit to avoid auto-updating remote code.
64
- DEEPSEEK_OCR_REVISION = os.getenv("DEEPSEEK_OCR_REVISION", None)
65
 
66
- TROCR_NAME = os.getenv("TROCR_MODEL", "microsoft/trocr-base-printed")
67
- BLIP_NAME = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
68
 
 
 
 
 
 
 
 
 
 
69
 
70
- # =========================
71
- # Load DeepSeek-OCR safely
72
- # =========================
73
- def load_deepseek_ocr():
74
- tokenizer = AutoTokenizer.from_pretrained(
75
- DEEPSEEK_OCR_NAME,
76
- trust_remote_code=True,
77
- revision=DEEPSEEK_OCR_REVISION,
78
- )
79
 
80
- base_kwargs = dict(
81
- trust_remote_code=True,
82
- use_safetensors=True,
83
- revision=DEEPSEEK_OCR_REVISION,
84
- )
85
 
86
- # IMPORTANT:
87
- # - Do NOT force flash_attention_2 on CPU.
88
- # - On CUDA: try flash_attention_2, but gracefully fallback if unavailable.
89
- if DEVICE == "cuda":
90
- # Try FlashAttention2 first
91
- try:
92
- model = AutoModel.from_pretrained(
93
- DEEPSEEK_OCR_NAME,
94
- torch_dtype=CUDA_DTYPE,
95
- _attn_implementation="flash_attention_2",
96
- **base_kwargs,
97
- )
98
- except Exception as e:
99
- warnings.warn(
100
- f"FlashAttention2 unavailable or failed ({e}). Falling back to SDPA/eager."
101
- )
102
- # Try SDPA
103
- try:
104
- model = AutoModel.from_pretrained(
105
- DEEPSEEK_OCR_NAME,
106
- torch_dtype=CUDA_DTYPE,
107
- _attn_implementation="sdpa",
108
- **base_kwargs,
109
- )
110
- except Exception:
111
- # Final fallback
112
- model = AutoModel.from_pretrained(
113
- DEEPSEEK_OCR_NAME,
114
- torch_dtype=CUDA_DTYPE,
115
- _attn_implementation="eager",
116
- **base_kwargs,
117
- )
118
-
119
- model = model.eval().to(DEVICE)
120
-
121
- else:
122
- # CPU path: no flash attention, use float32 for stability
123
- model = AutoModel.from_pretrained(
124
- DEEPSEEK_OCR_NAME,
125
- torch_dtype=torch.float32,
126
- _attn_implementation="eager",
127
- **base_kwargs,
128
- )
129
- model = model.eval().to(DEVICE)
130
-
131
- return tokenizer, model
132
-
133
-
134
- tokenizer, deepseek_model = load_deepseek_ocr()
135
-
136
-
137
- # =========================
138
- # Load TrOCR and BLIP
139
- # =========================
140
- def load_trocr():
141
- processor = AutoProcessor.from_pretrained(TROCR_NAME)
142
- model = VisionEncoderDecoderModel.from_pretrained(TROCR_NAME).eval()
143
- if DEVICE == "cuda":
144
- model = model.to(DEVICE).to(dtype=CUDA_DTYPE)
145
- else:
146
- model = model.to(DEVICE)
147
- return processor, model
148
-
149
-
150
- def load_blip():
151
- processor = BlipProcessor.from_pretrained(BLIP_NAME)
152
- model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).eval()
153
- if DEVICE == "cuda":
154
- model = model.to(DEVICE).to(dtype=CUDA_DTYPE)
155
- else:
156
- model = model.to(DEVICE)
157
- return processor, model
158
-
159
-
160
- trocr_processor, trocr_model = load_trocr()
161
- blip_processor, blip_model = load_blip()
162
-
163
-
164
- # =========================
165
- # App configs
166
- # =========================
167
- MODEL_CONFIGS = {
168
- "Gundam": {"base_size": 1024, "image_size": 640, "crop_mode": True},
169
- "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
170
- "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
171
- "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
172
- "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
173
- }
174
-
175
- TASK_PROMPTS = {
176
- "📋 Markdown": {
177
- "prompt": "<image>\n<|grounding|>Convert the document to markdown.",
178
- "has_grounding": True,
179
- },
180
- # NOTE: Free OCR теперь делаем через TrOCR (быстро, text-only)
181
- "📝 Free OCR": {"prompt": "", "has_grounding": False},
182
- # Locate оставляем на DeepSeek (grounding)
183
- "📍 Locate": {
184
- "prompt": "<image>\nLocate <|ref|>text<|/ref|> in the image.",
185
- "has_grounding": True,
186
- },
187
- # Describe теперь делаем через BLIP
188
- "🔍 Describe": {"prompt": "", "has_grounding": False},
189
- "✏️ Custom": {"prompt": "", "has_grounding": False},
190
- }
191
-
192
-
193
- # =========================
194
  # Helpers
195
- # =========================
196
- def safe_load_font(size: int = 30) -> ImageFont.FreeTypeFont:
197
  candidates = [
198
  "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
199
  "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
@@ -203,373 +75,376 @@ def safe_load_font(size: int = 30) -> ImageFont.FreeTypeFont:
203
  if os.path.exists(p):
204
  return ImageFont.truetype(p, size)
205
  except Exception:
206
- continue
207
  return ImageFont.load_default()
208
 
209
 
210
- def extract_grounding_references(text: str):
211
- pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
212
- return re.findall(pattern, text, re.DOTALL)
213
-
214
 
215
- def draw_bounding_boxes(image: Image.Image, refs, extract_images: bool = False):
216
- img_w, img_h = image.size
217
- img_draw = image.copy()
218
- draw = ImageDraw.Draw(img_draw)
219
- overlay = Image.new("RGBA", img_draw.size, (0, 0, 0, 0))
220
- draw2 = ImageDraw.Draw(overlay)
221
- font = safe_load_font(30)
222
- crops = []
223
-
224
- color_map = {}
225
- np.random.seed(42)
226
-
227
- for ref in refs:
228
- label = ref[1]
229
- if label not in color_map:
230
- color_map[label] = (
231
- int(np.random.randint(50, 255)),
232
- int(np.random.randint(50, 255)),
233
- int(np.random.randint(50, 255)),
234
- )
235
-
236
- color = color_map[label]
237
- try:
238
- coords = eval(ref[2])
239
- except Exception:
240
- continue
241
 
242
- color_a = color + (60,)
 
243
 
244
- for box in coords:
245
- x1, y1, x2, y2 = (
246
- int(box[0] / 999 * img_w),
247
- int(box[1] / 999 * img_h),
248
- int(box[2] / 999 * img_w),
249
- int(box[3] / 999 * img_h),
250
- )
251
 
252
- if extract_images and label == "image":
253
- crops.append(image.crop((x1, y1, x2, y2)))
254
-
255
- width = 5 if label == "title" else 3
256
- draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
257
- draw2.rectangle([x1, y1, x2, y2], fill=color_a)
 
 
258
 
259
- text_bbox = draw.textbbox((0, 0), label, font=font)
260
- tw, th = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
261
- ty = max(0, y1 - 20)
262
- draw.rectangle([x1, ty, x1 + tw + 4, ty + th + 4], fill=color)
263
- draw.text((x1 + 2, ty + 2), label, font=font, fill=(255, 255, 255))
264
 
265
- img_draw.paste(overlay, (0, 0), overlay)
266
- return img_draw, crops
 
 
 
 
267
 
268
 
269
- def clean_output(text: str, include_images: bool = False) -> str:
270
- if not text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  return ""
272
- pattern = r"(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)"
273
- matches = re.findall(pattern, text, re.DOTALL)
274
- img_num = 0
275
-
276
- for match in matches:
277
- if "<|ref|>image<|/ref|>" in match[0]:
278
- if include_images:
279
- text = text.replace(match[0], f"\n\n**[Figure {img_num + 1}]**\n\n", 1)
280
- img_num += 1
281
- else:
282
- text = text.replace(match[0], "", 1)
283
- else:
284
- text = re.sub(rf"(?m)^[^\n]*{re.escape(match[0])}[^\n]*\n?", "", text)
285
 
286
- return text.strip()
 
 
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- def embed_images(markdown: str, crops: List[Image.Image]) -> str:
290
- if not crops:
291
- return markdown
292
- for i, img in enumerate(crops):
293
- buf = BytesIO()
294
- img.save(buf, format="PNG")
295
- b64 = base64.b64encode(buf.getvalue()).decode()
296
- markdown = markdown.replace(
297
- f"**[Figure {i + 1}]**",
298
- f"\n\n![Figure {i + 1}](data:image/png;base64,{b64})\n\n",
299
- 1,
300
- )
301
- return markdown
302
-
303
-
304
- def trocr_ocr(image: Image.Image) -> str:
305
- if image.mode != "RGB":
306
- image = image.convert("RGB")
307
- pixel_values = trocr_processor(images=image, return_tensors="pt").pixel_values.to(DEVICE)
308
- with torch.no_grad():
309
- # Keep generation modest (faster)
310
- generated_ids = trocr_model.generate(pixel_values, max_new_tokens=256)
311
- text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
312
- return text.strip()
313
 
 
 
314
 
315
- def blip_describe(image: Image.Image) -> str:
316
- if image.mode != "RGB":
317
- image = image.convert("RGB")
318
- inputs = blip_processor(images=image, return_tensors="pt").to(DEVICE)
319
- with torch.no_grad():
320
- out = blip_model.generate(**inputs, max_new_tokens=80)
321
- caption = blip_processor.decode(out[0], skip_special_tokens=True)
322
- return caption.strip()
323
-
324
-
325
- # =========================
326
- # Core processing
327
- # =========================
328
- @gpu_decorator(duration=60)
329
- def process_image(image: Image.Image, mode: str, task: str, custom_prompt: str):
330
- if image is None:
331
- return "Error: upload image", "", "", None, []
332
-
333
- if task in ["✏️ Custom", "📍 Locate"] and not custom_prompt.strip():
334
- return "Error: enter prompt", "", "", None, []
335
-
336
- if image.mode in ("RGBA", "LA", "P"):
337
- image = image.convert("RGB")
338
- image = ImageOps.exif_transpose(image)
339
-
340
- # --- Route tasks to the best backend ---
341
- if task == "📝 Free OCR":
342
- text = trocr_ocr(image)
343
- if not text:
344
- return "No text", "", "", None, []
345
- md = "```text\n" + text + "\n```"
346
- return text, md, text, None, []
347
-
348
- if task == "🔍 Describe":
349
- desc = blip_describe(image)
350
- if not desc:
351
- return "No description", "", "", None, []
352
- md = f"**Description:** {desc}"
353
- return desc, md, desc, None, []
354
-
355
- # --- DeepSeek-OCR for Markdown / Locate / Custom ---
356
- config = MODEL_CONFIGS[mode]
357
-
358
- if task == "✏️ Custom":
359
- prompt = f"<image>\n{custom_prompt.strip()}"
360
- has_grounding = "<|grounding|>" in custom_prompt
361
- elif task == "📍 Locate":
362
- prompt = f"<image>\nLocate <|ref|>{custom_prompt.strip()}<|/ref|> in the image."
363
- has_grounding = True
364
- else:
365
- prompt = TASK_PROMPTS[task]["prompt"]
366
- has_grounding = TASK_PROMPTS[task]["has_grounding"]
367
-
368
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
369
- image.save(tmp.name, "JPEG", quality=95)
370
- tmp.close()
371
- out_dir = tempfile.mkdtemp()
372
-
373
- stdout = sys.stdout
374
- sys.stdout = StringIO()
375
 
376
- try:
377
- deepseek_model.infer(
378
- tokenizer=tokenizer,
379
- prompt=prompt,
380
- image_file=tmp.name,
381
- output_path=out_dir,
382
- base_size=config["base_size"],
383
- image_size=config["image_size"],
384
- crop_mode=config["crop_mode"],
385
- )
386
-
387
- result = "\n".join(
388
- [
389
- l
390
- for l in sys.stdout.getvalue().split("\n")
391
- if not any(
392
- s in l
393
- for s in [
394
- "image:",
395
- "other:",
396
- "PATCHES",
397
- "====",
398
- "BASE:",
399
- "%|",
400
- "torch.Size",
401
- ]
402
- )
403
- ]
404
- ).strip()
405
-
406
- finally:
407
- sys.stdout = stdout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  try:
409
- os.unlink(tmp.name)
 
 
410
  except Exception:
411
  pass
412
- shutil.rmtree(out_dir, ignore_errors=True)
413
-
414
- if not result:
415
- return "No text", "", "", None, []
416
 
417
- cleaned = clean_output(result, include_images=False)
418
- markdown = clean_output(result, include_images=True)
419
 
420
- img_out = None
421
- crops = []
 
 
 
 
 
 
 
 
422
 
423
- if has_grounding and "<|ref|>" in result:
424
- refs = extract_grounding_references(result)
425
- if refs:
426
- img_out, crops = draw_bounding_boxes(image, refs, extract_images=True)
427
 
428
- markdown = embed_images(markdown, crops)
429
 
430
- return cleaned, markdown, result, img_out, crops
431
-
432
-
433
- @gpu_decorator(duration=60)
434
- def process_pdf(path: str, mode: str, task: str, custom_prompt: str, page_num: int):
435
- doc = fitz.open(path)
436
- total_pages = len(doc)
437
- if page_num < 1 or page_num > total_pages:
438
- doc.close()
439
- return f"Invalid page number. PDF has {total_pages} pages.", "", "", None, []
440
- page = doc.load_page(page_num - 1)
441
- pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72), alpha=False)
442
- img = Image.open(BytesIO(pix.tobytes("png")))
443
- doc.close()
444
- return process_image(img, mode, task, custom_prompt)
445
 
446
 
447
- def process_file(path: str, mode: str, task: str, custom_prompt: str, page_num: int):
 
 
 
448
  if not path:
449
- return "Error: upload file", "", "", None, []
450
- if path.lower().endswith(".pdf"):
451
- return process_pdf(path, mode, task, custom_prompt, page_num)
452
- return process_image(Image.open(path), mode, task, custom_prompt)
453
-
454
-
455
- def toggle_prompt(task: str):
456
- if task == "✏️ Custom":
457
- return gr.update(visible=True, label="Custom Prompt", placeholder="Add <|grounding|> for boxes")
458
- if task == "📍 Locate":
459
- return gr.update(visible=True, label="Text to Locate", placeholder="Enter text")
460
- return gr.update(visible=False)
461
 
 
462
 
463
- def select_boxes(task: str):
464
- if task == "📍 Locate":
465
- return gr.update(selected="tab_boxes")
466
- return gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
 
 
 
468
 
469
- def get_pdf_page_count(file_path: str) -> int:
470
- if not file_path or not file_path.lower().endswith(".pdf"):
471
- return 1
472
  doc = fitz.open(file_path)
473
- count = len(doc)
474
  doc.close()
475
- return count
 
 
 
 
 
 
476
 
477
 
478
- def load_image(file_path: str, page_num: int = 1):
479
  if not file_path:
480
  return None
481
- if file_path.lower().endswith(".pdf"):
482
- doc = fitz.open(file_path)
483
- page_idx = max(0, min(int(page_num) - 1, len(doc) - 1))
484
- page = doc.load_page(page_idx)
485
- pix = page.get_pixmap(matrix=fitz.Matrix(300 / 72, 300 / 72), alpha=False)
486
- img = Image.open(BytesIO(pix.tobytes("png")))
487
- doc.close()
488
- return img
489
- return Image.open(file_path)
490
 
491
 
492
- def update_page_selector(file_path: str):
493
- if not file_path:
494
- return gr.update(visible=False)
495
- if file_path.lower().endswith(".pdf"):
496
- page_count = get_pdf_page_count(file_path)
497
- return gr.update(
498
- visible=True,
499
- maximum=page_count,
500
- value=1,
501
- minimum=1,
502
- label=f"Select Page (1-{page_count})",
503
- )
504
- return gr.update(visible=False)
505
-
506
-
507
- # =========================
508
- # UI
509
- # =========================
510
- with gr.Blocks(theme=gr.themes.Soft(), title="DeepSeek-OCR + TrOCR + BLIP") as demo:
511
- gr.Markdown(
512
- f"""
513
- # DeepSeek-OCR Demo (with TrOCR + BLIP)
514
-
515
- This app supports:
516
- - **Markdown**: DeepSeek-OCR (structured markdown + optional grounding boxes)
517
- - **Free OCR**: TrOCR (fast text-only OCR)
518
- - **Locate**: DeepSeek-OCR (grounding boxes)
519
- - **Describe**: BLIP (image captioning)
520
-
521
- Runtime device: **{DEVICE}**
522
- """
523
- )
524
 
525
- with gr.Row():
526
- with gr.Column(scale=1):
527
- file_in = gr.File(label="Upload Image or PDF", file_types=["image", ".pdf"], type="filepath")
528
- input_img = gr.Image(label="Input Image", type="pil", height=300)
529
- page_selector = gr.Number(label="Select Page", value=1, minimum=1, step=1, visible=False)
530
 
531
- mode = gr.Dropdown(list(MODEL_CONFIGS.keys()), value="Gundam", label="Mode")
532
- task = gr.Dropdown(list(TASK_PROMPTS.keys()), value="📋 Markdown", label="Task")
533
- prompt = gr.Textbox(label="Prompt", lines=2, visible=False)
 
 
 
 
 
 
 
 
 
 
 
534
 
535
- btn = gr.Button("Extract", variant="primary", size="lg")
536
 
537
  with gr.Column(scale=2):
538
- with gr.Tabs() as tabs:
539
- with gr.Tab("Text", id="tab_text"):
540
- text_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
541
- with gr.Tab("Markdown Preview", id="tab_markdown"):
542
- md_out = gr.Markdown("")
543
- with gr.Tab("Boxes", id="tab_boxes"):
544
- img_out = gr.Image(type="pil", height=500, show_label=False)
545
- with gr.Tab("Cropped Images", id="tab_crops"):
546
- gallery = gr.Gallery(show_label=False, columns=3, height=400)
547
- with gr.Tab("Raw Text", id="tab_raw"):
548
- raw_out = gr.Textbox(lines=20, show_copy_button=True, show_label=False)
549
-
550
-
551
- # File / PDF page handling
552
- file_in.change(load_image, [file_in, page_selector], [input_img])
553
- file_in.change(update_page_selector, [file_in], [page_selector])
554
- page_selector.change(load_image, [file_in, page_selector], [input_img])
555
-
556
- # Prompt visibility and tab switch
557
- task.change(toggle_prompt, [task], [prompt])
558
- task.change(select_boxes, [task], [tabs])
559
-
560
- def run(image, file_path, mode, task, custom_prompt, page_num):
561
- if file_path:
562
- return process_file(file_path, mode, task, custom_prompt, int(page_num))
563
- if image is not None:
564
- return process_image(image, mode, task, custom_prompt)
565
- return "Error: upload file or image", "", "", None, []
566
-
567
- submit_event = btn.click(
568
- run,
569
- [input_img, file_in, mode, task, prompt, page_selector],
570
- [text_out, md_out, raw_out, img_out, gallery],
571
  )
572
- submit_event.then(select_boxes, [task], [tabs])
573
 
574
  if __name__ == "__main__":
575
- demo.queue(max_size=20).launch()
 
 
1
  import os
 
2
  import re
 
3
  import tempfile
4
+ from io import BytesIO
5
+ from typing import List, Tuple, Optional
 
 
6
 
7
  import gradio as gr
8
  import torch
 
11
  import fitz # PyMuPDF
12
 
13
  from transformers import (
 
 
14
  AutoProcessor,
15
  VisionEncoderDecoderModel,
16
  BlipProcessor,
17
  BlipForConditionalGeneration,
18
  )
19
 
20
+ # -------------------------
21
+ # CPU-only setup
22
+ # -------------------------
23
+ DEVICE = torch.device("cpu")
24
+ torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
 
 
 
 
 
25
 
26
+ TROCR_NAME = os.getenv("TROCR_MODEL", "microsoft/trocr-base-printed")
27
+ BLIP_NAME = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
28
 
29
+ # -------------------------
30
+ # Models (CPU)
31
+ # -------------------------
32
+ trocr_processor = AutoProcessor.from_pretrained(TROCR_NAME)
33
+ trocr_model = VisionEncoderDecoderModel.from_pretrained(TROCR_NAME).eval().to(DEVICE)
34
 
35
+ blip_processor = BlipProcessor.from_pretrained(BLIP_NAME)
36
+ blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).eval().to(DEVICE)
37
 
38
+ # -------------------------
39
+ # Optional: pytesseract (for boxes on images)
40
+ # -------------------------
41
+ def _try_import_tesseract():
42
  try:
43
+ import pytesseract # type: ignore
44
+ # Quick sanity check: version call triggers binary lookup
45
+ _ = pytesseract.get_tesseract_version()
46
+ return pytesseract
47
  except Exception:
48
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ PYTESS = _try_import_tesseract()
 
51
 
52
+ # -------------------------
53
+ # UI / tasks
54
+ # -------------------------
55
+ TASKS = [
56
+ "OCR",
57
+ "Markdown",
58
+ "Locate",
59
+ "Describe",
60
+ ]
61
 
62
+ DEFAULT_DPI = 200 # PDF render DPI
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
64
 
65
+ # -------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Helpers
67
+ # -------------------------
68
+ def _safe_font(size: int = 28):
69
  candidates = [
70
  "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
71
  "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
 
75
  if os.path.exists(p):
76
  return ImageFont.truetype(p, size)
77
  except Exception:
78
+ pass
79
  return ImageFont.load_default()
80
 
81
 
82
+ def _to_rgb(img: Image.Image) -> Image.Image:
83
+ if img.mode in ("RGBA", "LA", "P"):
84
+ img = img.convert("RGB")
85
+ return ImageOps.exif_transpose(img)
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def _tokenize(s: str) -> List[str]:
89
+ return re.findall(r"[A-Za-zА-Яа-я0-9]+", s.lower())
90
 
 
 
 
 
 
 
 
91
 
92
+ def trocr_ocr(img: Image.Image) -> str:
93
+ img = _to_rgb(img)
94
+ inputs = trocr_processor(images=img, return_tensors="pt")
95
+ pixel_values = inputs.pixel_values.to(DEVICE)
96
+ with torch.no_grad():
97
+ ids = trocr_model.generate(pixel_values, max_new_tokens=256)
98
+ text = trocr_processor.batch_decode(ids, skip_special_tokens=True)[0]
99
+ return text.strip()
100
 
 
 
 
 
 
101
 
102
+ def blip_describe(img: Image.Image) -> str:
103
+ img = _to_rgb(img)
104
+ inputs = blip_processor(images=img, return_tensors="pt").to(DEVICE)
105
+ with torch.no_grad():
106
+ out = blip_model.generate(**inputs, max_new_tokens=80)
107
+ return blip_processor.decode(out[0], skip_special_tokens=True).strip()
108
 
109
 
110
+ def render_pdf_page(path: str, page_num: int, dpi: int = DEFAULT_DPI) -> Tuple[fitz.Document, fitz.Page, Image.Image, float]:
111
+ doc = fitz.open(path)
112
+ page_idx = max(0, min(page_num - 1, len(doc) - 1))
113
+ page = doc.load_page(page_idx)
114
+ zoom = dpi / 72.0
115
+ pix = page.get_pixmap(matrix=fitz.Matrix(zoom, zoom), alpha=False)
116
+ img = Image.open(BytesIO(pix.tobytes("png")))
117
+ return doc, page, img, zoom
118
+
119
+
120
+ def pdf_has_text(page: fitz.Page) -> bool:
121
+ # words is empty for scanned pages
122
+ words = page.get_text("words")
123
+ return bool(words)
124
+
125
+
126
+ def pdf_extract_text(page: fitz.Page) -> str:
127
+ txt = page.get_text("text") or ""
128
+ return txt.strip()
129
+
130
+
131
+ def pdf_to_markdown_simple(page: fitz.Page) -> str:
132
+ """
133
+ Lightweight markdown for selectable-text PDFs.
134
+ - Uses span sizes to guess headers.
135
+ - No heavy layout logic (keeps it stable and fast on CPU).
136
+ """
137
+ data = page.get_text("dict")
138
+ spans = []
139
+ for b in data.get("blocks", []):
140
+ for ln in b.get("lines", []):
141
+ for sp in ln.get("spans", []):
142
+ t = (sp.get("text") or "").strip()
143
+ if t:
144
+ spans.append(float(sp.get("size", 0.0)))
145
+ if not spans:
146
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ med = float(np.median(spans))
149
+ h1_thr = med * 1.60
150
+ h2_thr = med * 1.35
151
 
152
+ lines_out: List[str] = []
153
+ for b in data.get("blocks", []):
154
+ if b.get("type") != 0:
155
+ continue
156
+ for ln in b.get("lines", []):
157
+ parts = []
158
+ sizes = []
159
+ for sp in ln.get("spans", []):
160
+ t = (sp.get("text") or "")
161
+ if t.strip():
162
+ parts.append(t.strip())
163
+ sizes.append(float(sp.get("size", 0.0)))
164
+ if not parts:
165
+ continue
166
+ line = " ".join(parts).strip()
167
+ sz = max(sizes) if sizes else med
168
+
169
+ if sz >= h1_thr:
170
+ lines_out.append("# " + line)
171
+ elif sz >= h2_thr:
172
+ lines_out.append("## " + line)
173
+ else:
174
+ lines_out.append(line)
175
 
176
+ lines_out.append("") # paragraph break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ md = "\n".join(lines_out).strip()
179
+ return md
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ def draw_rects(img: Image.Image, rects_px: List[Tuple[int, int, int, int]]) -> Image.Image:
183
+ out = img.copy()
184
+ draw = ImageDraw.Draw(out)
185
+ overlay = Image.new("RGBA", out.size, (0, 0, 0, 0))
186
+ draw2 = ImageDraw.Draw(overlay)
187
+ for (x0, y0, x1, y1) in rects_px:
188
+ draw.rectangle([x0, y0, x1, y1], outline=(0, 160, 255), width=3)
189
+ draw2.rectangle([x0, y0, x1, y1], fill=(0, 160, 255, 60))
190
+ out.paste(overlay, (0, 0), overlay)
191
+ return out
192
+
193
+
194
+ def locate_in_pdf_words(page: fitz.Page, query: str) -> List[Tuple[float, float, float, float]]:
195
+ """
196
+ Returns list of rectangles in PDF coordinate space (points).
197
+ Uses exact word sequence match (token-based).
198
+ """
199
+ q = _tokenize(query)
200
+ if not q:
201
+ return []
202
+
203
+ words = page.get_text("words") # x0,y0,x1,y1,"word",block,line,wordno
204
+ if not words:
205
+ return []
206
+
207
+ w_tokens = [_tokenize(w[4])[0] if _tokenize(w[4]) else "" for w in words]
208
+ rects: List[Tuple[float, float, float, float]] = []
209
+
210
+ n = len(w_tokens)
211
+ m = len(q)
212
+ for i in range(0, n - m + 1):
213
+ if w_tokens[i:i + m] == q:
214
+ xs0 = [float(words[j][0]) for j in range(i, i + m)]
215
+ ys0 = [float(words[j][1]) for j in range(i, i + m)]
216
+ xs1 = [float(words[j][2]) for j in range(i, i + m)]
217
+ ys1 = [float(words[j][3]) for j in range(i, i + m)]
218
+ rects.append((min(xs0), min(ys0), max(xs1), max(ys1)))
219
+
220
+ return rects
221
+
222
+
223
+ def locate_in_image_tesseract(img: Image.Image, query: str) -> Tuple[List[Tuple[int, int, int, int]], str]:
224
+ """
225
+ Returns pixel-space rectangles for located phrase, plus a short status message.
226
+ If pytesseract is not available, returns empty list and message.
227
+ """
228
+ if PYTESS is None:
229
+ return [], "Tesseract not available: no boxes for images."
230
+
231
+ q = _tokenize(query)
232
+ if not q:
233
+ return [], "Empty query."
234
+
235
+ img = _to_rgb(img)
236
+ # Use data dict so it works consistently
237
+ data = PYTESS.image_to_data(img, output_type=PYTESS.Output.DICT)
238
+
239
+ texts = data.get("text", [])
240
+ left = data.get("left", [])
241
+ top = data.get("top", [])
242
+ width = data.get("width", [])
243
+ height = data.get("height", [])
244
+ conf = data.get("conf", [])
245
+
246
+ tokens = []
247
+ boxes = []
248
+ for i, t in enumerate(texts):
249
+ t = (t or "").strip()
250
+ if not t:
251
+ continue
252
+ tok = _tokenize(t)
253
+ if not tok:
254
+ continue
255
+ # Keep only "reasonable" confidence if numeric
256
  try:
257
+ c = float(conf[i])
258
+ if c < 0:
259
+ continue
260
  except Exception:
261
  pass
 
 
 
 
262
 
263
+ tokens.append(tok[0])
264
+ boxes.append((int(left[i]), int(top[i]), int(left[i] + width[i]), int(top[i] + height[i])))
265
 
266
+ rects: List[Tuple[int, int, int, int]] = []
267
+ n = len(tokens)
268
+ m = len(q)
269
+ for i in range(0, n - m + 1):
270
+ if tokens[i:i + m] == q:
271
+ xs0 = [boxes[j][0] for j in range(i, i + m)]
272
+ ys0 = [boxes[j][1] for j in range(i, i + m)]
273
+ xs1 = [boxes[j][2] for j in range(i, i + m)]
274
+ ys1 = [boxes[j][3] for j in range(i, i + m)]
275
+ rects.append((min(xs0), min(ys0), max(xs1), max(ys1)))
276
 
277
+ if not rects:
278
+ return [], "Not found."
279
+ return rects, "Found."
 
280
 
 
281
 
282
+ def as_markdown_block(text: str) -> str:
283
+ if not text.strip():
284
+ return ""
285
+ return "```text\n" + text.strip() + "\n```"
 
 
 
 
 
 
 
 
 
 
 
286
 
287
 
288
+ # -------------------------
289
+ # Main run
290
+ # -------------------------
291
+ def process(path: str, task: str, page_num: int, query: str):
292
  if not path:
293
+ return "Upload a file.", "", None
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ ext = os.path.splitext(path)[1].lower()
296
 
297
+ # ---------- PDF ----------
298
+ if ext == ".pdf":
299
+ doc, page, page_img, zoom = render_pdf_page(path, page_num, dpi=DEFAULT_DPI)
300
+ try:
301
+ if task == "Describe":
302
+ caption = blip_describe(page_img)
303
+ return caption, as_markdown_block(caption), None
304
+
305
+ if task == "OCR":
306
+ if pdf_has_text(page):
307
+ txt = pdf_extract_text(page)
308
+ else:
309
+ txt = trocr_ocr(page_img)
310
+ return txt, as_markdown_block(txt), None
311
+
312
+ if task == "Markdown":
313
+ if pdf_has_text(page):
314
+ md = pdf_to_markdown_simple(page)
315
+ if not md:
316
+ txt = pdf_extract_text(page)
317
+ md = as_markdown_block(txt)
318
+ else:
319
+ txt = trocr_ocr(page_img)
320
+ md = as_markdown_block(txt)
321
+ return md, md, None
322
+
323
+ if task == "Locate":
324
+ if not query.strip():
325
+ return "Enter text to locate.", "", page_img
326
+
327
+ # 1) Prefer precise PDF word boxes (selectable text)
328
+ rects_pdf = locate_in_pdf_words(page, query)
329
+ if rects_pdf:
330
+ # Convert PDF points -> pixels using same render zoom
331
+ rects_px = []
332
+ for (x0, y0, x1, y1) in rects_pdf:
333
+ rects_px.append((int(x0 * zoom), int(y0 * zoom), int(x1 * zoom), int(y1 * zoom)))
334
+ boxed = draw_rects(page_img, rects_px)
335
+ return "Found.", "", boxed
336
+
337
+ # 2) Fallback: if scanned page, try tesseract boxes on rendered image
338
+ rects_px, msg = locate_in_image_tesseract(page_img, query)
339
+ boxed = draw_rects(page_img, rects_px) if rects_px else page_img
340
+ return msg, "", boxed
341
+
342
+ return "Unknown task.", "", None
343
+ finally:
344
+ doc.close()
345
+
346
+ # ---------- Image ----------
347
+ img = _to_rgb(Image.open(path))
348
+
349
+ if task == "Describe":
350
+ caption = blip_describe(img)
351
+ return caption, as_markdown_block(caption), None
352
+
353
+ if task == "OCR":
354
+ txt = trocr_ocr(img)
355
+ return txt, as_markdown_block(txt), None
356
+
357
+ if task == "Markdown":
358
+ txt = trocr_ocr(img)
359
+ md = as_markdown_block(txt)
360
+ return md, md, None
361
+
362
+ if task == "Locate":
363
+ if not query.strip():
364
+ return "Enter text to locate.", "", img
365
+
366
+ rects_px, msg = locate_in_image_tesseract(img, query)
367
+ boxed = draw_rects(img, rects_px) if rects_px else img
368
+ return msg, "", boxed
369
+
370
+ return "Unknown task.", "", None
371
+
372
+
373
+ # -------------------------
374
+ # UI helpers
375
+ # -------------------------
376
+ def update_page_selector(file_path: str):
377
+ if not file_path:
378
+ return gr.update(visible=False), gr.update(value=None)
379
 
380
+ ext = os.path.splitext(file_path)[1].lower()
381
+ if ext != ".pdf":
382
+ return gr.update(visible=False), gr.update(value=_to_rgb(Image.open(file_path)))
383
 
 
 
 
384
  doc = fitz.open(file_path)
385
+ pages = len(doc)
386
  doc.close()
387
+
388
+ # Show first page preview
389
+ _, _, img, _ = render_pdf_page(file_path, 1, dpi=DEFAULT_DPI)
390
+ return (
391
+ gr.update(visible=True, minimum=1, maximum=max(1, pages), value=1),
392
+ gr.update(value=img),
393
+ )
394
 
395
 
396
+ def update_preview(file_path: str, page_num: int):
397
  if not file_path:
398
  return None
399
+ ext = os.path.splitext(file_path)[1].lower()
400
+ if ext != ".pdf":
401
+ return _to_rgb(Image.open(file_path))
402
+ _, _, img, _ = render_pdf_page(file_path, int(page_num), dpi=DEFAULT_DPI)
403
+ return img
 
 
 
 
404
 
405
 
406
+ def toggle_query(task: str):
407
+ return gr.update(visible=(task == "Locate"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
 
 
 
 
 
409
 
410
+ # -------------------------
411
+ # Build app (minimal style)
412
+ # -------------------------
413
+ theme = gr.themes.Base(
414
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"],
415
+ )
416
+
417
+ with gr.Blocks(theme=theme, title="Doc Tool (CPU)") as demo:
418
+ with gr.Row():
419
+ with gr.Column(scale=1, min_width=320):
420
+ file_in = gr.File(label="File", file_types=["image", ".pdf"], type="filepath")
421
+ page_num = gr.Slider(label="Page", minimum=1, maximum=1, value=1, step=1, visible=False)
422
+ task = gr.Dropdown(label="Task", choices=TASKS, value="OCR")
423
+ query = gr.Textbox(label="Query", visible=False, placeholder="Text to locate")
424
 
425
+ run_btn = gr.Button("Run", variant="primary")
426
 
427
  with gr.Column(scale=2):
428
+ preview = gr.Image(label="Preview", type="pil", height=360)
429
+ out_text = gr.Textbox(label="Output", lines=10)
430
+ out_md = gr.Markdown()
431
+
432
+ out_boxes = gr.Image(label="Boxes", type="pil", height=360)
433
+
434
+ file_in.change(update_page_selector, inputs=[file_in], outputs=[page_num, preview])
435
+ page_num.change(update_preview, inputs=[file_in, page_num], outputs=[preview])
436
+ task.change(toggle_query, inputs=[task], outputs=[query])
437
+
438
+ def on_run(file_path, task_name, page, q):
439
+ text, md, boxed = process(file_path, task_name, int(page), q or "")
440
+ return text, md, boxed
441
+
442
+ run_btn.click(
443
+ on_run,
444
+ inputs=[file_in, task, page_num, query],
445
+ outputs=[out_text, out_md, out_boxes],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  )
 
447
 
448
  if __name__ == "__main__":
449
+ # Disable SSR to avoid extra startup noise
450
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)