UnMelow commited on
Commit
63add86
·
verified ·
1 Parent(s): 0124ea1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +636 -357
app.py CHANGED
@@ -1,394 +1,673 @@
1
  import os
2
- import re
3
- from io import BytesIO
4
- from typing import List, Tuple
 
5
 
6
  import gradio as gr
7
  import torch
8
- import numpy as np
9
- from PIL import Image, ImageDraw, ImageOps
10
- import fitz # PyMuPDF
11
 
12
  from transformers import (
13
- TrOCRProcessor,
14
- VisionEncoderDecoderModel,
15
- BlipProcessor,
16
- BlipForConditionalGeneration,
17
  )
18
- from transformers.utils import logging as hf_logging
19
 
20
- # -------------------------
21
- # CPU-only, quieter logs
22
- # -------------------------
23
- hf_logging.set_verbosity_error()
24
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
25
-
26
  DEVICE = torch.device("cpu")
27
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
28
 
29
- TROCR_NAME = os.getenv("TROCR_MODEL", "microsoft/trocr-base-printed")
30
- BLIP_NAME = os.getenv("BLIP_MODEL", "Salesforce/blip-image-captioning-base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- DEFAULT_DPI = 200
33
- MAX_SIDE = int(os.getenv("MAX_SIDE", "1600")) # soft cap for CPU speed
 
 
 
34
 
35
- # -------------------------
36
- # Models (CPU)
37
- # -------------------------
38
- trocr_processor = TrOCRProcessor.from_pretrained(TROCR_NAME)
39
- trocr_model = VisionEncoderDecoderModel.from_pretrained(TROCR_NAME).eval().to(DEVICE)
40
 
41
- blip_processor = BlipProcessor.from_pretrained(BLIP_NAME)
42
- blip_model = BlipForConditionalGeneration.from_pretrained(BLIP_NAME).eval().to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # -------------------------
45
- # Optional: Tesseract for image boxes
46
- # -------------------------
47
- def _try_import_tesseract():
48
- try:
49
- import pytesseract # type: ignore
50
- _ = pytesseract.get_tesseract_version()
51
- return pytesseract
52
- except Exception:
53
- return None
54
 
55
- PYTESS = _try_import_tesseract()
56
 
57
- TASKS = ["OCR", "Markdown", "Locate", "Describe"]
 
 
58
 
 
 
 
 
 
59
 
60
- # -------------------------
61
- # Helpers
62
- # -------------------------
63
- def _to_rgb(img: Image.Image) -> Image.Image:
64
- if img.mode in ("RGBA", "LA", "P"):
65
- img = img.convert("RGB")
66
- img = ImageOps.exif_transpose(img)
67
 
68
- # Keep CPU inference reasonable
69
- w, h = img.size
70
- m = max(w, h)
71
- if m > MAX_SIDE:
72
- scale = MAX_SIDE / float(m)
73
- img = img.resize((int(w * scale), int(h * scale)), Image.Resampling.LANCZOS)
74
  return img
75
 
76
 
77
- def _tokenize(s: str) -> List[str]:
78
- return re.findall(r"[A-Za-zА-Яа-я0-9]+", (s or "").lower())
79
-
80
-
81
- def trocr_ocr(img: Image.Image) -> str:
82
- img = _to_rgb(img)
83
- inputs = trocr_processor(images=img, return_tensors="pt")
84
- pixel_values = inputs.pixel_values.to(DEVICE)
85
- with torch.no_grad():
86
- ids = trocr_model.generate(pixel_values, max_new_tokens=256)
87
- text = trocr_processor.batch_decode(ids, skip_special_tokens=True)[0]
88
- return (text or "").strip()
89
-
90
-
91
- def blip_describe(img: Image.Image) -> str:
92
- img = _to_rgb(img)
93
- inputs = blip_processor(images=img, return_tensors="pt").to(DEVICE)
94
- with torch.no_grad():
95
- out = blip_model.generate(**inputs, max_new_tokens=80)
96
- return blip_processor.decode(out[0], skip_special_tokens=True).strip()
97
-
98
-
99
- def render_pdf_page(path: str, page_num: int, dpi: int = DEFAULT_DPI):
100
- doc = fitz.open(path)
101
- page_idx = max(0, min(int(page_num) - 1, len(doc) - 1))
102
- page = doc.load_page(page_idx)
103
- zoom = dpi / 72.0
104
- pix = page.get_pixmap(matrix=fitz.Matrix(zoom, zoom), alpha=False)
105
- img = Image.open(BytesIO(pix.tobytes("png")))
106
- return doc, page, _to_rgb(img), zoom
107
-
108
-
109
- def pdf_has_text(page: fitz.Page) -> bool:
110
- return bool(page.get_text("words"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
112
 
113
- def pdf_extract_text(page: fitz.Page) -> str:
114
- return (page.get_text("text") or "").strip()
 
115
 
 
116
 
117
- def pdf_to_markdown_simple(page: fitz.Page) -> str:
118
- data = page.get_text("dict")
119
- spans = []
120
- for b in data.get("blocks", []):
121
- for ln in b.get("lines", []):
122
- for sp in ln.get("spans", []):
123
- t = (sp.get("text") or "").strip()
124
- if t:
125
- spans.append(float(sp.get("size", 0.0)))
126
- if not spans:
127
- return ""
128
 
129
- med = float(np.median(spans))
130
- h1_thr = med * 1.60
131
- h2_thr = med * 1.35
 
 
132
 
133
- out_lines: List[str] = []
134
- for b in data.get("blocks", []):
135
- if b.get("type") != 0:
136
- continue
137
- for ln in b.get("lines", []):
138
- parts = []
139
- sizes = []
140
- for sp in ln.get("spans", []):
141
- t = (sp.get("text") or "").strip()
142
- if t:
143
- parts.append(t)
144
- sizes.append(float(sp.get("size", 0.0)))
145
- if not parts:
146
- continue
147
- line = " ".join(parts).strip()
148
- sz = max(sizes) if sizes else med
149
- if sz >= h1_thr:
150
- out_lines.append("# " + line)
151
- elif sz >= h2_thr:
152
- out_lines.append("## " + line)
153
- else:
154
- out_lines.append(line)
155
- out_lines.append("")
156
- return "\n".join(out_lines).strip()
157
-
158
-
159
- def draw_rects(img: Image.Image, rects_px: List[Tuple[int, int, int, int]]) -> Image.Image:
160
- out = img.copy()
161
- draw = ImageDraw.Draw(out)
162
- overlay = Image.new("RGBA", out.size, (0, 0, 0, 0))
163
- draw2 = ImageDraw.Draw(overlay)
164
- for (x0, y0, x1, y1) in rects_px:
165
- draw.rectangle([x0, y0, x1, y1], outline=(0, 160, 255), width=3)
166
- draw2.rectangle([x0, y0, x1, y1], fill=(0, 160, 255, 60))
167
- out.paste(overlay, (0, 0), overlay)
168
- return out
169
-
170
-
171
- def locate_in_pdf_words(page: fitz.Page, query: str) -> List[Tuple[float, float, float, float]]:
172
- q = _tokenize(query)
173
- if not q:
174
- return []
175
- words = page.get_text("words")
176
- if not words:
177
- return []
178
-
179
- w_tokens = []
180
- for w in words:
181
- toks = _tokenize(w[4])
182
- w_tokens.append(toks[0] if toks else "")
183
-
184
- rects = []
185
- n, m = len(w_tokens), len(q)
186
- for i in range(0, n - m + 1):
187
- if w_tokens[i:i + m] == q:
188
- xs0 = [float(words[j][0]) for j in range(i, i + m)]
189
- ys0 = [float(words[j][1]) for j in range(i, i + m)]
190
- xs1 = [float(words[j][2]) for j in range(i, i + m)]
191
- ys1 = [float(words[j][3]) for j in range(i, i + m)]
192
- rects.append((min(xs0), min(ys0), max(xs1), max(ys1)))
193
- return rects
194
-
195
-
196
- def locate_in_image_tesseract(img: Image.Image, query: str):
197
- if PYTESS is None:
198
- return [], "Tesseract not available."
199
- q = _tokenize(query)
200
- if not q:
201
- return [], "Empty query."
202
-
203
- img = _to_rgb(img)
204
- data = PYTESS.image_to_data(img, output_type=PYTESS.Output.DICT)
205
-
206
- texts = data.get("text", [])
207
- left = data.get("left", [])
208
- top = data.get("top", [])
209
- width = data.get("width", [])
210
- height = data.get("height", [])
211
- conf = data.get("conf", [])
212
-
213
- tokens = []
214
- boxes = []
215
- for i, t in enumerate(texts):
216
- t = (t or "").strip()
217
- if not t:
218
- continue
219
- toks = _tokenize(t)
220
- if not toks:
221
- continue
222
- try:
223
- c = float(conf[i])
224
- if c < 0:
225
- continue
226
- except Exception:
227
- pass
228
- tokens.append(toks[0])
229
- boxes.append((int(left[i]), int(top[i]), int(left[i] + width[i]), int(top[i] + height[i])))
230
-
231
- rects_px = []
232
- n, m = len(tokens), len(q)
233
- for i in range(0, n - m + 1):
234
- if tokens[i:i + m] == q:
235
- xs0 = [boxes[j][0] for j in range(i, i + m)]
236
- ys0 = [boxes[j][1] for j in range(i, i + m)]
237
- xs1 = [boxes[j][2] for j in range(i, i + m)]
238
- ys1 = [boxes[j][3] for j in range(i, i + m)]
239
- rects_px.append((min(xs0), min(ys0), max(xs1), max(ys1)))
240
-
241
- return rects_px, ("Found." if rects_px else "Not found.")
242
-
243
-
244
- def as_text_block(s: str) -> str:
245
- s = (s or "").strip()
246
- return s if s else ""
247
-
248
-
249
- # -------------------------
250
- # Core processing
251
- # -------------------------
252
- def process(file_path: str, task: str, page_num: int, query: str):
253
- if not file_path:
254
- return "Upload a file.", "", None, None
255
-
256
- ext = os.path.splitext(file_path)[1].lower()
257
-
258
- # PDF
259
- if ext == ".pdf":
260
- doc, page, page_img, zoom = render_pdf_page(file_path, page_num, dpi=DEFAULT_DPI)
261
- try:
262
- preview = page_img
263
-
264
- if task == "Describe":
265
- cap = blip_describe(page_img)
266
- return cap, cap, None, preview
267
-
268
- if task == "OCR":
269
- txt = pdf_extract_text(page) if pdf_has_text(page) else trocr_ocr(page_img)
270
- return txt, txt, None, preview
271
-
272
- if task == "Markdown":
273
- if pdf_has_text(page):
274
- md = pdf_to_markdown_simple(page)
275
- if not md:
276
- md = pdf_extract_text(page)
277
- else:
278
- md = trocr_ocr(page_img)
279
- return md, md, None, preview
280
-
281
- if task == "Locate":
282
- if not (query or "").strip():
283
- return "Enter query.", "", preview, preview
284
-
285
- # selectable-text PDF: precise boxes
286
- rects_pdf = locate_in_pdf_words(page, query)
287
- if rects_pdf:
288
- rects_px = [(int(x0 * zoom), int(y0 * zoom), int(x1 * zoom), int(y1 * zoom)) for x0, y0, x1, y1 in rects_pdf]
289
- boxed = draw_rects(page_img, rects_px)
290
- return "Found.", "", boxed, preview
291
-
292
- # fallback: render + tesseract
293
- rects_px, msg = locate_in_image_tesseract(page_img, query)
294
- boxed = draw_rects(page_img, rects_px) if rects_px else page_img
295
- return msg, "", boxed, preview
296
-
297
- return "Unknown task.", "", None, preview
298
- finally:
299
- doc.close()
300
-
301
- # Image
302
- img = _to_rgb(Image.open(file_path))
303
- preview = img
304
-
305
- if task == "Describe":
306
- cap = blip_describe(img)
307
- return cap, cap, None, preview
308
-
309
- if task == "OCR":
310
- txt = trocr_ocr(img)
311
- return txt, txt, None, preview
312
-
313
- if task == "Markdown":
314
- md = trocr_ocr(img)
315
- return md, md, None, preview
316
-
317
- if task == "Locate":
318
- if not (query or "").strip():
319
- return "Enter query.", "", img, preview
320
- rects_px, msg = locate_in_image_tesseract(img, query)
321
- boxed = draw_rects(img, rects_px) if rects_px else img
322
- return msg, "", boxed, preview
323
-
324
- return "Unknown task.", "", None, preview
325
-
326
-
327
- # -------------------------
328
- # UI wiring
329
- # -------------------------
330
- def update_page_ui(file_path: str):
331
- if not file_path:
332
- return gr.update(visible=False), None
333
-
334
- ext = os.path.splitext(file_path)[1].lower()
335
- if ext != ".pdf":
336
- return gr.update(visible=False), _to_rgb(Image.open(file_path))
337
-
338
- doc = fitz.open(file_path)
339
- pages = max(1, len(doc))
340
- doc.close()
341
-
342
- _, _, img, _ = render_pdf_page(file_path, 1, dpi=DEFAULT_DPI)
343
- return gr.update(visible=True, minimum=1, maximum=pages, value=1), img
344
-
345
-
346
- def update_preview(file_path: str, page_num: int):
347
- if not file_path:
348
- return None
349
- ext = os.path.splitext(file_path)[1].lower()
350
- if ext != ".pdf":
351
- return _to_rgb(Image.open(file_path))
352
- _, _, img, _ = render_pdf_page(file_path, int(page_num), dpi=DEFAULT_DPI)
353
- return img
354
 
 
 
 
 
355
 
356
- def toggle_query(task: str):
357
- return gr.update(visible=(task == "Locate"))
 
358
 
 
359
 
360
- # -------------------------
361
- # Minimal UI style
362
- # -------------------------
363
- theme = gr.themes.Monochrome(
364
- font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"]
365
- )
366
 
367
- with gr.Blocks(theme=theme, title="Doc Tool (CPU)") as demo:
368
- with gr.Row():
369
- with gr.Column(scale=1, min_width=320):
370
- file_in = gr.File(label="File", file_types=["image", ".pdf"], type="filepath")
371
- page = gr.Slider(label="Page", minimum=1, maximum=1, value=1, step=1, visible=False)
372
- task = gr.Dropdown(label="Task", choices=TASKS, value="OCR")
373
- query = gr.Textbox(label="Query", placeholder="Text to locate", visible=False)
374
- run_btn = gr.Button("Run", variant="primary")
375
-
376
- with gr.Column(scale=2):
377
- with gr.Row():
378
- preview = gr.Image(label="Preview", type="pil", height=320)
379
- boxes = gr.Image(label="Boxes", type="pil", height=320)
380
- out = gr.Textbox(label="Output", lines=10)
381
-
382
- file_in.change(update_page_ui, inputs=[file_in], outputs=[page, preview])
383
- page.change(update_preview, inputs=[file_in, page], outputs=[preview])
384
- task.change(toggle_query, inputs=[task], outputs=[query])
385
-
386
- def on_run(fp, t, p, q):
387
- text, _, boxed, prev = process(fp, t, int(p), q or "")
388
- # keep preview stable; boxes only when relevant
389
- return prev, boxed, as_text_block(text)
390
-
391
- run_btn.click(on_run, inputs=[file_in, task, page, query], outputs=[preview, boxes, out])
392
 
393
  if __name__ == "__main__":
394
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
1
  import os
2
+ import random
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import List, Tuple, Dict, Optional
6
 
7
  import gradio as gr
8
  import torch
9
+ from PIL import Image, ImageDraw, ImageFont
 
 
10
 
11
  from transformers import (
12
+ AutoTokenizer,
13
+ AutoModel,
14
+ AutoModelForSeq2SeqLM,
15
+ AutoModelForCausalLM,
16
  )
 
17
 
18
+ # ============================================================
19
+ # CPU setup
20
+ # ============================================================
 
21
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
 
22
  DEVICE = torch.device("cpu")
23
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
24
 
25
+ # ============================================================
26
+ # 3 Transformers (minimum)
27
+ # 1) Coach (Seq2Seq)
28
+ # 2) Opponent (Causal LM)
29
+ # 3) Embeddings (Encoder)
30
+ # ============================================================
31
+ COACH_MODEL_NAME = os.getenv("COACH_MODEL", "google/flan-t5-small")
32
+ OPP_MODEL_NAME = os.getenv("OPP_MODEL", "distilgpt2")
33
+ EMB_MODEL_NAME = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-MiniLM-L3-v2")
34
+
35
+ coach_tok = AutoTokenizer.from_pretrained(COACH_MODEL_NAME)
36
+ coach_model = AutoModelForSeq2SeqLM.from_pretrained(COACH_MODEL_NAME).eval().to(DEVICE)
37
+
38
+ opp_tok = AutoTokenizer.from_pretrained(OPP_MODEL_NAME)
39
+ opp_model = AutoModelForCausalLM.from_pretrained(OPP_MODEL_NAME).eval().to(DEVICE)
40
+
41
+ emb_tok = AutoTokenizer.from_pretrained(EMB_MODEL_NAME)
42
+ emb_model = AutoModel.from_pretrained(EMB_MODEL_NAME).eval().to(DEVICE)
43
+
44
+
45
+ # ============================================================
46
+ # Checkers engine (English draughts-like)
47
+ # Pieces:
48
+ # '.' empty
49
+ # 'w' white man (user)
50
+ # 'W' white king
51
+ # 'b' black man (bot)
52
+ # 'B' black king
53
+ #
54
+ # Coordinates:
55
+ # internal: r=0..7 top->bottom, c=0..7 left->right
56
+ # dark squares: (r+c)%2==1
57
+ # Move string:
58
+ # "b6-a5" or "c3-e5-g7" using a-h and 1-8 (1 is bottom row).
59
+ # ============================================================
60
+
61
+ def inside(r: int, c: int) -> bool:
62
+ return 0 <= r < 8 and 0 <= c < 8
63
+
64
+ def is_dark(r: int, c: int) -> bool:
65
+ return (r + c) % 2 == 1
66
+
67
+ def rc_to_alg(r: int, c: int) -> str:
68
+ # a1 bottom-left => internal (7,0)
69
+ file_ = chr(ord("a") + c)
70
+ rank = str(8 - r)
71
+ return f"{file_}{rank}"
72
+
73
+ def alg_to_rc(s: str) -> Tuple[int, int]:
74
+ s = s.strip().lower()
75
+ c = ord(s[0]) - ord("a")
76
+ r = 8 - int(s[1])
77
+ return r, c
78
+
79
+ def move_seq_to_str(seq: List[Tuple[int, int]]) -> str:
80
+ return "-".join(rc_to_alg(r, c) for r, c in seq)
81
+
82
+ def move_str_to_seq(s: str) -> List[Tuple[int, int]]:
83
+ parts = [p.strip() for p in s.split("-") if p.strip()]
84
+ return [alg_to_rc(p) for p in parts]
85
+
86
+ def piece_color(p: str) -> Optional[str]:
87
+ if p in ("w", "W"):
88
+ return "w"
89
+ if p in ("b", "B"):
90
+ return "b"
91
+ return None
92
+
93
+ def is_king(p: str) -> bool:
94
+ return p in ("W", "B")
95
+
96
+
97
+ @dataclass
98
+ class GameState:
99
+ board: List[List[str]]
100
+ turn: str # "w" user, "b" bot
101
+ history: List[str]
102
+ last_analysis: str
103
+
104
+
105
+ def initial_board() -> List[List[str]]:
106
+ b = [["." for _ in range(8)] for _ in range(8)]
107
+ # Black at top rows 0-2 on dark squares
108
+ for r in range(0, 3):
109
+ for c in range(8):
110
+ if is_dark(r, c):
111
+ b[r][c] = "b"
112
+ # White at bottom rows 5-7 on dark squares
113
+ for r in range(5, 8):
114
+ for c in range(8):
115
+ if is_dark(r, c):
116
+ b[r][c] = "w"
117
+ return b
118
+
119
+ def clone_board(board: List[List[str]]) -> List[List[str]]:
120
+ return [row[:] for row in board]
121
+
122
+ def board_to_ascii(board: List[List[str]]) -> str:
123
+ # compact representation for prompting
124
+ lines = []
125
+ for r in range(8):
126
+ lines.append("".join(board[r]))
127
+ return "\n".join(lines)
128
+
129
+ def count_material(board: List[List[str]]) -> Dict[str, float]:
130
+ score = {"w": 0.0, "b": 0.0}
131
+ for r in range(8):
132
+ for c in range(8):
133
+ p = board[r][c]
134
+ if p == "w":
135
+ score["w"] += 1.0
136
+ elif p == "W":
137
+ score["w"] += 1.6
138
+ elif p == "b":
139
+ score["b"] += 1.0
140
+ elif p == "B":
141
+ score["b"] += 1.6
142
+ return score
143
+
144
+ def promote_if_needed(p: str, r: int) -> str:
145
+ if p == "w" and r == 0:
146
+ return "W"
147
+ if p == "b" and r == 7:
148
+ return "B"
149
+ return p
150
+
151
+
152
+ # ----------------------------
153
+ # Move generation
154
+ # ----------------------------
155
+ def move_dirs(p: str) -> List[Tuple[int, int]]:
156
+ # movement directions (step)
157
+ if p == "w":
158
+ return [(-1, -1), (-1, +1)]
159
+ if p == "b":
160
+ return [(+1, -1), (+1, +1)]
161
+ # kings
162
+ if p in ("W", "B"):
163
+ return [(-1, -1), (-1, +1), (+1, -1), (+1, +1)]
164
+ return []
165
+
166
+ def capture_dirs(p: str) -> List[Tuple[int, int]]:
167
+ # English draughts: men capture forward only; kings both ways
168
+ return move_dirs(p)
169
+
170
+ def gen_simple_moves(board: List[List[str]], color: str) -> List[List[Tuple[int, int]]]:
171
+ moves = []
172
+ for r in range(8):
173
+ for c in range(8):
174
+ p = board[r][c]
175
+ if piece_color(p) != color:
176
+ continue
177
+ for dr, dc in move_dirs(p):
178
+ r2, c2 = r + dr, c + dc
179
+ if inside(r2, c2) and board[r2][c2] == ".":
180
+ moves.append([(r, c), (r2, c2)])
181
+ return moves
182
+
183
+ def gen_captures_from(board: List[List[str]], r: int, c: int, p: str) -> List[List[Tuple[int, int]]]:
184
+ """
185
+ Returns capture sequences starting at (r,c), including start and landings.
186
+ If man reaches king row during capture, we stop (promotion at end of move).
187
+ """
188
+ color = piece_color(p)
189
+ assert color in ("w", "b")
190
+
191
+ sequences = []
192
+ found_any = False
193
+
194
+ for dr, dc in capture_dirs(p):
195
+ r_mid, c_mid = r + dr, c + dc
196
+ r2, c2 = r + 2 * dr, c + 2 * dc
197
+ if not (inside(r2, c2) and inside(r_mid, c_mid)):
198
+ continue
199
+ mid_piece = board[r_mid][c_mid]
200
+ if mid_piece == ".":
201
+ continue
202
+ if piece_color(mid_piece) == color:
203
+ continue
204
+ if board[r2][c2] != ".":
205
+ continue
206
 
207
+ # perform capture on a cloned board
208
+ nb = clone_board(board)
209
+ nb[r][c] = "."
210
+ nb[r_mid][c_mid] = "."
211
+ nb[r2][c2] = p # promotion deferred
212
 
213
+ # stop extending if this is a man that reaches king row
214
+ if (p == "w" and r2 == 0) or (p == "b" and r2 == 7):
215
+ sequences.append([(r, c), (r2, c2)])
216
+ found_any = True
217
+ continue
218
 
219
+ tails = gen_captures_from(nb, r2, c2, p)
220
+ if tails:
221
+ for t in tails:
222
+ sequences.append([(r, c)] + t[1:])
223
+ found_any = True
224
+ else:
225
+ sequences.append([(r, c), (r2, c2)])
226
+ found_any = True
227
+
228
+ return sequences if found_any else []
229
+
230
+ def gen_legal_moves(board: List[List[str]], color: str) -> List[List[Tuple[int, int]]]:
231
+ captures = []
232
+ for r in range(8):
233
+ for c in range(8):
234
+ p = board[r][c]
235
+ if piece_color(p) != color:
236
+ continue
237
+ caps = gen_captures_from(board, r, c, p)
238
+ captures.extend(caps)
239
+
240
+ # forced capture rule
241
+ if captures:
242
+ # remove duplicates (can arise via different recursion paths)
243
+ uniq = {}
244
+ for seq in captures:
245
+ key = tuple(seq)
246
+ uniq[key] = seq
247
+ return list(uniq.values())
248
+
249
+ return gen_simple_moves(board, color)
250
+
251
+ def apply_move(board: List[List[str]], seq: List[Tuple[int, int]]) -> List[List[str]]:
252
+ nb = clone_board(board)
253
+ (r0, c0) = seq[0]
254
+ p = nb[r0][c0]
255
+ nb[r0][c0] = "."
256
+
257
+ for i in range(1, len(seq)):
258
+ (r1, c1) = seq[i - 1]
259
+ (r2, c2) = seq[i]
260
+ # capture if jump
261
+ if abs(r2 - r1) == 2 and abs(c2 - c1) == 2:
262
+ rm = (r1 + r2) // 2
263
+ cm = (c1 + c2) // 2
264
+ nb[rm][cm] = "."
265
+
266
+ (rf, cf) = seq[-1]
267
+ p2 = promote_if_needed(p, rf)
268
+ nb[rf][cf] = p2
269
+ return nb
270
+
271
+ def winner(board: List[List[str]]) -> Optional[str]:
272
+ # winner if opponent has no pieces or no moves
273
+ w_cnt = 0
274
+ b_cnt = 0
275
+ for r in range(8):
276
+ for c in range(8):
277
+ if board[r][c] in ("w", "W"):
278
+ w_cnt += 1
279
+ elif board[r][c] in ("b", "B"):
280
+ b_cnt += 1
281
+ if w_cnt == 0:
282
+ return "b"
283
+ if b_cnt == 0:
284
+ return "w"
285
+ if not gen_legal_moves(board, "w"):
286
+ return "b"
287
+ if not gen_legal_moves(board, "b"):
288
+ return "w"
289
+ return None
290
+
291
+
292
+ # ============================================================
293
+ # Simple engine for analysis (not a transformer):
294
+ # minimax on material + mobility, small depth for CPU.
295
+ # ============================================================
296
+ def eval_board(board: List[List[str]]) -> float:
297
+ m = count_material(board)
298
+ # positive => good for white
299
+ score = (m["w"] - m["b"])
300
+ # mobility bonus
301
+ score += 0.04 * (len(gen_legal_moves(board, "w")) - len(gen_legal_moves(board, "b")))
302
+ return score
303
+
304
+ def minimax(board: List[List[str]], color: str, depth: int, alpha: float, beta: float) -> Tuple[float, Optional[List[Tuple[int, int]]]]:
305
+ win = winner(board)
306
+ if win == "w":
307
+ return 10_000.0, None
308
+ if win == "b":
309
+ return -10_000.0, None
310
+
311
+ if depth == 0:
312
+ return eval_board(board), None
313
+
314
+ moves = gen_legal_moves(board, color)
315
+ if not moves:
316
+ # no moves => lose
317
+ return (-10_000.0 if color == "w" else 10_000.0), None
318
+
319
+ best_move = None
320
+
321
+ if color == "w":
322
+ best = -math.inf
323
+ for mv in moves:
324
+ nb = apply_move(board, mv)
325
+ val, _ = minimax(nb, "b", depth - 1, alpha, beta)
326
+ if val > best:
327
+ best = val
328
+ best_move = mv
329
+ alpha = max(alpha, best)
330
+ if beta <= alpha:
331
+ break
332
+ return best, best_move
333
+ else:
334
+ best = math.inf
335
+ for mv in moves:
336
+ nb = apply_move(board, mv)
337
+ val, _ = minimax(nb, "w", depth - 1, alpha, beta)
338
+ if val < best:
339
+ best = val
340
+ best_move = mv
341
+ beta = min(beta, best)
342
+ if beta <= alpha:
343
+ break
344
+ return best, best_move
345
+
346
+
347
+ # ============================================================
348
+ # Embeddings (transformer #3) for retrieving tips
349
+ # ============================================================
350
+ TIPS = [
351
+ "Всегда проверяй обязательный бой: если есть взятие, обычный ход запрещён.",
352
+ "Старайся сохранять дамочную линию: не открывай край без причины.",
353
+ "Не меняйся, если это приводит к потере темпа и отдаёт центр.",
354
+ "Центр важен: контроль диагоналей увеличивает мобильность и шансы на многоходовые взятия.",
355
+ "Перед ходом оцени ответ соперника: что он берёт или чем отвечает на диагонали?",
356
+ "Если видишь возможность мультибоя, считай траекторию до конца — важно, где ты остановишься.",
357
+ "Дамка сильнее: иногда стоит пожертвовать шашку ради прохода в дамки.",
358
+ "Не оставляй одиночные шашки без поддержки — их легко поймать взятием.",
359
+ "Думай про 'вилку' (двойную угрозу) и про то, чтобы не подставлять шашку под обязательный бой.",
360
+ ]
361
+
362
+ @torch.no_grad()
363
+ def embed_text(text: str) -> torch.Tensor:
364
+ toks = emb_tok(text, return_tensors="pt", truncation=True, max_length=128, padding=True)
365
+ toks = {k: v.to(DEVICE) for k, v in toks.items()}
366
+ out = emb_model(**toks)
367
+ # mean pooling
368
+ last = out.last_hidden_state # [B,T,H]
369
+ mask = toks["attention_mask"].unsqueeze(-1) # [B,T,1]
370
+ pooled = (last * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
371
+ pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
372
+ return pooled[0].cpu()
373
+
374
+ TIP_EMBS = torch.stack([embed_text(t) for t in TIPS], dim=0) # [N,H]
375
+
376
+ def retrieve_tips(query: str, k: int = 3) -> List[str]:
377
+ q = embed_text(query)
378
+ sims = (TIP_EMBS @ q.unsqueeze(1)).squeeze(1) # [N]
379
+ top = torch.topk(sims, k=min(k, len(TIPS))).indices.tolist()
380
+ return [TIPS[i] for i in top]
381
+
382
+
383
+ # ============================================================
384
+ # Coach (transformer #1): generates explanation/feedback
385
+ # ============================================================
386
+ @torch.no_grad()
387
+ def coach_generate(prompt: str, max_new_tokens: int = 160) -> str:
388
+ inp = coach_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
389
+ inp = {k: v.to(DEVICE) for k, v in inp.items()}
390
+ out = coach_model.generate(
391
+ **inp,
392
+ max_new_tokens=max_new_tokens,
393
+ do_sample=False,
394
+ num_beams=1,
395
+ )
396
+ text = coach_tok.decode(out[0], skip_special_tokens=True)
397
+ return text.strip()
398
+
399
+
400
+ # ============================================================
401
+ # Opponent (transformer #2): chooses a legal move
402
+ # ============================================================
403
+ @torch.no_grad()
404
+ def opponent_choose_move(board: List[List[str]], legal_moves: List[str]) -> str:
405
+ # distilgpt2 is not instruction-tuned, so we keep it extremely constrained and parse output.
406
+ board_ascii = board_to_ascii(board)
407
+ moves_block = "\n".join([f"- {m}" for m in legal_moves[:40]]) # cap list
408
+ prompt = (
409
+ "You are playing checkers as Black.\n"
410
+ "Choose ONE move exactly from the list. Output only that move.\n"
411
+ f"Board:\n{board_ascii}\n"
412
+ f"Moves:\n{moves_block}\n"
413
+ "Move:"
414
+ )
415
+ inp = opp_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
416
+ inp = {k: v.to(DEVICE) for k, v in inp.items()}
417
+ gen = opp_model.generate(
418
+ **inp,
419
+ max_new_tokens=24,
420
+ do_sample=True,
421
+ top_p=0.85,
422
+ temperature=0.7,
423
+ pad_token_id=opp_tok.eos_token_id,
424
+ )
425
+ out = opp_tok.decode(gen[0], skip_special_tokens=True)
426
+ tail = out.split("Move:")[-1].strip()
427
+
428
+ # parse: pick the first legal move that appears in the generated tail
429
+ for m in legal_moves:
430
+ if m in tail:
431
+ return m
432
+
433
+ # fallback: try extract token pattern like a1-b2
434
+ cand = re.findall(r"[a-h][1-8](?:-[a-h][1-8])+", tail.lower())
435
+ if cand:
436
+ for c in cand:
437
+ if c in legal_moves:
438
+ return c
439
+
440
+ # final fallback: random legal
441
+ return random.choice(legal_moves)
442
+
443
+
444
+ # ============================================================
445
+ # Rendering board
446
+ # ============================================================
447
+ def render_board(board: List[List[str]], size: int = 520) -> Image.Image:
448
+ pad = 20
449
+ cell = (size - 2 * pad) // 8
450
+ img = Image.new("RGB", (size, size), (245, 245, 245))
451
+ d = ImageDraw.Draw(img)
452
+
453
+ dark = (150, 110, 80)
454
+ light = (235, 220, 200)
455
+
456
+ # grid
457
+ for r in range(8):
458
+ for c in range(8):
459
+ x0 = pad + c * cell
460
+ y0 = pad + r * cell
461
+ x1 = x0 + cell
462
+ y1 = y0 + cell
463
+ d.rectangle([x0, y0, x1, y1], fill=(dark if is_dark(r, c) else light))
464
+
465
+ # pieces
466
+ for r in range(8):
467
+ for c in range(8):
468
+ p = board[r][c]
469
+ if p == ".":
470
+ continue
471
+ cx = pad + c * cell + cell // 2
472
+ cy = pad + r * cell + cell // 2
473
+ rad = int(cell * 0.38)
474
 
475
+ if p in ("w", "W"):
476
+ fill = (245, 245, 245)
477
+ outline = (30, 30, 30)
478
+ else:
479
+ fill = (40, 40, 40)
480
+ outline = (230, 230, 230)
 
 
 
 
481
 
482
+ d.ellipse([cx - rad, cy - rad, cx + rad, cy + rad], fill=fill, outline=outline, width=3)
483
 
484
+ if is_king(p):
485
+ # crown marker
486
+ d.ellipse([cx - rad // 2, cy - rad // 2, cx + rad // 2, cy + rad // 2], outline=(255, 215, 0), width=4)
487
 
488
+ # coordinates
489
+ try:
490
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
491
+ except Exception:
492
+ font = None
493
 
494
+ for c in range(8):
495
+ d.text((pad + c * cell + 3, pad + 8 * cell + 2), chr(ord("a") + c), fill=(30, 30, 30), font=font)
496
+ for r in range(8):
497
+ d.text((3, pad + r * cell + 3), str(8 - r), fill=(30, 30, 30), font=font)
 
 
 
498
 
 
 
 
 
 
 
499
  return img
500
 
501
 
502
+ # ============================================================
503
+ # Game logic wrapper
504
+ # ============================================================
505
+ def new_game() -> GameState:
506
+ return GameState(
507
+ board=initial_board(),
508
+ turn="w",
509
+ history=[],
510
+ last_analysis="",
511
+ )
512
+
513
+ def legal_moves_str(board: List[List[str]], color: str) -> List[str]:
514
+ moves = gen_legal_moves(board, color)
515
+ ms = [move_seq_to_str(mv) for mv in moves]
516
+ # stable ordering: captures first (longer sequences first), then lexicographic
517
+ ms.sort(key=lambda s: (-s.count("-"), s))
518
+ return ms
519
+
520
+ def analyze_user_move(board_before: List[List[str]], user_move_str: str) -> str:
521
+ # engine "best move" as baseline (not a transformer)
522
+ depth = int(os.getenv("ANALYSIS_DEPTH", "3"))
523
+ best_val, best_mv = minimax(board_before, "w", depth=depth, alpha=-math.inf, beta=math.inf)
524
+ best_str = move_seq_to_str(best_mv) if best_mv else "(none)"
525
+
526
+ tips = retrieve_tips("шашки: как улучшить ход и не подставиться", k=3)
527
+
528
+ prompt = (
529
+ "Ты тренер по шашкам. Коротко и по делу.\n"
530
+ f"Ход игрока: {user_move_str}\n"
531
+ f"Рекомендованный ход (по анализу): {best_str}\n"
532
+ "Дай объяснение: почему рекомендованный лучше, и какая ошибка/риск в ходе игрока.\n"
533
+ "Добавь 2-3 практических совета.\n"
534
+ "Подсказки:\n"
535
+ + "\n".join(f"- {t}" for t in tips)
536
+ )
537
+ return coach_generate(prompt, max_new_tokens=180)
538
+
539
+
540
+ def step_user_and_bot(state: GameState, user_move: str) -> Tuple[GameState, str]:
541
+ if winner(state.board) is not None:
542
+ return state, "Game already finished."
543
+
544
+ if state.turn != "w":
545
+ return state, "Not your turn."
546
+
547
+ leg = legal_moves_str(state.board, "w")
548
+ if user_move not in leg:
549
+ return state, "Invalid move (not in legal list)."
550
+
551
+ board_before = clone_board(state.board)
552
+ seq = move_str_to_seq(user_move)
553
+ state.board = apply_move(state.board, seq)
554
+ state.history.append(f"White: {user_move}")
555
+ state.turn = "b"
556
+
557
+ # analysis (coach transformer)
558
+ state.last_analysis = analyze_user_move(board_before, user_move)
559
+
560
+ win = winner(state.board)
561
+ if win is not None:
562
+ state.history.append("Result: " + ("White wins" if win == "w" else "Black wins"))
563
+ return state, ("White wins." if win == "w" else "Black wins.")
564
+
565
+ # bot move
566
+ bot_leg = legal_moves_str(state.board, "b")
567
+ if not bot_leg:
568
+ state.history.append("Result: White wins")
569
+ return state, "White wins."
570
+
571
+ bot_move = opponent_choose_move(state.board, bot_leg)
572
+ bot_seq = move_str_to_seq(bot_move)
573
+ state.board = apply_move(state.board, bot_seq)
574
+ state.history.append(f"Black: {bot_move}")
575
+ state.turn = "w"
576
+
577
+ win = winner(state.board)
578
+ if win is not None:
579
+ state.history.append("Result: " + ("White wins" if win == "w" else "Black wins"))
580
+ return state, ("White wins." if win == "w" else "Black wins.")
581
+
582
+ return state, f"Bot played: {bot_move}"
583
+
584
+
585
+ # ============================================================
586
+ # Coach chat (transformer #1 + embeddings #3)
587
+ # ============================================================
588
+ def coach_chat(state: GameState, message: str, chat_hist: List[Tuple[str, str]]):
589
+ msg = (message or "").strip()
590
+ if not msg:
591
+ return chat_hist, ""
592
+
593
+ # Retrieve tips relevant to the question
594
+ tips = retrieve_tips(msg, k=3)
595
+
596
+ # Provide board context
597
+ context = board_to_ascii(state.board)
598
+ last = state.history[-6:] if state.history else []
599
+
600
+ prompt = (
601
+ "Ты тренер по шашкам. Отвечай кратко, но конкретно.\n"
602
+ f"Вопрос игрока: {msg}\n"
603
+ "Контекст партии (последние ходы):\n"
604
+ + ("\n".join(last) if last else "(нет)")
605
+ + "\n"
606
+ "Доска (ASCII):\n"
607
+ + context
608
+ + "\n"
609
+ "Полезные подсказки:\n"
610
+ + "\n".join(f"- {t}" for t in tips)
611
+ + "\n"
612
+ "Ответ:"
613
+ )
614
+
615
+ answer = coach_generate(prompt, max_new_tokens=180)
616
+ chat_hist = chat_hist + [(msg, answer)]
617
+ return chat_hist, ""
618
+
619
+
620
+ # ============================================================
621
+ # UI
622
+ # ============================================================
623
+ theme = gr.themes.Monochrome(font=[gr.themes.GoogleFont("Inter"), "system-ui"])
624
+
625
+ with gr.Blocks(theme=theme, title="Checkers Coach (CPU, 3 Transformers)") as demo:
626
+ state = gr.State(new_game())
627
 
628
+ with gr.Row():
629
+ with gr.Column(scale=1, min_width=360):
630
+ board_img = gr.Image(label="Board", type="pil", height=520)
631
+ status = gr.Textbox(label="Status", value="", interactive=False)
632
 
633
+ move_dd = gr.Dropdown(label="Your move (White)", choices=[], value=None)
634
+ play_btn = gr.Button("Play move", variant="primary")
635
+ new_btn = gr.Button("New game")
636
 
637
+ analysis = gr.Textbox(label="Coach analysis", lines=10, interactive=False)
638
 
639
+ with gr.Column(scale=1, min_width=360):
640
+ hist = gr.Markdown("")
641
+ gr.Markdown("### Coach chat")
642
+ chat = gr.Chatbot(height=360)
643
+ msg = gr.Textbox(label="Message", placeholder="Ask about strategy, mistakes, next plan…")
644
+ send = gr.Button("Send")
 
 
 
 
 
645
 
646
+ def refresh_ui(gs: GameState):
647
+ img = render_board(gs.board)
648
+ leg = legal_moves_str(gs.board, "w") if winner(gs.board) is None else []
649
+ h = "### History\n" + ("\n".join([f"- {x}" for x in gs.history]) if gs.history else "- (empty)")
650
+ return img, ("" if gs.turn == "w" else "Bot thinking / waiting…"), gr.update(choices=leg, value=(leg[0] if leg else None)), gs.last_analysis, h
651
 
652
+ def on_new():
653
+ gs = new_game()
654
+ return (gs, ) + refresh_ui(gs) + ([], "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
+ def on_play(gs: GameState, mv: str):
657
+ gs, st = step_user_and_bot(gs, mv or "")
658
+ img, _, dd, an, h = refresh_ui(gs)
659
+ return gs, img, st, dd, an, h
660
 
661
+ def on_send(gs: GameState, m: str, ch: List[Tuple[str, str]]):
662
+ ch, cleared = coach_chat(gs, m, ch or [])
663
+ return ch, cleared
664
 
665
+ demo.load(lambda gs: refresh_ui(gs), inputs=[state], outputs=[board_img, status, move_dd, analysis, hist])
666
 
667
+ new_btn.click(on_new, inputs=[], outputs=[state, board_img, status, move_dd, analysis, hist, chat, msg])
668
+ play_btn.click(on_play, inputs=[state, move_dd], outputs=[state, board_img, status, move_dd, analysis, hist])
 
 
 
 
669
 
670
+ send.click(on_send, inputs=[state, msg, chat], outputs=[chat, msg])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
 
672
  if __name__ == "__main__":
673
+ demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)