|
|
import os |
|
|
import re |
|
|
import hashlib |
|
|
from dataclasses import dataclass |
|
|
from collections import OrderedDict |
|
|
from typing import List, Tuple, Optional, Dict, Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import gradio as gr |
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModel, |
|
|
pipeline, |
|
|
) |
|
|
from transformers.utils import logging as hf_logging |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") |
|
|
os.environ.setdefault("TRANSFORMERS_NO_ADVISORY_WARNINGS", "1") |
|
|
hf_logging.set_verbosity_error() |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4"))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUM_MODEL_CANDIDATES = [ |
|
|
"d0rj/rut5-base-summ", |
|
|
"cointegrated/rut5-base-absum", |
|
|
] |
|
|
|
|
|
QA_MODEL_CANDIDATES = [ |
|
|
"mrm8488/bert-multi-cased-finetuned-xquadv1", |
|
|
"mrm8488/bert-multi-cased-finedtuned-xquad-tydiqa-goldp", |
|
|
] |
|
|
|
|
|
EMB_MODEL_CANDIDATES = [ |
|
|
"intfloat/multilingual-e5-small", |
|
|
"intfloat/e5-small-v2", |
|
|
] |
|
|
|
|
|
DEVICE = -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_TEXT_CHARS = 120_000 |
|
|
CHUNK_CHARS = 1400 |
|
|
MAX_CHUNKS = 140 |
|
|
EMB_BATCH = 16 |
|
|
|
|
|
TOPK_DEFAULT = 5 |
|
|
CTX_MAX_CHARS = 4500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RU_STOP = { |
|
|
"и","в","во","на","но","а","что","это","как","к","ко","из","за","по","у","от","до","при","для","над", |
|
|
"под","же","ли","бы","не","ни","то","его","ее","их","мы","вы","они","она","он","оно","этот","эта","эти", |
|
|
"там","тут","здесь","так","такие","такой","есть","быть","был","была","были","будет","будут" |
|
|
} |
|
|
|
|
|
def safe_text(s: str, max_chars: int = MAX_TEXT_CHARS) -> str: |
|
|
s = (s or "").strip() |
|
|
if len(s) > max_chars: |
|
|
s = s[:max_chars].rstrip() + "\n\n[Обрезано по лимиту длины]" |
|
|
return s |
|
|
|
|
|
def normalize_space(s: str) -> str: |
|
|
return re.sub(r"\s+", " ", (s or "")).strip() |
|
|
|
|
|
def split_into_chunks(text: str) -> List[str]: |
|
|
text = safe_text(text) |
|
|
paras = [p.strip() for p in re.split(r"\n\s*\n+", text) if p.strip()] |
|
|
chunks = [] |
|
|
buf = "" |
|
|
|
|
|
for p in paras: |
|
|
if not buf: |
|
|
buf = p |
|
|
elif len(buf) + 2 + len(p) <= CHUNK_CHARS: |
|
|
buf = buf + "\n\n" + p |
|
|
else: |
|
|
chunks.append(buf.strip()) |
|
|
buf = p |
|
|
if len(chunks) >= MAX_CHUNKS: |
|
|
break |
|
|
|
|
|
if buf and len(chunks) < MAX_CHUNKS: |
|
|
chunks.append(buf.strip()) |
|
|
|
|
|
|
|
|
sent_re = re.compile(r"(?<=[\.\!\?…])\s+") |
|
|
final_chunks = [] |
|
|
for c in chunks: |
|
|
if len(c) <= int(CHUNK_CHARS * 1.6): |
|
|
final_chunks.append(c) |
|
|
continue |
|
|
sents = [x.strip() for x in sent_re.split(c) if x.strip()] |
|
|
b = "" |
|
|
for s in sents: |
|
|
if not b: |
|
|
b = s |
|
|
elif len(b) + 1 + len(s) <= CHUNK_CHARS: |
|
|
b = b + " " + s |
|
|
else: |
|
|
final_chunks.append(b.strip()) |
|
|
b = s |
|
|
if len(final_chunks) >= MAX_CHUNKS: |
|
|
break |
|
|
if b and len(final_chunks) < MAX_CHUNKS: |
|
|
final_chunks.append(b.strip()) |
|
|
if len(final_chunks) >= MAX_CHUNKS: |
|
|
break |
|
|
|
|
|
return final_chunks[:MAX_CHUNKS] |
|
|
|
|
|
def sha_key(text: str) -> str: |
|
|
h = hashlib.sha1(text.encode("utf-8")).hexdigest() |
|
|
return h[:12] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SUM_PIPE = None |
|
|
_SUM_ID = None |
|
|
|
|
|
_QA_PIPE = None |
|
|
_QA_ID = None |
|
|
|
|
|
_EMB_TOK = None |
|
|
_EMB_MODEL = None |
|
|
_EMB_ID = None |
|
|
|
|
|
|
|
|
def _try_load_summarizer() -> Tuple[Any, str]: |
|
|
last_err = None |
|
|
for mid in SUM_MODEL_CANDIDATES: |
|
|
try: |
|
|
pipe = pipeline("summarization", model=mid, device=DEVICE) |
|
|
return pipe, mid |
|
|
except Exception as e: |
|
|
last_err = e |
|
|
raise RuntimeError(f"Cannot load summarization model. Last error: {last_err}") |
|
|
|
|
|
def _try_load_qa() -> Tuple[Any, str]: |
|
|
last_err = None |
|
|
for mid in QA_MODEL_CANDIDATES: |
|
|
try: |
|
|
pipe = pipeline("question-answering", model=mid, device=DEVICE) |
|
|
return pipe, mid |
|
|
except Exception as e: |
|
|
last_err = e |
|
|
raise RuntimeError(f"Cannot load QA model. Last error: {last_err}") |
|
|
|
|
|
def _try_load_emb() -> Tuple[Any, Any, str]: |
|
|
last_err = None |
|
|
for mid in EMB_MODEL_CANDIDATES: |
|
|
try: |
|
|
tok = AutoTokenizer.from_pretrained(mid, use_fast=True) |
|
|
model = AutoModel.from_pretrained(mid, torch_dtype=torch.float32, low_cpu_mem_usage=True).eval() |
|
|
return tok, model, mid |
|
|
except Exception as e: |
|
|
last_err = e |
|
|
raise RuntimeError(f"Cannot load embedding model. Last error: {last_err}") |
|
|
|
|
|
def get_models(): |
|
|
global _SUM_PIPE, _SUM_ID, _QA_PIPE, _QA_ID, _EMB_TOK, _EMB_MODEL, _EMB_ID |
|
|
|
|
|
if _SUM_PIPE is None: |
|
|
_SUM_PIPE, _SUM_ID = _try_load_summarizer() |
|
|
if _QA_PIPE is None: |
|
|
_QA_PIPE, _QA_ID = _try_load_qa() |
|
|
if _EMB_MODEL is None: |
|
|
_EMB_TOK, _EMB_MODEL, _EMB_ID = _try_load_emb() |
|
|
|
|
|
return _SUM_PIPE, _SUM_ID, _QA_PIPE, _QA_ID, _EMB_TOK, _EMB_MODEL, _EMB_ID |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def _mean_pool(last_hidden: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
|
|
m = mask.unsqueeze(-1).bool() |
|
|
x = last_hidden.masked_fill(~m, 0.0) |
|
|
summed = x.sum(dim=1) |
|
|
denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) |
|
|
return summed / denom |
|
|
|
|
|
@torch.inference_mode() |
|
|
def embed_texts(texts: List[str], is_query: bool) -> np.ndarray: |
|
|
_, _, _, _, tok, model, _ = get_models() |
|
|
prefix = "query: " if is_query else "passage: " |
|
|
batch_texts = [prefix + normalize_space(t) for t in texts] |
|
|
|
|
|
vecs = [] |
|
|
for i in range(0, len(batch_texts), EMB_BATCH): |
|
|
batch = batch_texts[i:i + EMB_BATCH] |
|
|
enc = tok(batch, padding=True, truncation=True, max_length=512, return_tensors="pt") |
|
|
out = model(**enc) |
|
|
pooled = _mean_pool(out.last_hidden_state, enc["attention_mask"]) |
|
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) |
|
|
vecs.append(pooled.cpu().numpy().astype(np.float32)) |
|
|
|
|
|
return np.vstack(vecs) if vecs else np.zeros((0, 384), dtype=np.float32) |
|
|
|
|
|
def topk_cosine(q: np.ndarray, mat: np.ndarray, k: int) -> List[Tuple[int, float]]: |
|
|
scores = (mat @ q.reshape(-1, 1)).squeeze(1) |
|
|
if scores.size == 0: |
|
|
return [] |
|
|
k = max(1, min(k, scores.size)) |
|
|
idx = np.argpartition(-scores, k - 1)[:k] |
|
|
idx = idx[np.argsort(-scores[idx])] |
|
|
return [(int(i), float(scores[i])) for i in idx] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Index: |
|
|
key: str |
|
|
text: str |
|
|
chunks: List[str] |
|
|
emb: np.ndarray |
|
|
|
|
|
|
|
|
|
|
|
_INDEX_CACHE: "OrderedDict[str, Index]" = OrderedDict() |
|
|
CACHE_MAX = 4 |
|
|
|
|
|
def get_index(text: str) -> Index: |
|
|
text = safe_text(text) |
|
|
k = sha_key(text) |
|
|
if k in _INDEX_CACHE: |
|
|
_INDEX_CACHE.move_to_end(k) |
|
|
return _INDEX_CACHE[k] |
|
|
|
|
|
chunks = split_into_chunks(text) |
|
|
emb = embed_texts(chunks, is_query=False) if chunks else np.zeros((0, 384), dtype=np.float32) |
|
|
idx = Index(key=k, text=text, chunks=chunks, emb=emb) |
|
|
|
|
|
_INDEX_CACHE[k] = idx |
|
|
_INDEX_CACHE.move_to_end(k) |
|
|
while len(_INDEX_CACHE) > CACHE_MAX: |
|
|
_INDEX_CACHE.popitem(last=False) |
|
|
|
|
|
return idx |
|
|
|
|
|
def retrieve(idx: Index, query: str, k: int) -> List[Tuple[float, str]]: |
|
|
query = (query or "").strip() |
|
|
if not query or idx.emb.shape[0] == 0: |
|
|
return [] |
|
|
qv = embed_texts([query], is_query=True)[0] |
|
|
hits = topk_cosine(qv, idx.emb, k=k) |
|
|
return [(score, idx.chunks[i]) for i, score in hits] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def summarize_one(text: str, out_max: int, out_min: int) -> str: |
|
|
sum_pipe, _, _, _, _, _, _ = get_models() |
|
|
text = normalize_space(text) |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
res = sum_pipe(text, max_length=out_max, min_length=out_min, do_sample=False) |
|
|
if isinstance(res, list) and res: |
|
|
return (res[0].get("summary_text") or "").strip() |
|
|
return "" |
|
|
|
|
|
def smart_summary(text: str) -> str: |
|
|
text = safe_text(text) |
|
|
if not text: |
|
|
return "Нет текста." |
|
|
|
|
|
chunks = split_into_chunks(text) |
|
|
if not chunks: |
|
|
return "Нет текста." |
|
|
|
|
|
|
|
|
if len(text) < 2500 and len(chunks) <= 2: |
|
|
s = summarize_one(text, out_max=220, out_min=80) |
|
|
return s if s else summarize_one(text, out_max=160, out_min=50) |
|
|
|
|
|
|
|
|
parts = chunks[:8] |
|
|
partial = [] |
|
|
for p in parts: |
|
|
sp = summarize_one(p, out_max=140, out_min=40) |
|
|
if sp: |
|
|
partial.append(sp) |
|
|
|
|
|
combined = " ".join(partial).strip() |
|
|
if not combined: |
|
|
combined = " ".join(parts)[:4000] |
|
|
|
|
|
final = summarize_one(combined, out_max=240, out_min=90) |
|
|
if not final: |
|
|
final = summarize_one(combined, out_max=180, out_min=60) |
|
|
|
|
|
return final if final else "Не удалось получить пересказ." |
|
|
|
|
|
def make_title(text: str, summary: str) -> str: |
|
|
|
|
|
src = summary.strip() if summary.strip() else normalize_space(text[:500]) |
|
|
words = [w for w in re.split(r"\s+", src) if w] |
|
|
title = " ".join(words[:12]).strip(" .,:;—-") |
|
|
return title if title else "Краткий пересказ" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def qa_answer(question: str, context: str) -> Tuple[str, str, float]: |
|
|
_, _, qa_pipe, _, _, _, _ = get_models() |
|
|
question = (question or "").strip() |
|
|
context = (context or "").strip() |
|
|
if not question or not context: |
|
|
return "", "", 0.0 |
|
|
|
|
|
context = context[:CTX_MAX_CHARS] |
|
|
out = qa_pipe(question=question, context=context) |
|
|
ans = (out.get("answer") or "").strip() |
|
|
score = float(out.get("score") or 0.0) |
|
|
start = int(out.get("start") or 0) |
|
|
end = int(out.get("end") or 0) |
|
|
|
|
|
|
|
|
left = max(0, start - 140) |
|
|
right = min(len(context), end + 220) |
|
|
snippet = context[left:right].strip() |
|
|
if left > 0: |
|
|
snippet = "…" + snippet |
|
|
if right < len(context): |
|
|
snippet = snippet + "…" |
|
|
|
|
|
return ans, snippet, score |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sentences(text: str) -> List[str]: |
|
|
|
|
|
text = normalize_space(text) |
|
|
if not text: |
|
|
return [] |
|
|
parts = re.split(r"(?<=[\.\!\?…])\s+", text) |
|
|
out = [] |
|
|
for p in parts: |
|
|
p = p.strip() |
|
|
if 40 <= len(p) <= 240: |
|
|
out.append(p) |
|
|
return out |
|
|
|
|
|
def _keywords(text: str) -> Dict[str, int]: |
|
|
words = re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", text.lower()) |
|
|
freq: Dict[str, int] = {} |
|
|
for w in words: |
|
|
if w in RU_STOP: |
|
|
continue |
|
|
freq[w] = freq.get(w, 0) + 1 |
|
|
return freq |
|
|
|
|
|
def generate_quiz_questions(text: str, n: int) -> List[str]: |
|
|
n = int(max(1, min(n, 12))) |
|
|
sents = _sentences(text) |
|
|
if not sents: |
|
|
return [] |
|
|
|
|
|
freq = _keywords(text) |
|
|
if not freq: |
|
|
|
|
|
sents = sents[:n] |
|
|
return [f"О чем говорится в утверждении: «{s}»?" for s in sents] |
|
|
|
|
|
scored = [] |
|
|
for s in sents: |
|
|
ws = re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", s.lower()) |
|
|
score = sum(freq.get(w, 0) for w in ws if w not in RU_STOP) |
|
|
scored.append((score, s)) |
|
|
scored.sort(key=lambda x: x[0], reverse=True) |
|
|
|
|
|
questions = [] |
|
|
for _, s in scored[: min(len(scored), n * 2)]: |
|
|
ws = [w for w in re.findall(r"[А-Яа-яЁёA-Za-z\-]{3,}", s.lower()) if w not in RU_STOP] |
|
|
if not ws: |
|
|
continue |
|
|
|
|
|
kw = max(ws, key=lambda w: freq.get(w, 0)) |
|
|
|
|
|
blanked = re.sub(re.escape(kw), "____", s, count=1, flags=re.IGNORECASE) |
|
|
q = f"Заполните пропуск: {blanked}" |
|
|
questions.append(q) |
|
|
if len(questions) >= n: |
|
|
break |
|
|
|
|
|
return questions[:n] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_load_models() -> str: |
|
|
try: |
|
|
sum_pipe, sum_id, qa_pipe, qa_id, emb_tok, emb_model, emb_id = get_models() |
|
|
return ( |
|
|
"Модели загружены.\n" |
|
|
f"- Summarization: {sum_id}\n" |
|
|
f"- QA: {qa_id}\n" |
|
|
f"- Embeddings: {emb_id}\n" |
|
|
) |
|
|
except Exception as e: |
|
|
return f"Ошибка загрузки моделей: {e}" |
|
|
|
|
|
def on_summary(text: str) -> str: |
|
|
try: |
|
|
text = safe_text(text) |
|
|
if not text: |
|
|
return "Нет текста." |
|
|
s = smart_summary(text) |
|
|
title = make_title(text, s) |
|
|
return f"### Заголовок\n{title}\n\n### Пересказ\n{s}" |
|
|
except Exception as e: |
|
|
return f"Ошибка: {e}" |
|
|
|
|
|
def on_search(text: str, query: str, k: int) -> str: |
|
|
try: |
|
|
text = safe_text(text) |
|
|
query = (query or "").strip() |
|
|
if not text: |
|
|
return "Нет текста." |
|
|
if not query: |
|
|
return "Введите запрос." |
|
|
idx = get_index(text) |
|
|
hits = retrieve(idx, query, int(max(1, min(k, 10)))) |
|
|
if not hits: |
|
|
return "Ничего не найдено." |
|
|
out = ["### Результаты"] |
|
|
for i, (score, chunk) in enumerate(hits, 1): |
|
|
out.append(f"**{i}. score={score:.3f}**\n{chunk}\n") |
|
|
return "\n".join(out).strip() |
|
|
except Exception as e: |
|
|
return f"Ошибка: {e}" |
|
|
|
|
|
def on_quiz(text: str, n: int) -> str: |
|
|
try: |
|
|
text = safe_text(text) |
|
|
if not text: |
|
|
return "Нет текста." |
|
|
idx = get_index(text) |
|
|
|
|
|
questions = generate_quiz_questions(text, int(n)) |
|
|
if not questions: |
|
|
return "Не удалось сгенерировать вопросы." |
|
|
|
|
|
lines = ["### Вопросы и ответы (с доказательством)"] |
|
|
for i, q in enumerate(questions, 1): |
|
|
|
|
|
|
|
|
qa_q = re.sub(r"^Заполните пропуск:\s*", "", q).strip() |
|
|
hits = retrieve(idx, qa_q, k=5) |
|
|
ctx = "\n\n".join([c for _, c in hits]) if hits else text[:CTX_MAX_CHARS] |
|
|
ctx = ctx[:CTX_MAX_CHARS] |
|
|
|
|
|
ans, ev, score = qa_answer(qa_q, ctx) |
|
|
if not ans or score < 0.08: |
|
|
ans = "В тексте это не указано (или требуется переформулировать вопрос)." |
|
|
|
|
|
lines.append(f"**{i}. {q}**") |
|
|
lines.append(f"- Ответ: {ans}") |
|
|
lines.append(f"- Фрагмент: {ev}") |
|
|
lines.append("") |
|
|
return "\n".join(lines).strip() |
|
|
except Exception as e: |
|
|
return f"Ошибка: {e}" |
|
|
|
|
|
def on_chat(text: str, history: List[Tuple[str, str]], user_q: str): |
|
|
try: |
|
|
text = safe_text(text) |
|
|
user_q = (user_q or "").strip() |
|
|
history = history or [] |
|
|
|
|
|
if not text: |
|
|
history.append((user_q, "Нет текста. Вставьте текст слева.")) |
|
|
return history, "" |
|
|
|
|
|
if not user_q: |
|
|
return history, "" |
|
|
|
|
|
idx = get_index(text) |
|
|
hits = retrieve(idx, user_q, k=5) |
|
|
ctx = "\n\n".join([c for _, c in hits]) if hits else text[:CTX_MAX_CHARS] |
|
|
ctx = ctx[:CTX_MAX_CHARS] |
|
|
|
|
|
ans, ev, score = qa_answer(user_q, ctx) |
|
|
if not ans or score < 0.08: |
|
|
reply = "Ответ по тексту не найден. Попробуйте переформулировать вопрос или уточнить термин." |
|
|
else: |
|
|
reply = f"Ответ: {ans}\n\nДоказательство:\n{ev}" |
|
|
|
|
|
history.append((user_q, reply)) |
|
|
return history, "" |
|
|
except Exception as e: |
|
|
history = history or [] |
|
|
history.append((user_q, f"Ошибка: {e}")) |
|
|
return history, "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="RU Text Assistant (CPU, 3 Transformers)") as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
text_in = gr.Textbox(label="Текст (русский)", lines=16, placeholder="Вставьте текст для анализа…") |
|
|
load_btn = gr.Button("Загрузить модели", variant="secondary") |
|
|
model_status = gr.Textbox(label="Статус", lines=5, interactive=False) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Пересказ"): |
|
|
sum_btn = gr.Button("Сделать пересказ", variant="primary") |
|
|
sum_out = gr.Markdown() |
|
|
|
|
|
with gr.Tab("Поиск"): |
|
|
query_in = gr.Textbox(label="Запрос", placeholder="Например: стандартизация, вариабельность, вывод…") |
|
|
k_in = gr.Slider(1, 10, value=TOPK_DEFAULT, step=1, label="Top-K") |
|
|
search_btn = gr.Button("Найти фрагменты", variant="primary") |
|
|
search_out = gr.Markdown() |
|
|
|
|
|
with gr.Tab("Вопросы"): |
|
|
n_in = gr.Slider(1, 12, value=6, step=1, label="Количество вопросов") |
|
|
quiz_btn = gr.Button("Сгенерировать и проверить", variant="primary") |
|
|
quiz_out = gr.Markdown() |
|
|
|
|
|
with gr.Tab("Чат по тексту"): |
|
|
chat = gr.Chatbot(height=420) |
|
|
user_q = gr.Textbox(label="Вопрос", lines=1, placeholder="Задайте вопрос по тексту…") |
|
|
send_btn = gr.Button("Отправить", variant="primary") |
|
|
|
|
|
load_btn.click(on_load_models, outputs=[model_status]) |
|
|
sum_btn.click(on_summary, inputs=[text_in], outputs=[sum_out]) |
|
|
search_btn.click(on_search, inputs=[text_in, query_in, k_in], outputs=[search_out]) |
|
|
quiz_btn.click(on_quiz, inputs=[text_in, n_in], outputs=[quiz_out]) |
|
|
send_btn.click(on_chat, inputs=[text_in, chat, user_q], outputs=[chat, user_q]) |
|
|
user_q.submit(on_chat, inputs=[text_in, chat, user_q], outputs=[chat, user_q]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=32).launch(server_name="0.0.0.0", server_port=7860, show_error=True) |
|
|
|