UnMelow commited on
Commit
1c45e58
·
verified ·
1 Parent(s): 112c4ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +420 -512
app.py CHANGED
@@ -1,251 +1,218 @@
1
  import os
2
  import re
3
- import time
4
- import math
5
- import threading
6
  from dataclasses import dataclass
7
- from typing import Any, Dict, List, Optional, Tuple
 
8
 
9
  import numpy as np
10
  import torch
11
  import gradio as gr
12
 
13
- from huggingface_hub import HfApi
14
  from transformers import (
15
  AutoTokenizer,
16
  AutoModel,
17
- AutoModelForQuestionAnswering,
18
- AutoModelForSeq2SeqLM,
19
  )
20
  from transformers.utils import logging as hf_logging
21
 
22
 
23
- # ============================================================
24
- # CPU-only + timeouts + quiet logs
25
- # ============================================================
26
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
27
  os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
28
  os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
29
- os.environ.setdefault("HF_HUB_ETAG_TIMEOUT", "5")
30
- os.environ.setdefault("HF_HUB_DOWNLOAD_TIMEOUT", "30")
31
-
32
  hf_logging.set_verbosity_error()
33
 
34
- DEVICE = torch.device("cpu")
35
  torch.set_grad_enabled(False)
36
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
37
 
38
- # ============================================================
39
- # Hard limits (RAM + speed)
40
- # ============================================================
41
- MAX_INPUT_CHARS = 80_000
42
- MAX_CHUNKS = 120
43
- CHUNK_TARGET_CHARS = 900
44
- EMBED_BATCH = 16
45
-
46
- GEN_MAX_NEW_TOKENS = 240
47
- GEN_MIN_NEW_TOKENS = 80
48
-
49
- QA_MAX_LENGTH = 384
50
- QA_STRIDE = 128
51
- MAX_CONTEXT_CHARS = 4_000
52
- MAX_ANSWER_LEN_TOKENS = 40
53
-
54
- # ============================================================
55
- # 3+ Transformers:
56
- # 1) Generator (RU-friendly): mT5-small
57
- # 2) Embeddings: multilingual-e5-small
58
- # 3) Extractive QA: mBERT xquad
59
- # ============================================================
60
- GEN_CANDIDATES = [
61
- "google/mt5-small",
62
- "google/flan-t5-small",
63
  ]
64
 
65
- EMB_CANDIDATES = [
66
- "intfloat/multilingual-e5-small",
67
- "intfloat/e5-small-v2",
68
  ]
69
 
70
- QA_CANDIDATES = [
71
- "mrm8488/bert-multi-cased-finetuned-xquadv1",
72
- "timopixel/bert-base-multilingual-cased-finetuned-squad",
73
  ]
74
 
75
-
76
- def hf_exists(model_id: str) -> bool:
77
- """
78
- Fast availability check. If no network, we assume it exists and will try to load.
79
- """
80
- try:
81
- api = HfApi()
82
- api.model_info(model_id)
83
- return True
84
- except Exception:
85
- return True
86
-
87
-
88
- def pick_model(candidates: List[str]) -> str:
89
- for mid in candidates:
90
- if hf_exists(mid):
91
- return mid
92
- return candidates[0]
93
-
94
-
95
- GEN_ID = pick_model(GEN_CANDIDATES)
96
- EMB_ID = pick_model(EMB_CANDIDATES)
97
- QA_ID = pick_model(QA_CANDIDATES)
98
-
99
-
100
- # ============================================================
101
- # Lazy loaders (avoid loading everything on start)
102
- # ============================================================
103
- _load_lock = threading.Lock()
104
-
105
- _GEN_TOK = None
106
- _GEN_MODEL = None
107
-
108
- _EMB_TOK = None
109
- _EMB_MODEL = None
110
-
111
- _QA_TOK = None
112
- _QA_MODEL = None
113
-
114
-
115
- def load_gen():
116
- global _GEN_TOK, _GEN_MODEL
117
- with _load_lock:
118
- if _GEN_TOK is not None and _GEN_MODEL is not None:
119
- return _GEN_TOK, _GEN_MODEL
120
- tok = AutoTokenizer.from_pretrained(GEN_ID, use_fast=True)
121
- model = AutoModelForSeq2SeqLM.from_pretrained(
122
- GEN_ID,
123
- torch_dtype=torch.float32,
124
- low_cpu_mem_usage=True,
125
- ).eval()
126
- _GEN_TOK, _GEN_MODEL = tok, model
127
- return tok, model
128
-
129
-
130
- def load_emb():
131
- global _EMB_TOK, _EMB_MODEL
132
- with _load_lock:
133
- if _EMB_TOK is not None and _EMB_MODEL is not None:
134
- return _EMB_TOK, _EMB_MODEL
135
- tok = AutoTokenizer.from_pretrained(EMB_ID, use_fast=True)
136
- model = AutoModel.from_pretrained(
137
- EMB_ID,
138
- torch_dtype=torch.float32,
139
- low_cpu_mem_usage=True,
140
- ).eval()
141
- _EMB_TOK, _EMB_MODEL = tok, model
142
- return tok, model
143
-
144
-
145
- def load_qa():
146
- global _QA_TOK, _QA_MODEL
147
- with _load_lock:
148
- if _QA_TOK is not None and _QA_MODEL is not None:
149
- return _QA_TOK, _QA_MODEL
150
- tok = AutoTokenizer.from_pretrained(QA_ID, use_fast=True)
151
- model = AutoModelForQuestionAnswering.from_pretrained(
152
- QA_ID,
153
- torch_dtype=torch.float32,
154
- low_cpu_mem_usage=True,
155
- ).eval()
156
- _QA_TOK, _QA_MODEL = tok, model
157
- return tok, model
158
-
159
-
160
- # ============================================================
161
- # Utilities
162
- # ============================================================
163
- def safe_trunc(s: str, max_chars: int) -> str:
164
  s = (s or "").strip()
165
  if len(s) > max_chars:
166
- return s[:max_chars].rstrip() + "\n\n[Обрезано по лимиту длины]"
167
  return s
168
 
169
-
170
- def norm_space(s: str) -> str:
171
  return re.sub(r"\s+", " ", (s or "")).strip()
172
 
173
-
174
- def split_chunks(text: str) -> List[str]:
175
- text = safe_trunc(text, MAX_INPUT_CHARS)
176
  paras = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
