import os import re import hashlib from dataclasses import dataclass from collections import OrderedDict from typing import List, Tuple, Optional, Dict, Any import numpy as np import torch import gradio as gr from transformers import ( AutoTokenizer, AutoModel, pipeline, ) from transformers.utils import logging as hf_logging # ========================= # CPU-only + quieter logs # ========================= os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") hf_logging.set_verbosity_error() torch.set_grad_enabled(False) torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4"))) # ========================= # Models (3 transformers) # ========================= SUM_MODEL_CANDIDATES = [ "d0rj/rut5-base-summ", # RU summarization "cointegrated/rut5-base-absum", # RU summarization fallback ] QA_MODEL_CANDIDATES = [ "mrm8488/bert-multi-cased-finetuned-xquadv1", # multilingual QA "mrm8488/bert-multi-cased-finedtuned-xquad-tydiqa-goldp", ] EMB_MODEL_CANDIDATES = [ "intfloat/multilingual-e5-small", # retrieval embeddings "intfloat/e5-small-v2", ] DEVICE = -1 # CPU for pipelines # ========================= # Limits (memory & speed) # ========================= MAX_TEXT_CHARS = 120_000 CHUNK_CHARS = 1400 MAX_CHUNKS = 140 EMB_BATCH = 16 TOPK_DEFAULT = 5 CTX_MAX_CHARS = 4500 # ========================= # Helpers # ========================= RU_STOP = { "и","в","во","на","но","а","что","это","как","к","ко","из","за","по","у","от","до","при","для","над", "под","же","ли","бы","не","ни","то","его","ее","их","мы","вы","они","она","он","оно","этот","эта","эти", "там","тут","здесь","так","такие","такой","есть","быть","был","была","были","будет","будут" } def safe_text(s: str, max_chars: int = MAX_TEXT_CHARS) -> str: s = (s or "").strip() if len(s) > max_chars: s = s[:max_chars].rstrip() + "\n\n[Обрезано по лимиту длины]" return s def normalize_space(s: str) -> str: return re.sub(r"\s+", " ", (s or "")).strip() def split_into_chunks(text: str) -> List[str]: text = safe_text(text) paras = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()] chunks = [] buf = "" for p in paras: if not buf: buf = p elif len(buf) + 2 + len(p) <= CHUNK_CHARS: buf = buf + "\n\n" + p else: chunks.append(buf.strip()) buf = p if len(chunks) >= MAX_CHUNKS: break if buf and len(chunks) < MAX_CHUNKS: chunks.append(buf.strip()) # If still too big, split long chunks by sentences sent_re = re.compile(r"(?<=[\.\!\?…])\s+") final_chunks = [] for c in chunks: if len(c) <= int(CHUNK_CHARS * 1.6): final_chunks.append(c) continue sents = [x.strip() for x in sent_re.split(c) if x.strip()] b = "" for s in sents: if not b: b = s elif len(b) + 1 + len(s) <= CHUNK_CHARS: b = b + " " + s else: final_chunks.append(b.strip()) b = s if len(final_chunks) >= MAX_CHUNKS: break if b and len(final_chunks) < MAX_CHUNKS: final_chunks.append(b.strip()) if len(final_chunks) >= MAX_CHUNKS: break return final_chunks[:MAX_CHUNKS] def sha_key(text: str) -> str: h = hashlib.sha1(text.encode("utf-8")).hexdigest() return h[:12] # ========================= # Global model holders # ========================= _SUM_PIPE = None _SUM_ID = None _QA_PIPE = None _QA_ID = None _EMB_TOK = None _EMB_MODEL = None _EMB_ID = None def _try_load_summarizer() -> Tuple[Any, str]: last_err = None for mid in SUM_MODEL_CANDIDATES: try: pipe = pipeline("summarization", model=mid, device=DEVICE) return pipe, mid except Exception as e: last_err = e raise RuntimeError(f"Cannot load summarization model. Last error: {last_err}") def _try_load_qa() -> Tuple[Any, str]: last_err = None for mid in QA_MODEL_CANDIDATES: try: pipe = pipeline("question-answering", model=mid, device=DEVICE) return pipe, mid except Exception as e: last_err = e raise RuntimeError(f"Cannot load QA model. Last error: {last_err}") def _try_load_emb() -> Tuple[Any, Any, str]: last_err = None for mid in EMB_MODEL_CANDIDATES: try: tok = AutoTokenizer.from_pretrained(mid, use_fast=True) model = AutoModel.from_pretrained(mid, torch_dtype=torch.float32, low_cpu_mem_usage=True).eval() return tok, model, mid except Exception as e: last_err = e raise RuntimeError(f"Cannot load embedding model. Last error: {last_err}") def get_models(): global _SUM_PIPE, _SUM_ID, _QA_PIPE, _QA_ID, _EMB_TOK, _EMB_MODEL, _EMB_ID if _SUM_PIPE is None: _SUM_PIPE, _SUM_ID = _try_load_summarizer() if _QA_PIPE is None: _QA_PIPE, _QA_ID = _try_load_qa() if _EMB_MODEL is None: _EMB_TOK, _EMB_MODEL, _EMB_ID = _try_load_emb() return _SUM_PIPE, _SUM_ID, _QA_PIPE, _QA_ID, _EMB_TOK, _EMB_MODEL, _EMB_ID # ========================= # Embeddings + retrieval # ========================= @torch.inference_mode() def _mean_pool(last_hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: m = mask.unsqueeze(-1).bool() x = last_hidden.masked_fill(~m, 0.0) summed = x.sum(dim=1) denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) return summed / denom @torch.inference_mode() def embed_texts(texts: List[str], is_query: bool) -> np.ndarray: _, _, _, _, tok, model, _ = get_models() prefix = "query: " if is_query else "passage: " batch_texts = [prefix + normalize_space(t) for t in texts] vecs = [] for i in range(0, len(batch_texts), EMB_BATCH): batch = batch_texts[i:i + EMB_BATCH] enc = tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt") out = model(**enc) pooled = _mean_pool(out.last_hidden_state, enc["attention_mask"]) pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) vecs.append(pooled.cpu().numpy().astype(np.float32)) return np.vstack(vecs) if vecs else np.zeros((0, 384), dtype=np.float32) def topk_cosine(q: np.ndarray, mat: np.ndarray, k: int) -> List[Tuple[int, float]]: scores = (mat @ q.reshape(-1, 1)).squeeze(1) if scores.size == 0: return [] k = max(1, min(k, scores.size)) idx = np.argpartition(-scores, k - 1)[:k] idx = idx[np.argsort(-scores[idx])] return [(int(i), float(scores[i])) for i in idx] @dataclass class Index: key: str text: str chunks: List[str] emb: np.ndarray # Small LRU cache (keeps RAM bounded) _INDEX_CACHE: "OrderedDict[str, Index]" = OrderedDict() CACHE_MAX = 4 def get_index(text: str) -> Index: text = safe_text(text) k = sha_key(text) if k in _INDEX_CACHE: _INDEX_CACHE.move_to_end(k) return _INDEX_CACHE[k] chunks = split_into_chunks(text) emb = embed_texts(chunks, is_query=False) if chunks else np.zeros((0, 384), dtype=np.float32) idx = Index(key=k, text=text, chunks=chunks, emb=emb) _INDEX_CACHE[k] = idx _INDEX_CACHE.move_to_end(k) while len(_INDEX_CACHE) > CACHE_MAX: _INDEX_CACHE.popitem(last=False) return idx def retrieve(idx: Index, query: str, k: int) -> List[Tuple[float, str]]: query = (query or "").strip() if not query or idx.emb.shape[0] == 0: return [] qv = embed_texts([query], is_query=True)[0] hits = topk_cosine(qv, idx.emb, k=k) return [(score, idx.chunks[i]) for i, score in hits] # ========================= # Summarization (hierarchical, stable) # ========================= def summarize_one(text: str, out_max: int, out_min: int) -> str: sum_pipe, _, _, _, _, _, _ = get_models() text = normalize_space(text) if not text: return "" # pipeline expects token lengths; we keep conservative values res = sum_pipe(text, max_length=out_max, min_length=out_min, do_sample=False) if isinstance(res, list) and res: return (res[0].get("summary_text") or "").strip() return "" def smart_summary(text: str) -> str: text = safe_text(text) if not text: return "Нет текста." chunks = split_into_chunks(text) if not chunks: return "Нет текста." # For short text: direct if len(text) < 2500 and len(chunks) <= 2: s = summarize_one(text, out_max=220, out_min=80) return s if s else summarize_one(text, out_max=160, out_min=50) # For long text: summarize chunks then summarize the combined summaries parts = chunks[:8] partial = [] for p in parts: sp = summarize_one(p, out_max=140, out_min=40) if sp: partial.append(sp) combined = " ".join(partial).strip() if not combined: combined = " ".join(parts)[:4000] final = summarize_one(combined, out_max=240, out_min=90) if not final: final = summarize_one(combined, out_max=180, out_min=60) return final if final else "Не удалось получить пересказ." def make_title(text: str, summary: str) -> str: # heuristic title: first 8–12 words of summary, else first sentence of text src = summary.strip() if summary.strip() else normalize_space(text[:500]) words = [w for w in re.split(r"\s+", src) if w] title = " ".join(words[:12]).strip(" .,:;—-") return title if title else "Краткий пересказ" # ========================= # QA Chat (retrieval + pipeline QA) # ========================= def qa_answer(question: str, context: str) -> Tuple[str, str, float]: _, _, qa_pipe, _, _, _, _ = get_models() question = (question or "").strip() context = (context or "").strip() if not question or not context: return "", "", 0.0 context = context[:CTX_MAX_CHARS] out = qa_pipe(question=question, context=context) ans = (out.get("answer") or "").strip() score = float(out.get("score") or 0.0) start = int(out.get("start") or 0) end = int(out.get("end") or 0) # evidence snippet left = max(0, start - 140) right = min(len(context), end + 220) snippet = context[left:right].strip() if left > 0: snippet = "…" + snippet if right < len(context): snippet = snippet + "…" return ans, snippet, score # ========================= # Quiz (heuristic questions; answers via retrieval+QA) # ========================= def _sentences(text: str) -> List[str]: # very simple sentence splitter text = normalize_space(text) if not text: return [] parts = re.split(r"(?<=[\.\!\?…])\s+", text) out = [] for p in parts: p = p.strip() if 40 <= len(p) <= 240: out.append(p) return out def _keywords(text: str) -> Dict[str, int]: words = re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", text.lower()) freq: Dict[str, int] = {} for w in words: if w in RU_STOP: continue freq[w] = freq.get(w, 0) + 1 return freq def generate_quiz_questions(text: str, n: int) -> List[str]: n = int(max(1, min(n, 12))) sents = _sentences(text) if not sents: return [] freq = _keywords(text) if not freq: # fallback: use first sentences sents = sents[:n] return [f"О чем говорится в утверждении: «{s}»?" for s in sents] scored = [] for s in sents: ws = re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", s.lower()) score = sum(freq.get(w, 0) for w in ws if w not in RU_STOP) scored.append((score, s)) scored.sort(key=lambda x: x[0], reverse=True) questions = [] for _, s in scored[: min(len(scored), n * 2)]: ws = [w for w in re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", s.lower()) if w not in RU_STOP] if not ws: continue # choose "keyword" to blank kw = max(ws, key=lambda w: freq.get(w, 0)) # blank first occurrence (case-insensitive) blanked = re.sub(re.escape(kw), "____", s, count=1, flags=re.IGNORECASE) q = f"Заполните пропуск: {blanked}" questions.append(q) if len(questions) >= n: break return questions[:n] # ========================= # Gradio actions # ========================= def on_load_models() -> str: try: sum_pipe, sum_id, qa_pipe, qa_id, emb_tok, emb_model, emb_id = get_models() return ( "Модели загружены.\n" f"- Summarization: {sum_id}\n" f"- QA: {qa_id}\n" f"- Embeddings: {emb_id}\n" ) except Exception as e: return f"Ошибка загрузки моделей: {e}" def on_summary(text: str) -> str: try: text = safe_text(text) if not text: return "Нет текста." s = smart_summary(text) title = make_title(text, s) return f"### Заголовок\n{title}\n\n### Пересказ\n{s}" except Exception as e: return f"Ошибка: {e}" def on_search(text: str, query: str, k: int) -> str: try: text = safe_text(text) query = (query or "").strip() if not text: return "Нет текста." if not query: return "Введите запрос." idx = get_index(text) hits = retrieve(idx, query, int(max(1, min(k, 10)))) if not hits: return "Ничего не найдено." out = ["### Результаты"] for i, (score, chunk) in enumerate(hits, 1): out.append(f"**{i}. score={score:.3f}**\n{chunk}\n") return "\n".join(out).strip() except Exception as e: return f"Ошибка: {e}" def on_quiz(text: str, n: int) -> str: try: text = safe_text(text) if not text: return "Нет текста." idx = get_index(text) questions = generate_quiz_questions(text, int(n)) if not questions: return "Не удалось сгенерировать вопросы." lines = ["### Вопросы и ответы (с доказательством)"] for i, q in enumerate(questions, 1): # For cloze question, try to answer via QA using retrieved context. # We convert cloze to a QA-style question by removing "Заполните пропуск:" qa_q = re.sub(r"^Заполните пропуск:\s*", "", q).strip() hits = retrieve(idx, qa_q, k=5) ctx = "\n\n".join([c for _, c in hits]) if hits else text[:CTX_MAX_CHARS] ctx = ctx[:CTX_MAX_CHARS] ans, ev, score = qa_answer(qa_q, ctx) if not ans or score < 0.08: ans = "В тексте это не указано (или требуется переформулировать вопрос)." lines.append(f"**{i}. {q}**") lines.append(f"- Ответ: {ans}") lines.append(f"- Фрагмент: {ev}") lines.append("") return "\n".join(lines).strip() except Exception as e: return f"Ошибка: {e}" def on_chat(text: str, history: List[Tuple[str, str]], user_q: str): try: text = safe_text(text) user_q = (user_q or "").strip() history = history or [] if not text: history.append((user_q, "Нет текста. Вставьте текст слева.")) return history, "" if not user_q: return history, "" idx = get_index(text) hits = retrieve(idx, user_q, k=5) ctx = "\n\n".join([c for _, c in hits]) if hits else text[:CTX_MAX_CHARS] ctx = ctx[:CTX_MAX_CHARS] ans, ev, score = qa_answer(user_q, ctx) if not ans or score < 0.08: reply = "Ответ по тексту не найден. Попробуйте переформулировать вопрос или уточнить термин." else: reply = f"Ответ: {ans}\n\nДоказательство:\n{ev}" history.append((user_q, reply)) return history, "" except Exception as e: history = history or [] history.append((user_q, f"Ошибка: {e}")) return history, "" # ========================= # UI (minimal) # ========================= with gr.Blocks(title="RU Text Assistant (CPU, 3 Transformers)") as demo: with gr.Row(): with gr.Column(scale=2): text_in = gr.Textbox(label="Текст (русский)", lines=16, placeholder="Вставьте текст для анализа…") load_btn = gr.Button("Загрузить модели", variant="secondary") model_status = gr.Textbox(label="Статус", lines=5, interactive=False) with gr.Column(scale=3): with gr.Tabs(): with gr.Tab("Пересказ"): sum_btn = gr.Button("Сделать пересказ", variant="primary") sum_out = gr.Markdown() with gr.Tab("Поиск"): query_in = gr.Textbox(label="Запрос", placeholder="Например: стандартизация, вариабельность, вывод…") k_in = gr.Slider(1, 10, value=TOPK_DEFAULT, step=1, label="Top-K") search_btn = gr.Button("Найти фрагменты", variant="primary") search_out = gr.Markdown() with gr.Tab("Вопросы"): n_in = gr.Slider(1, 12, value=6, step=1, label="Количество вопросов") quiz_btn = gr.Button("Сгенерировать и проверить", variant="primary") quiz_out = gr.Markdown() with gr.Tab("Чат по тексту"): chat = gr.Chatbot(height=420) user_q = gr.Textbox(label="Вопрос", lines=1, placeholder="Задайте вопрос по тексту…") send_btn = gr.Button("Отправить", variant="primary") load_btn.click(on_load_models, outputs=[model_status]) sum_btn.click(on_summary, inputs=[text_in], outputs=[sum_out]) search_btn.click(on_search, inputs=[text_in, query_in, k_in], outputs=[search_out]) quiz_btn.click(on_quiz, inputs=[text_in, n_in], outputs=[quiz_out]) send_btn.click(on_chat, inputs=[text_in, chat, user_q], outputs=[chat, user_q]) user_q.submit(on_chat, inputs=[text_in, chat, user_q], outputs=[chat, user_q]) if __name__ == "__main__": demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, show_error=True)