Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 11,439 Bytes
9a5d8ec ba84108 9a5d8ec ba84108 9a5d8ec ba84108 406793a bb3d05e e31ea41 bb3d05e ba84108 9b9fd34 bb3d05e ba84108 bb3d05e 9b9fd34 9cfe75e b6df863 265d5e8 e6317fa 265d5e8 e6317fa 3d475b8 406793a 3d475b8 9b9fd34 b6df863 e12a61c 9b9fd34 b6df863 9b9fd34 e31ea41 9b9fd34 db5e662 b6df863 9b9fd34 e12a61c 87f6a56 9cfe75e 325ed03 b6ecc6e 6370ba9 ba84108 b6ecc6e 9ea7979 ba84108 406793a 71c9493 406793a 9a2984d 87cb4a4 9a2984d f52f2a8 9a2984d f52f2a8 9a2984d 9cfe75e defc447 0856c76 9a2984d 9b9fd34 406793a 1e3cd55 9ea7979 9a2984d 9b9fd34 1e3cd55 9a2984d 9b9fd34 f1b3d54 9a2984d f1b3d54 9b9fd34 9a2984d 406793a 9a2984d 177a8e6 d5e14ad 177a8e6 9a2984d f1b3d54 ba1e296 9a2984d b81d9a3 9a2984d b81d9a3 ae8c8c9 0856c76 9a2984d 0856c76 924a06f 0856c76 6af94a0 ae8c8c9 87cb4a4 9a2984d ae8c8c9 0eb70bb 87cb4a4 ae8c8c9 0eb70bb ba1e296 9a2984d b81d9a3 ba1e296 9a2984d f1b3d54 f35e5d2 9a2984d 1e3cd55 c921016 9a2984d c921016 598b6e6 9a2984d ba1e296 db5e662 9a2984d ba1e296 23ced0f 4c56373 5a001f3 4c56373 c921016 9a2984d 1e3cd55 cd07d29 79079d1 fa07bfe 2954446 79079d1 4612b23 f2be8d8 4612b23 9cfe75e b0c4a5b f0fd942 06353d7 296279c 1f7a544 296279c 06353d7 be1d31d 2954446 c50c843 518369b c50c843 518369b 4612b23 c50c843 5a001f3 ec632c7 0856c76 c50c843 ad4809c c50c843 1e3cd55 c50c843 79079d1 0856c76 fa07bfe ac240ce f0fd942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import time # Timing / Performance-Messung
import random # Zufallswerte (z.B. Beispiel-Reviews)
import html # HTML-Escaping für sichere Ausgabe in Gradio
import types # Monkeypatching von Instanzen (fastText .predict)
import numpy as np # Numerische Arrays und Wahrscheinlichkeiten
# Machine Learning / NLP
import torch # PyTorch: Modelle, Tensoren, Devices
import fasttext # Sprach-ID-Modell (lid.176)
# Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden:
import sentencepiece # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3)
import tiktoken # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer)
from langid.langid import LanguageIdentifier, model # Alternative Sprach-ID
# Hugging Face Ökosystem
import spaces # HF Spaces-Dekoratoren (@spaces.GPU)
from transformers import AutoTokenizer # Tokenizer laden (use_fast=False für DeBERTa v3)
from huggingface_hub import hf_hub_download # Download von Dateien/Weights aus dem HF Hub
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# Übersetzung
import deepl # DeepL API für automatische Übersetzung
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier
from lib.bert_regressor_utils import (
predict_flavours, # Hauptfunktion: Vorhersage der 8 Aromenachsen
cleanup_tasting_note
)
from lib.wheel import build_svg_with_values # SVG-Rendering für Flavour Wheel
from lib.examples import EXAMPLES # Beispiel-Reviews (vordefiniert)
##################################################################################
device = "cuda" if torch.cuda.is_available() else "cpu"
### Stettings ####################################################################
MODEL_BASE_CLEANUP = "distilbert-base-uncased"
REPO_ID_CLEANUP = "ziem-io/binary_classifier_is_review_sentence"
MODEL_FILE_CLEANUP = os.getenv("MODEL_FILE_CLEANUP") # in Space-Secrets hinterlegen
MODEL_BASE_CLASSIFY = "microsoft/deberta-v3-base"
REPO_ID_CLASSIFY = "ziem-io/flavour_regressor_multi_head"
MODEL_FILE_CLASSIFY = os.getenv("MODEL_FILE_CLASSIFY") # in Space-Secrets hinterlegen
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights Model Cleanup ---
weights_path_cleanup = hf_hub_download(
repo_id=REPO_ID_CLEANUP,
filename=MODEL_FILE_CLEANUP,
token=HF_TOKEN
)
# --- Tokenizer Model Classify
tokenizer_cleanup = AutoTokenizer.from_pretrained(
MODEL_BASE_CLEANUP,
use_fast=True # DistilBERT kann fast tokenizer problemlos
)
model_cleanup = BertBinaryClassifier(
pretrained_model_name=MODEL_BASE_CLEANUP
)
state_cleanup = load_file(weights_path_cleanup) # safetensors -> dict[str, Tensor]
res = model_cleanup.load_state_dict(state_cleanup, strict=True) # strict=True wenn Keys exakt passen
model_cleanup.to(device).eval()
##################################################################################
# --- Download Weights Model Classify ---
weights_path_classify = hf_hub_download(
repo_id=REPO_ID_CLASSIFY,
filename=MODEL_FILE_CLASSIFY,
token=HF_TOKEN
)
# --- Tokenizer Model Classify (SentencePiece!) ---
tokenizer_classify = AutoTokenizer.from_pretrained(
MODEL_BASE_CLASSIFY,
use_fast=False
)
model_classify = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE_CLASSIFY
)
state_classify = load_file(weights_path_classify) # safetensors -> dict[str, Tensor]
res = model_classify.load_state_dict(state_classify, strict=False) # strict=True wenn Keys exakt passen
model_classify.to(device).eval()
### Check if lang is english #####################################################
ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)
def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
t = (text or "").strip()
if len(t) < min_chars:
return True, 0.0
lang, prob = ID.classify(t) # prob ∈ [0,1]
return (lang == "en" and prob >= threshold), float(prob)
def _translate_en(text: str, target_lang: str = "EN-GB"):
deepl_client = deepl.Translator(DEEPL_API_KEY)
result = deepl_client.translate_text(text, target_lang=target_lang)
return result.text
def _render_cleanup_html(cleanup_meta):
"""
Renders cleanup_meta into HTML with struck-through non-note sentences.
"""
html_out = "<p>"
for s in cleanup_meta:
sent = html.escape(s.get("sentence", ""))
if s.get("is_note"):
html_out += f"{sent} "
else:
html_out += f"<span style='text-decoration: line-through; color: gray;'>{sent}</span> "
html_out += "</p>"
return html_out
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review_raw: str, do_cleanup: bool):
# Normalize input (handle None and trim whitespace)
review_raw = (review_raw or "").strip()
is_translated = False
html_info_out = ""
# Abort early if no text is provided
if not review_raw:
return "Please enter a review.", "", {}
# Detect language of the input text
review_is_eng, review_lang_prob = _is_eng(review_raw)
# Automatically translate non-English text
if not review_is_eng:
review_raw = _translate_en(review_raw)
html_info_out += (
"<strong style='display:block'>Your text has been automatically translated:</strong>"
f"<p>{html.escape(review_raw)}</p>"
)
is_translated = True
# Initialize prediction outputs
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Start timing the model inference
t_start_flavours = time.time()
# Default values to ensure all variables are always defined
# Without cleanup enabled, the full text is treated as a tasting note
review_clean = review_raw
cleanup_meta = []
review_status = "review_only"
# Apply text cleanup only when the checkbox is enabled
if do_cleanup:
review_clean, cleanup_meta, review_status = cleanup_tasting_note(
review_raw,
model_cleanup,
tokenizer_cleanup,
device
)
# Display cleanup visualization based on detected content
if review_status == "mixed":
# Show which parts were kept vs. removed
html_info_out = "<strong>Your text has been cleaned up</strong>"
html_info_out += _render_cleanup_html(cleanup_meta)
elif review_status == "noise_only":
# No tasting notes detected in the text
html_info_out += "<strong>No tasting notes detected</strong>"
html_info_out += _render_cleanup_html(cleanup_meta)
else: # review_status == "review_only"
# Text consists entirely of tasting notes; no cleanup needed
html_info_out += "<strong>No cleanup was necessary</strong>"
# Run flavour prediction only if review content is present
if (not do_cleanup) or (review_status in ("review_only", "mixed")):
prediction_flavours = predict_flavours(
review_clean,
model_classify,
tokenizer_classify,
device
)
prediction_flavours_list = list(prediction_flavours.values())
# Stop timing inference
t_end_flavours = time.time()
# Build the flavour wheel SVG
html_wheel_out = build_svg_with_values(prediction_flavours_list)
# Prepare structured JSON output
json_out = {
"result": dict(prediction_flavours.items()),
"range": {"min": 0, "max": 4},
"review": {
"raw": review_raw,
"clean": review_clean,
"clean_meta": cleanup_meta,
"status": review_status
},
"models": {
"cleanup": MODEL_FILE_CLEANUP,
"classify": MODEL_FILE_CLASSIFY
},
"device": device,
"translated": is_translated,
"duration": round((t_end_flavours - t_start_flavours), 3),
}
# Return HTML info, flavour wheel, and JSON output
return html_info_out, html_wheel_out, json_out
##################################################################################
def random_text():
return random.choice(EXAMPLES)
def _start_text():
return EXAMPLES[20]
### Create Form interface with Gradio Framework ##################################
custom_css = """
@media (prefers-color-scheme: dark) {
svg#wheel > text {
fill: rgb(200, 200, 200);
}
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
gr.HTML("""
<h3>Automatically turns Whisky Tasting Notes into Flavour Wheels.</h3>
<p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")
gr.HTML("""
<p style='color: var(--block-title-text-color)'>Learn more about use cases and get in touch at <a href='https://www.whisky-wheel.com'>www.whisky-wheel.com</a></p>
""")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
review_box = gr.Textbox(
label="Whisky Tasting Note",
lines=8,
placeholder="Enter whisky tasting note",
value=_start_text(),
)
gr.HTML("<div style='color: gray; font-size: 0.9em;'>Note: Non-English texts will be automatically translated.</div>")
with gr.Column():
cleanup_cb = gr.Checkbox(
label="BETA: Cleanup Tasting Note",
value=False
)
gr.HTML("<div style='color: gray; font-size: 0.9em;'>Removes non–tasting note parts from text.</div>")
with gr.Row():
replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_info_out = gr.HTML(label="Info")
html_wheel_out = gr.HTML(label="Flavour Wheel")
json_out = gr.JSON(label="JSON")
# Event Button Submit
submit_btn.click(
predict,
inputs=[review_box, cleanup_cb],
outputs=[html_info_out, html_wheel_out, json_out]
)
# Event Button Submit
replace_btn.click(random_text, outputs=review_box)
demo.launch(show_api=False) |