177
-
178
- chunks: List[str] = []
179
  buf = ""
 
180
  for p in paras:
181
  if not buf:
182
  buf = p
183
- continue
184
- if len(buf) + 2 + len(p) <= CHUNK_TARGET_CHARS:
185
  buf = buf + "\n\n" + p
186
  else:
187
  chunks.append(buf.strip())
188
  buf = p
189
  if len(chunks) >= MAX_CHUNKS:
190
  break
 
191
  if buf and len(chunks) < MAX_CHUNKS:
192
  chunks.append(buf.strip())
193
 
194
- # if single paragraph is too big, split by sentences
195
- sent_split = re.compile(r"(?<=[\.\!\?…])\s+")
196
- fixed: List[str] = []
197
  for c in chunks:
198
- if len(c) <= CHUNK_TARGET_CHARS * 1.6:
199
- fixed.append(c)
200
  continue
201
- sents = [s.strip() for s in sent_split.split(c) if s.strip()]
202
  b = ""
203
  for s in sents:
204
  if not b:
205
  b = s
206
- continue
207
- if len(b) + 1 + len(s) <= CHUNK_TARGET_CHARS:
208
  b = b + " " + s
209
  else:
210
- fixed.append(b.strip())
211
  b = s
212
- if len(fixed) >= MAX_CHUNKS:
213
  break
214
- if b and len(fixed) < MAX_CHUNKS:
215
- fixed.append(b.strip())
216
- if len(fixed) >= MAX_CHUNKS:
217
  break
218
 
219
- return fixed[:MAX_CHUNKS]
 
 
 
 
 
 
 
 
 
 
 
220
 
 
 
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  @torch.inference_mode()
223
- def mean_pool(last_hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
224
  m = mask.unsqueeze(-1).bool()
225
  x = last_hidden.masked_fill(~m, 0.0)
226
  summed = x.sum(dim=1)
227
  denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1)
228
  return summed / denom
229
 
230
-
231
  @torch.inference_mode()
232
  def embed_texts(texts: List[str], is_query: bool) -> np.ndarray:
233
- tok, model = load_emb()
234
-
235
- # E5 prefix convention improves retrieval
236
  prefix = "query: " if is_query else "passage: "
237
- batch_texts = [prefix + norm_space(t) for t in texts]
238
 
239
  vecs = []
240
- for i in range(0, len(batch_texts), EMBED_BATCH):
241
- batch = batch_texts[i:i + EMBED_BATCH]
242
  enc = tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
243
  out = model(**enc)
244
- pooled = mean_pool(out.last_hidden_state, enc["attention_mask"])
245
  pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
246
  vecs.append(pooled.cpu().numpy().astype(np.float32))
247
- return np.vstack(vecs)
248
 
 
249
 
250
  def topk_cosine(q: np.ndarray, mat: np.ndarray, k: int) -> List[Tuple[int, float]]:
251
  scores = (mat @ q.reshape(-1, 1)).squeeze(1)
@@ -258,394 +225,335 @@ def topk_cosine(q: np.ndarray, mat: np.ndarray, k: int) -> List[Tuple[int, float
258
 
259
 
260
  @dataclass
261
- class IndexState:
 
262
  text: str
263
  chunks: List[str]
264
- emb: Optional[np.ndarray]
265
 
266
 
267
- def build_index(text: str) -> IndexState:
268
- text = safe_trunc(text, MAX_INPUT_CHARS)
269
- chunks = split_chunks(text)
270
- if not chunks:
271
- return IndexState(text=text, chunks=[], emb=None)
272
- emb = embed_texts(chunks, is_query=False)
273
- return IndexState(text=text, chunks=chunks, emb=emb)
274
 
 
 
 
 
 
 
275
 
276
- def ensure_index(state: Optional[Dict[str, Any]], text: str) -> IndexState:
277
- text = safe_trunc(text, MAX_INPUT_CHARS)
278
- if not state or state.get("text") != text:
279
- st = build_index(text)
280
- return st
281
- return IndexState(text=state["text"], chunks=state["chunks"], emb=state["emb"])
282
 
 
 
 
 
283
 
284
- def retrieve(st: IndexState, query: str, k: int = 5) -> List[Tuple[float, str]]:
 
 
285
  query = (query or "").strip()
286
- if not query or not st.chunks or st.emb is None:
287
  return []
288
  qv = embed_texts([query], is_query=True)[0]
289
- hits = topk_cosine(qv, st.emb, k=k)
290
- return [(score, st.chunks[idx]) for idx, score in hits]
291
-
292
-
293
- # ============================================================
294
- # Generator (mT5 / flan)
295
- # ============================================================
296
- @torch.inference_mode()
297
- def generate_text(prompt: str,
298
- max_new_tokens: int = GEN_MAX_NEW_TOKENS,
299
- min_new_tokens: int = 0,
300
- do_sample: bool = False) -> str:
301
- tok, model = load_gen()
302
- enc = tok(prompt, return_tensors="pt", truncation=True, max_length=512)
303
-
304
- out = model.generate(
305
- **enc,
306
- max_new_tokens=max_new_tokens,
307
- min_new_tokens=min_new_tokens,
308
- num_beams=4 if not do_sample else 1,
309
- do_sample=do_sample,
310
- temperature=0.9 if do_sample else None,
311
- top_p=0.95 if do_sample else None,
312
- repetition_penalty=1.05,
313
- no_repeat_ngram_size=3,
314
- early_stopping=True,
315
- )
316
- s = tok.decode(out[0], skip_special_tokens=True).strip()
317
- return s
318
 
319
 
320
- def robust_summary(selected_text: str) -> Tuple[str, str]:
321
- """
322
- Returns (title, summary). Retries if model outputs too short.
323
- """
324
- selected_text = safe_trunc(selected_text, 4500)
325
-
326
- title_prompt = (
327
- "Сформулируй короткий заголовок (до 12 слов) для текста.\n\n"
328
- f"Текст:\n{selected_text}\n\n"
329
- "Заголовок:"
330
- )
331
- title = generate_text(title_prompt, max_new_tokens=32, min_new_tokens=8, do_sample=False)
332
- title = title.strip().strip('"').strip()
333
-
334
- sum_prompt = (
335
- "Сделай связный пересказ текста на русском языке. "
336
- "Требования: 6–10 предложений, без воды, сохранить ключевые причины, эффекты и вывод.\n\n"
337
- f"Текст:\n{selected_text}\n\n"
338
- "Пересказ:"
339
- )
340
- summary = generate_text(sum_prompt, max_new_tokens=GEN_MAX_NEW_TOKENS, min_new_tokens=GEN_MIN_NEW_TOKENS, do_sample=False)
341
-
342
- # If too short -> retry with bullet format
343
- if len(summary) < 120 and len(selected_text) > 600:
344
- sum_prompt2 = (
345
- "Сделай конспект текста на русском: 6–10 пунктов, каждый пункт 1 строка. "
346
- "Пункты должны покрывать весь текст.\n\n"
347
- f"Текст:\n{selected_text}\n\n"
348
- "Конспект:"
349
- )
350
- summary2 = generate_text(sum_prompt2, max_new_tokens=GEN_MAX_NEW_TOKENS, min_new_tokens=80, do_sample=True)
351
- if len(summary2) > len(summary):
352
- summary = summary2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- return title, summary
355
 
356
 
357
- def generate_questions(text: str, n: int) -> List[str]:
358
- n = int(max(1, min(n, 12)))
359
- text = safe_trunc(text, 3000)
360
-
361
- prompt = (
362
- f"Сгенерируй {n} вопросов для самопроверки по тексту. "
363
- "Вопросы должны проверять понимание причинно-следственных связей и выводов. "
364
- "Формат: нумерова��ный список.\n\n"
365
- f"Текст:\n{text}\n\n"
366
- "Вопросы:\n"
367
- )
368
- raw = generate_text(prompt, max_new_tokens=220, min_new_tokens=80, do_sample=True)
369
-
370
- # parse numbered list
371
- qs = []
372
- for line in raw.splitlines():
373
- line = line.strip()
374
- m = re.match(r"^\d+[\)\.\-]\s*(.+)$", line)
375
- if m:
376
- q = m.group(1).strip()
377
- if q and not q.endswith("?"):
378
- q += "?"
379
- qs.append(q)
380
- # fallback: split by '?'
381
- if not qs:
382
- parts = [p.strip() for p in re.split(r"\?\s*", raw) if p.strip()]
383
- qs = [(p + "?") for p in parts[:n]]
384
- # unique + cap
385
- seen = set()
386
  out = []
387
- for q in qs:
388
- ql = q.lower()
389
- if ql in seen:
390
- continue
391
- seen.add(ql)
392
- out.append(q)
393
- if len(out) >= n:
394
- break
395
  return out
396
 
397
-
398
- # ============================================================
399
- # Extractive QA (FIXED: remove overflow_to_sample_mapping)
400
- # ============================================================
401
- @torch.inference_mode()
402
- def extractive_qa(question: str, context: str) -> Tuple[str, str]:
403
- question = (question or "").strip()
404
- context = (context or "").strip()
405
- if not question or not context:
406
- return "", ""
407
-
408
- tok, model = load_qa()
409
- context = safe_trunc(context, MAX_CONTEXT_CHARS)
410
-
411
- enc = tok(
412
- question,
413
- context,
414
- truncation="only_second",
415
- max_length=QA_MAX_LENGTH,
416
- stride=QA_STRIDE,
417
- return_overflowing_tokens=True,
418
- return_offsets_mapping=True,
419
- padding=True,
420
- return_tensors="pt",
421
- )
422
-
423
- offset_mapping = enc.pop("offset_mapping")
424
- # IMPORTANT: do not pass these to model
425
- enc.pop("overflow_to_sample_mapping", None)
426
- enc.pop("num_truncated_tokens", None)
427
- enc.pop("special_tokens_mask", None)
428
-
429
- # Only model inputs
430
- model_inputs = {k: v for k, v in enc.items() if k in ("input_ids", "attention_mask", "token_type_ids")}
431
- outputs = model(**model_inputs)
432
-
433
- start = outputs.start_logits.detach().cpu().numpy()
434
- end = outputs.end_logits.detach().cpu().numpy()
435
-
436
- best_score = -1e9
437
- best_span = (0, 0)
438
- best_ctx = context
439
-
440
- for i in range(start.shape[0]):
441
- seq_ids = tok.sequence_ids(i)
442
- offsets = offset_mapping[i].detach().cpu().numpy()
443
-
444
- # context token indices
445
- ctx_idxs = [j for j, sid in enumerate(seq_ids) if sid == 1 and not (offsets[j][0] == 0 and offsets[j][1] == 0)]
446
- if not ctx_idxs:
447
  continue
 
 
448
 
449
- s_logits = start[i]
450
- e_logits = end[i]
451
-
452
- # Take top candidates to avoid O(n^2)
453
- top_s = sorted(ctx_idxs, key=lambda j: s_logits[j], reverse=True)[:20]
454
- top_e = sorted(ctx_idxs, key=lambda j: e_logits[j], reverse=True)[:20]
455
-
456
- for s_idx in top_s:
457
- for e_idx in top_e:
458
- if e_idx < s_idx:
459
- continue
460
- if (e_idx - s_idx) > MAX_ANSWER_LEN_TOKENS:
461
- continue
462
- score = float(s_logits[s_idx] + e_logits[e_idx])
463
- if score > best_score:
464
- a = int(offsets[s_idx][0])
465
- b = int(offsets[e_idx][1])
466
- if b > a:
467
- best_score = score
468
- best_span = (a, b)
469
-
470
- ans = best_ctx[best_span[0]:best_span[1]].strip()
471
- if not ans:
472
- return "", ""
473
-
474
- left = max(0, best_span[0] - 120)
475
- right = min(len(best_ctx), best_span[1] + 180)
476
- snippet = best_ctx[left:right].strip()
477
- if left > 0:
478
- snippet = "…" + snippet
479
- if right < len(best_ctx):
480
- snippet = snippet + "…"
481
 
482
- return ans, snippet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
 
 
484
 
485
- # ============================================================
486
- # Features
487
- # ============================================================
488
- def select_central_text(st: IndexState, level: str) -> str:
489
- if not st.chunks or st.emb is None:
490
- return ""
491
- emb = st.emb
492
- centroid = emb.mean(axis=0)
493
- centroid = centroid / (np.linalg.norm(centroid) + 1e-12)
494
- sims = (emb @ centroid.reshape(-1, 1)).squeeze(1)
495
-
496
- k = 3 if level == "Коротко" else 6
497
- k = min(k, len(st.chunks))
498
- idx = np.argpartition(-sims, k - 1)[:k]
499
- idx = idx[np.argsort(-sims[idx])]
500
- return "\n\n".join(st.chunks[i] for i in idx.tolist())
501
-
502
-
503
- def do_summary(text: str, state: Optional[Dict[str, Any]], level: str) -> Tuple[str, Dict[str, Any]]:
504
- st = ensure_index(state, text)
505
- selected = select_central_text(st, level)
506
- if not selected:
507
- return "Нет текста для пересказа.", st.__dict__
508
- title, summ = robust_summary(selected)
509
- md = f"### Заголовок\n{title}\n\n### Пересказ\n{summ}"
510
- return md, st.__dict__
511
-
512
-
513
- def do_search(text: str, state: Optional[Dict[str, Any]], query: str, k: int) -> Tuple[str, Dict[str, Any]]:
514
- st = ensure_index(state, text)
515
- query = (query or "").strip()
516
- if not query:
517
- return "Введите запрос.", st.__dict__
518
- hits = retrieve(st, query, k=int(max(1, min(k, 10))))
519
- if not hits:
520
- return "Ничего не найдено.", st.__dict__
521
- out = ["### Результаты\n"]
522
- for i, (score, chunk) in enumerate(hits, 1):
523
- out.append(f"**{i}. score={score:.3f}**\n{chunk}\n")
524
- return "\n".join(out).strip(), st.__dict__
525
-
526
-
527
- def do_quiz(text: str, state: Optional[Dict[str, Any]], n: int) -> Tuple[str, Dict[str, Any]]:
528
- st = ensure_index(state, text)
529
- if not st.chunks:
530
- return "Нет текста.", st.__dict__
531
-
532
- # build a compact source for question generation (central passages)
533
- central = select_central_text(st, "Подробнее")
534
- if not central:
535
- central = safe_trunc(st.text, 3000)
536
-
537
- questions = generate_questions(central, int(n))
538
- if not questions:
539
- return "Не удалось сгенерировать вопросы.", st.__dict__
540
-
541
- # answer each question from retrieved context
542
- lines = ["### Вопросы и ответы\n"]
543
- for i, q in enumerate(questions, 1):
544
- hits = retrieve(st, q, k=4)
545
- ctx = "\n\n".join([c for _, c in hits]) if hits else central
546
- ctx = safe_trunc(ctx, MAX_CONTEXT_CHARS)
547
-
548
- ans, ev = extractive_qa(q, ctx)
549
- if not ans:
550
- # fallback: generator open-book answer
551
- prompt = (
552
- "Ответь на вопрос, используя ТОЛЬКО данный текст. "
553
- "Если ответа нет, скажи 'В тексте это не указано'.\n\n"
554
- f"Текст:\n{ctx}\n\n"
555
- f"Вопрос: {q}\nОтвет:"
556
- )
557
- ans = generate_text(prompt, max_new_tokens=120, min_new_tokens=20, do_sample=False).strip()
558
- ev = ctx[:320].strip()
559
-
560
- lines.append(f"**{i}. {q}**")
561
- lines.append(f"- Ответ: {ans}")
562
- lines.append(f"- Фрагмент: {ev}")
563
- lines.append("")
564
- return "\n".join(lines).strip(), st.__dict__
565
-
566
-
567
- def do_chat(text: str, state: Optional[Dict[str, Any]], chat: List[Tuple[str, str]], user_q: str):
568
- st = ensure_index(state, text)
569
- user_q = (user_q or "").strip()
570
- if not user_q:
571
- return chat, st.__dict__, ""
572
-
573
- hits = retrieve(st, user_q, k=5)
574
- ctx = "\n\n".join([c for _, c in hits]) if hits else safe_trunc(st.text, 2500)
575
- ctx = safe_trunc(ctx, MAX_CONTEXT_CHARS)
576
-
577
- ans, ev = extractive_qa(user_q, ctx)
578
- if not ans:
579
- prompt = (
580
- "Ответь на вопрос по тексту. "
581
- "Если ответа нет, скажи 'В тексте это не указано'.\n\n"
582
- f"Текст:\n{ctx}\n\n"
583
- f"Вопрос: {user_q}\nОтвет:"
584
  )
585
- ans = generate_text(prompt, max_new_tokens=140, min_new_tokens=20, do_sample=False).strip()
586
- ev = ctx[:360].strip()
587
 
588
- reply = f"**Ответ:** {ans}\n\n**Доказательство:**\n{ev}"
589
- chat = (chat or []) + [(user_q, reply)]
590
- return chat, st.__dict__, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
 
 
 
592
 
593
- def model_info_text() -> str:
594
- return (
595
- "Используемые модели (3 трансформера):\n"
596
- f"1) Генерация (пересказ/вопросы): {GEN_ID}\n"
597
- f"2) Эмбеддинги (поиск): {EMB_ID}\n"
598
- f"3) Extractive QA (ответ+фрагмент): {QA_ID}\n"
599
- "\nCPU-only, без GPU. Память обычно < 16GB."
600
- )
601
 
 
 
 
 
602
 
603
- # ============================================================
604
- # UI
605
- # ============================================================
606
- with gr.Blocks(title="RU Text Study Assistant (CPU, 3 Transformers)") as demo:
607
- gr.Markdown("## RU Text Study Assistant\nПересказ, вопросы, чат по тексту и семантический поиск. CPU-only, 3 трансформера.")
 
 
 
 
 
 
 
608
 
609
- state = gr.State({"text": "", "chunks": [], "emb": None})
610
 
 
 
 
 
611
  with gr.Row():
612
  with gr.Column(scale=2):
613
- src_text = gr.Textbox(label="Текст", lines=12, placeholder="Вставьте русский текст (лекция, статья, конспект).")
614
- with gr.Accordion("Модели", open=False):
615
- gr.Textbox(value=model_info_text(), lines=6, interactive=False, show_label=False)
616
 
617
  with gr.Column(scale=3):
618
  with gr.Tabs():
619
  with gr.Tab("Пересказ"):
620
- sum_level = gr.Radio(["Коротко", "Подробнее"], value="Коротко", label="Уровень")
621
  sum_btn = gr.Button("Сделать пересказ", variant="primary")
622
  sum_out = gr.Markdown()
623
 
624
- with gr.Tab("Вопросы"):
625
- q_n = gr.Slider(1, 12, value=6, step=1, label="Количество вопросов")
626
- q_btn = gr.Button("Сгенерировать вопросы", variant="primary")
627
- q_out = gr.Markdown()
628
-
629
- with gr.Tab("Чат по тексту"):
630
- chat = gr.Chatbot(height=380)
631
- with gr.Row():
632
- user_q = gr.Textbox(label="Вопрос", placeholder="Задайте вопрос по тексту…", lines=1)
633
- send = gr.Button("Отправить")
634
- gr.Markdown("Ответ: поиск по чанкам → extractive QA с доказательством → fallback на генерацию.")
635
-
636
  with gr.Tab("Поиск"):
637
- s_q = gr.Textbox(label="Запрос", placeholder="Например: 'вывод', 'метод', 'ограничения'")
638
- s_k = gr.Slider(1, 10, value=5, step=1, label="Топ-K")
639
- s_btn = gr.Button("Найти фрагменты", variant="primary")
640
- s_out = gr.Markdown()
641
-
642
- sum_btn.click(do_summary, inputs=[src_text, state, sum_level], outputs=[sum_out, state])
643
- q_btn.click(do_quiz, inputs=[src_text, state, q_n], outputs=[q_out, state])
644
 
645
- send.click(do_chat, inputs=[src_text, state, chat, user_q], outputs=[chat, state, user_q])
646
- user_q.submit(do_chat, inputs=[src_text, state, chat, user_q], outputs=[chat, state, user_q])
 
 
647
 
648
- s_btn.click(do_search, inputs=[src_text, state, s_q, s_k], outputs=[s_out, state])
 
 
 
 
 
 
 
 
 
 
649
 
650
  if __name__ == "__main__":
651
  demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
  import os
2
  import re
3
+ import hashlib
 
 
4
  from dataclasses import dataclass
5
+ from collections import OrderedDict
6
+ from typing import List, Tuple, Optional, Dict, Any
7
 
8
  import numpy as np
9
  import torch
10
  import gradio as gr
11
 
 
12
  from transformers import (
13
  AutoTokenizer,
14
  AutoModel,
15
+ pipeline,
 
16
  )
17
  from transformers.utils import logging as hf_logging
18
 
19
 
20
+ # =========================
21
+ # CPU-only + quieter logs
22
+ # =========================
23
  os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
24
  os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
25
  os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1")
 
 
 
26
  hf_logging.set_verbosity_error()
27
 
 
28
  torch.set_grad_enabled(False)
29
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
30
 
31
+ # =========================
32
+ # Models (3 transformers)
33
+ # =========================
34
+ SUM_MODEL_CANDIDATES = [
35
+ "d0rj/rut5-base-summ", # RU summarization
36
+ "cointegrated/rut5-base-absum", # RU summarization fallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ]
38
 
39
+ QA_MODEL_CANDIDATES = [
40
+ "mrm8488/bert-multi-cased-finetuned-xquadv1", # multilingual QA
41
+ "mrm8488/bert-multi-cased-finedtuned-xquad-tydiqa-goldp",
42
  ]
43
 
44
+ EMB_MODEL_CANDIDATES = [
45
+ "intfloat/multilingual-e5-small", # retrieval embeddings
46
+ "intfloat/e5-small-v2",
47
  ]
48
 
49
+ DEVICE = -1 # CPU for pipelines
50
+
51
+ # =========================
52
+ # Limits (memory & speed)
53
+ # =========================
54
+ MAX_TEXT_CHARS = 120_000
55
+ CHUNK_CHARS = 1400
56
+ MAX_CHUNKS = 140
57
+ EMB_BATCH = 16
58
+
59
+ TOPK_DEFAULT = 5
60
+ CTX_MAX_CHARS = 4500
61
+
62
+ # =========================
63
+ # Helpers
64
+ # =========================
65
+ RU_STOP = {
66
+ "и","в","во","на","но","а","что","это","как","к","ко","из","за","по","у","от","до","при","для","над",
67
+ "под","же","ли","бы","не","ни","то","его","ее","их","мы","вы","они","она","он","оно","этот","эта","эти",
68
+ "там","тут","здесь","так","такие","такой","есть","быть","был","была","были","будет","будут"
69
+ }
70
+
71
+ def safe_text(s: str, max_chars: int = MAX_TEXT_CHARS) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  s = (s or "").strip()
73
  if len(s) > max_chars:
74
+ s = s[:max_chars].rstrip() + "\n\n[Обрезано по лимиту длины]"
75
  return s
76
 
77
+ def normalize_space(s: str) -> str:
 
78
  return re.sub(r"\s+", " ", (s or "")).strip()
79
 
80
+ def split_into_chunks(text: str) -> List[str]:
81
+ text = safe_text(text)
 
82
  paras = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()]
83
+ chunks = []
 
84
  buf = ""
85
+
86
  for p in paras:
87
  if not buf:
88
  buf = p
89
+ elif len(buf) + 2 + len(p) <= CHUNK_CHARS:
 
90
  buf = buf + "\n\n" + p
91
  else:
92
  chunks.append(buf.strip())
93
  buf = p
94
  if len(chunks) >= MAX_CHUNKS:
95
  break
96
+
97
  if buf and len(chunks) < MAX_CHUNKS:
98
  chunks.append(buf.strip())
99
 
100
+ # If still too big, split long chunks by sentences
101
+ sent_re = re.compile(r"(?<=[\.\!\?…])\s+")
102
+ final_chunks = []
103
  for c in chunks:
104
+ if len(c) <= int(CHUNK_CHARS * 1.6):
105
+ final_chunks.append(c)
106
  continue
107
+ sents = [x.strip() for x in sent_re.split(c) if x.strip()]
108
  b = ""
109
  for s in sents:
110
  if not b:
111
  b = s
112
+ elif len(b) + 1 + len(s) <= CHUNK_CHARS:
 
113
  b = b + " " + s
114
  else:
115
+ final_chunks.append(b.strip())
116
  b = s
117
+ if len(final_chunks) >= MAX_CHUNKS:
118
  break
119
+ if b and len(final_chunks) < MAX_CHUNKS:
120
+ final_chunks.append(b.strip())
121
+ if len(final_chunks) >= MAX_CHUNKS:
122
  break
123
 
124
+ return final_chunks[:MAX_CHUNKS]
125
+
126
+ def sha_key(text: str) -> str:
127
+ h = hashlib.sha1(text.encode("utf-8")).hexdigest()
128
+ return h[:12]
129
+
130
+
131
+ # =========================
132
+ # Global model holders
133
+ # =========================
134
+ _SUM_PIPE = None
135
+ _SUM_ID = None
136
 
137
+ _QA_PIPE = None
138
+ _QA_ID = None
139
 
140
+ _EMB_TOK = None
141
+ _EMB_MODEL = None
142
+ _EMB_ID = None
143
+
144
+
145
+ def _try_load_summarizer() -> Tuple[Any, str]:
146
+ last_err = None
147
+ for mid in SUM_MODEL_CANDIDATES:
148
+ try:
149
+ pipe = pipeline("summarization", model=mid, device=DEVICE)
150
+ return pipe, mid
151
+ except Exception as e:
152
+ last_err = e
153
+ raise RuntimeError(f"Cannot load summarization model. Last error: {last_err}")
154
+
155
+ def _try_load_qa() -> Tuple[Any, str]:
156
+ last_err = None
157
+ for mid in QA_MODEL_CANDIDATES:
158
+ try:
159
+ pipe = pipeline("question-answering", model=mid, device=DEVICE)
160
+ return pipe, mid
161
+ except Exception as e:
162
+ last_err = e
163
+ raise RuntimeError(f"Cannot load QA model. Last error: {last_err}")
164
+
165
+ def _try_load_emb() -> Tuple[Any, Any, str]:
166
+ last_err = None
167
+ for mid in EMB_MODEL_CANDIDATES:
168
+ try:
169
+ tok = AutoTokenizer.from_pretrained(mid, use_fast=True)
170
+ model = AutoModel.from_pretrained(mid, torch_dtype=torch.float32, low_cpu_mem_usage=True).eval()
171
+ return tok, model, mid
172
+ except Exception as e:
173
+ last_err = e
174
+ raise RuntimeError(f"Cannot load embedding model. Last error: {last_err}")
175
+
176
+ def get_models():
177
+ global _SUM_PIPE, _SUM_ID, _QA_PIPE, _QA_ID, _EMB_TOK, _EMB_MODEL, _EMB_ID
178
+
179
+ if _SUM_PIPE is None:
180
+ _SUM_PIPE, _SUM_ID = _try_load_summarizer()
181
+ if _QA_PIPE is None:
182
+ _QA_PIPE, _QA_ID = _try_load_qa()
183
+ if _EMB_MODEL is None:
184
+ _EMB_TOK, _EMB_MODEL, _EMB_ID = _try_load_emb()
185
+
186
+ return _SUM_PIPE, _SUM_ID, _QA_PIPE, _QA_ID, _EMB_TOK, _EMB_MODEL, _EMB_ID
187
+
188
+
189
+ # =========================
190
+ # Embeddings + retrieval
191
+ # =========================
192
  @torch.inference_mode()
193
+ def _mean_pool(last_hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
194
  m = mask.unsqueeze(-1).bool()
195
  x = last_hidden.masked_fill(~m, 0.0)
196
  summed = x.sum(dim=1)
197
  denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1)
198
  return summed / denom
199
 
 
200
  @torch.inference_mode()
201
  def embed_texts(texts: List[str], is_query: bool) -> np.ndarray:
202
+ _, _, _, _, tok, model, _ = get_models()
 
 
203
  prefix = "query: " if is_query else "passage: "
204
+ batch_texts = [prefix + normalize_space(t) for t in texts]
205
 
206
  vecs = []
207
+ for i in range(0, len(batch_texts), EMB_BATCH):
208
+ batch = batch_texts[i:i + EMB_BATCH]
209
  enc = tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt")
210
  out = model(**enc)
211
+ pooled = _mean_pool(out.last_hidden_state, enc["attention_mask"])
212
  pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
213
  vecs.append(pooled.cpu().numpy().astype(np.float32))
 
214
 
215
+ return np.vstack(vecs) if vecs else np.zeros((0, 384), dtype=np.float32)
216
 
217
  def topk_cosine(q: np.ndarray, mat: np.ndarray, k: int) -> List[Tuple[int, float]]:
218
  scores = (mat @ q.reshape(-1, 1)).squeeze(1)
 
225
 
226
 
227
  @dataclass
228
+ class Index:
229
+ key: str
230
  text: str
231
  chunks: List[str]
232
+ emb: np.ndarray
233
 
234
 
235
+ # Small LRU cache (keeps RAM bounded)
236
+ _INDEX_CACHE: "OrderedDict[str, Index]" = OrderedDict()
237
+ CACHE_MAX = 4
 
 
 
 
238
 
239
+ def get_index(text: str) -> Index:
240
+ text = safe_text(text)
241
+ k = sha_key(text)
242
+ if k in _INDEX_CACHE:
243
+ _INDEX_CACHE.move_to_end(k)
244
+ return _INDEX_CACHE[k]
245
 
246
+ chunks = split_into_chunks(text)
247
+ emb = embed_texts(chunks, is_query=False) if chunks else np.zeros((0, 384), dtype=np.float32)
248
+ idx = Index(key=k, text=text, chunks=chunks, emb=emb)
 
 
 
249
 
250
+ _INDEX_CACHE[k] = idx
251
+ _INDEX_CACHE.move_to_end(k)
252
+ while len(_INDEX_CACHE) > CACHE_MAX:
253
+ _INDEX_CACHE.popitem(last=False)
254
 
255
+ return idx
256
+
257
+ def retrieve(idx: Index, query: str, k: int) -> List[Tuple[float, str]]:
258
  query = (query or "").strip()
259
+ if not query or idx.emb.shape[0] == 0:
260
  return []
261
  qv = embed_texts([query], is_query=True)[0]
262
+ hits = topk_cosine(qv, idx.emb, k=k)
263
+ return [(score, idx.chunks[i]) for i, score in hits]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
 
266
+ # =========================
267
+ # Summarization (hierarchical, stable)
268
+ # =========================
269
+ def summarize_one(text: str, out_max: int, out_min: int) -> str:
270
+ sum_pipe, _, _, _, _, _, _ = get_models()
271
+ text = normalize_space(text)
272
+ if not text:
273
+ return ""
274
+ # pipeline expects token lengths; we keep conservative values
275
+ res = sum_pipe(text, max_length=out_max, min_length=out_min, do_sample=False)
276
+ if isinstance(res, list) and res:
277
+ return (res[0].get("summary_text") or "").strip()
278
+ return ""
279
+
280
+ def smart_summary(text: str) -> str:
281
+ text = safe_text(text)
282
+ if not text:
283
+ return "Нет текста."
284
+
285
+ chunks = split_into_chunks(text)
286
+ if not chunks:
287
+ return "Нет текста."
288
+
289
+ # For short text: direct
290
+ if len(text) < 2500 and len(chunks) <= 2:
291
+ s = summarize_one(text, out_max=220, out_min=80)
292
+ return s if s else summarize_one(text, out_max=160, out_min=50)
293
+
294
+ # For long text: summarize chunks then summarize the combined summaries
295
+ parts = chunks[:8]
296
+ partial = []
297
+ for p in parts:
298
+ sp = summarize_one(p, out_max=140, out_min=40)
299
+ if sp:
300
+ partial.append(sp)
301
+
302
+ combined = " ".join(partial).strip()
303
+ if not combined:
304
+ combined = " ".join(parts)[:4000]
305
+
306
+ final = summarize_one(combined, out_max=240, out_min=90)
307
+ if not final:
308
+ final = summarize_one(combined, out_max=180, out_min=60)
309
+
310
+ return final if final else "Не удалось получить пересказ."
311
+
312
+ def make_title(text: str, summary: str) -> str:
313
+ # heuristic title: first 8–12 words of summary, else first sentence of text
314
+ src = summary.strip() if summary.strip() else normalize_space(text[:500])
315
+ words = [w for w in re.split(r"\s+", src) if w]
316
+ title = " ".join(words[:12]).strip(" .,:;—-")
317
+ return title if title else "Краткий пересказ"
318
+
319
+
320
+ # =========================
321
+ # QA Chat (retrieval + pipeline QA)
322
+ # =========================
323
+ def qa_answer(question: str, context: str) -> Tuple[str, str, float]:
324
+ _, _, qa_pipe, _, _, _, _ = get_models()
325
+ question = (question or "").strip()
326
+ context = (context or "").strip()
327
+ if not question or not context:
328
+ return "", "", 0.0
329
+
330
+ context = context[:CTX_MAX_CHARS]
331
+ out = qa_pipe(question=question, context=context)
332
+ ans = (out.get("answer") or "").strip()
333
+ score = float(out.get("score") or 0.0)
334
+ start = int(out.get("start") or 0)
335
+ end = int(out.get("end") or 0)
336
+
337
+ # evidence snippet
338
+ left = max(0, start - 140)
339
+ right = min(len(context), end + 220)
340
+ snippet = context[left:right].strip()
341
+ if left > 0:
342
+ snippet = "…" + snippet
343
+ if right < len(context):
344
+ snippet = snippet + "…"
345
 
346
+ return ans, snippet, score
347
 
348
 
349
+ # =========================
350
+ # Quiz (heuristic questions; answers via retrieval+QA)
351
+ # =========================
352
+ def _sentences(text: str) -> List[str]:
353
+ # very simple sentence splitter
354
+ text = normalize_space(text)
355
+ if not text:
356
+ return []
357
+ parts = re.split(r"(?<=[\.\!\?…])\s+", text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  out = []
359
+ for p in parts:
360
+ p = p.strip()
361
+ if 40 <= len(p) <= 240:
362
+ out.append(p)
 
 
 
 
363
  return out
364
 
365
+ def _keywords(text: str) -> Dict[str, int]:
366
+ words = re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", text.lower())
367
+ freq: Dict[str, int] = {}
368
+ for w in words:
369
+ if w in RU_STOP:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  continue
371
+ freq[w] = freq.get(w, 0) + 1
372
+ return freq
373
 
374
+ def generate_quiz_questions(text: str, n: int) -> List[str]:
375
+ n = int(max(1, min(n, 12)))
376
+ sents = _sentences(text)
377
+ if not sents:
378
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
+ freq = _keywords(text)
381
+ if not freq:
382
+ # fallback: use first sentences
383
+ sents = sents[:n]
384
+ return [f"О чем говорится в утверждении: «{s}»?" for s in sents]
385
+
386
+ scored = []
387
+ for s in sents:
388
+ ws = re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", s.lower())
389
+ score = sum(freq.get(w, 0) for w in ws if w not in RU_STOP)
390
+ scored.append((score, s))
391
+ scored.sort(key=lambda x: x[0], reverse=True)
392
+
393
+ questions = []
394
+ for _, s in scored[: min(len(scored), n * 2)]:
395
+ ws = [w for w in re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", s.lower()) if w not in RU_STOP]
396
+ if not ws:
397
+ continue
398
+ # choose "keyword" to blank
399
+ kw = max(ws, key=lambda w: freq.get(w, 0))
400
+ # blank first occurrence (case-insensitive)
401
+ blanked = re.sub(re.escape(kw), "____", s, count=1, flags=re.IGNORECASE)
402
+ q = f"Заполните пропуск: {blanked}"
403
+ questions.append(q)
404
+ if len(questions) >= n:
405
+ break
406
 
407
+ return questions[:n]
408
 
409
+
410
+ # =========================
411
+ # Gradio actions
412
+ # =========================
413
+ def on_load_models() -> str:
414
+ try:
415
+ sum_pipe, sum_id, qa_pipe, qa_id, emb_tok, emb_model, emb_id = get_models()
416
+ return (
417
+ "Модели загружены.\n"
418
+ f"- Summarization: {sum_id}\n"
419
+ f"- QA: {qa_id}\n"
420
+ f"- Embeddings: {emb_id}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  )
422
+ except Exception as e:
423
+ return f"Ошибка загрузки моделей: {e}"
424
 
425
+ def on_summary(text: str) -> str:
426
+ try:
427
+ text = safe_text(text)
428
+ if not text:
429
+ return "Нет текста."
430
+ s = smart_summary(text)
431
+ title = make_title(text, s)
432
+ return f"### Заголовок\n{title}\n\n### Пересказ\n{s}"
433
+ except Exception as e:
434
+ return f"Ошибка: {e}"
435
+
436
+ def on_search(text: str, query: str, k: int) -> str:
437
+ try:
438
+ text = safe_text(text)
439
+ query = (query or "").strip()
440
+ if not text:
441
+ return "Нет текста."
442
+ if not query:
443
+ return "Введите запрос."
444
+ idx = get_index(text)
445
+ hits = retrieve(idx, query, int(max(1, min(k, 10))))
446
+ if not hits:
447
+ return "Ничего не найдено."
448
+ out = ["### Результаты"]
449
+ for i, (score, chunk) in enumerate(hits, 1):
450
+ out.append(f"**{i}. score={score:.3f}**\n{chunk}\n")
451
+ return "\n".join(out).strip()
452
+ except Exception as e:
453
+ return f"Ошибка: {e}"
454
+
455
+ def on_quiz(text: str, n: int) -> str:
456
+ try:
457
+ text = safe_text(text)
458
+ if not text:
459
+ return "Нет текста."
460
+ idx = get_index(text)
461
+
462
+ questions = generate_quiz_questions(text, int(n))
463
+ if not questions:
464
+ return "Не удалось сгенерировать вопросы."
465
+
466
+ lines = ["### Вопросы и ответы (с доказательством)"]
467
+ for i, q in enumerate(questions, 1):
468
+ # For cloze question, try to answer via QA using retrieved context.
469
+ # We convert cloze to a QA-style question by removing "Заполните пропуск:"
470
+ qa_q = re.sub(r"^Заполните пропуск:\s*", "", q).strip()
471
+ hits = retrieve(idx, qa_q, k=5)
472
+ ctx = "\n\n".join([c for _, c in hits]) if hits else text[:CTX_MAX_CHARS]
473
+ ctx = ctx[:CTX_MAX_CHARS]
474
+
475
+ ans, ev, score = qa_answer(qa_q, ctx)
476
+ if not ans or score < 0.08:
477
+ ans = "В тексте это не указано (или требуется переформулировать вопрос)."
478
+
479
+ lines.append(f"**{i}. {q}**")
480
+ lines.append(f"- Ответ: {ans}")
481
+ lines.append(f"- Фрагмент: {ev}")
482
+ lines.append("")
483
+ return "\n".join(lines).strip()
484
+ except Exception as e:
485
+ return f"Ошибка: {e}"
486
+
487
+ def on_chat(text: str, history: List[Tuple[str, str]], user_q: str):
488
+ try:
489
+ text = safe_text(text)
490
+ user_q = (user_q or "").strip()
491
+ history = history or []
492
 
493
+ if not text:
494
+ history.append((user_q, "Нет текста. Вставьте текст слева."))
495
+ return history, ""
496
 
497
+ if not user_q:
498
+ return history, ""
 
 
 
 
 
 
499
 
500
+ idx = get_index(text)
501
+ hits = retrieve(idx, user_q, k=5)
502
+ ctx = "\n\n".join([c for _, c in hits]) if hits else text[:CTX_MAX_CHARS]
503
+ ctx = ctx[:CTX_MAX_CHARS]
504
 
505
+ ans, ev, score = qa_answer(user_q, ctx)
506
+ if not ans or score < 0.08:
507
+ reply = "Ответ по тексту не найден. Попробуйте переформулировать вопрос или уточнить термин."
508
+ else:
509
+ reply = f"Ответ: {ans}\n\nДоказательство:\n{ev}"
510
+
511
+ history.append((user_q, reply))
512
+ return history, ""
513
+ except Exception as e:
514
+ history = history or []
515
+ history.append((user_q, f"Ошибка: {e}"))
516
+ return history, ""
517
 
 
518
 
519
+ # =========================
520
+ # UI (minimal)
521
+ # =========================
522
+ with gr.Blocks(title="RU Text Assistant (CPU, 3 Transformers)") as demo:
523
  with gr.Row():
524
  with gr.Column(scale=2):
525
+ text_in = gr.Textbox(label="Текст (русский)", lines=16, placeholder="Вставьте текст для анализа…")
526
+ load_btn = gr.Button("Загрузить модели", variant="secondary")
527
+ model_status = gr.Textbox(label="Статус", lines=5, interactive=False)
528
 
529
  with gr.Column(scale=3):
530
  with gr.Tabs():
531
  with gr.Tab("Пересказ"):
 
532
  sum_btn = gr.Button("Сделать пересказ", variant="primary")
533
  sum_out = gr.Markdown()
534
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  with gr.Tab("Поиск"):
536
+ query_in = gr.Textbox(label="Запрос", placeholder="Например: стандартизация, вариабельность, вывод…")
537
+ k_in = gr.Slider(1, 10, value=TOPK_DEFAULT, step=1, label="Top-K")
538
+ search_btn = gr.Button("Найти фрагменты", variant="primary")
539
+ search_out = gr.Markdown()
 
 
 
540
 
541
+ with gr.Tab("Вопросы"):
542
+ n_in = gr.Slider(1, 12, value=6, step=1, label="Количество вопросов")
543
+ quiz_btn = gr.Button("Сгенерировать и проверить", variant="primary")
544
+ quiz_out = gr.Markdown()
545
 
546
+ with gr.Tab("Чат по тексту"):
547
+ chat = gr.Chatbot(height=420)
548
+ user_q = gr.Textbox(label="Вопрос", lines=1, placeholder="Задайте вопрос по тексту…")
549
+ send_btn = gr.Button("Отправить", variant="primary")
550
+
551
+ load_btn.click(on_load_models, outputs=[model_status])
552
+ sum_btn.click(on_summary, inputs=[text_in], outputs=[sum_out])
553
+ search_btn.click(on_search, inputs=[text_in, query_in, k_in], outputs=[search_out])
554
+ quiz_btn.click(on_quiz, inputs=[text_in, n_in], outputs=[quiz_out])
555
+ send_btn.click(on_chat, inputs=[text_in, chat, user_q], outputs=[chat, user_q])
556
+ user_q.submit(on_chat, inputs=[text_in, chat, user_q], outputs=[chat, user_q])
557
 
558
  if __name__ == "__main__":
559
  demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, show_error=True)