Spaces:
Running on CPU Upgrade

whisky-wheel / app.py
ziem-io's picture
Update: Display review for 'No tasting notes detected'
87cb4a4
raw
history blame
11.4 kB
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import time # Timing / Performance-Messung
import random # Zufallswerte (z.B. Beispiel-Reviews)
import html # HTML-Escaping für sichere Ausgabe in Gradio
import types # Monkeypatching von Instanzen (fastText .predict)
import numpy as np # Numerische Arrays und Wahrscheinlichkeiten
# Machine Learning / NLP
import torch # PyTorch: Modelle, Tensoren, Devices
import fasttext # Sprach-ID-Modell (lid.176)
# Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden:
import sentencepiece # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3)
import tiktoken # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer)
from langid.langid import LanguageIdentifier, model # Alternative Sprach-ID
# Hugging Face Ökosystem
import spaces # HF Spaces-Dekoratoren (@spaces.GPU)
from transformers import AutoTokenizer # Tokenizer laden (use_fast=False für DeBERTa v3)
from huggingface_hub import hf_hub_download # Download von Dateien/Weights aus dem HF Hub
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# Übersetzung
import deepl # DeepL API für automatische Übersetzung
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier
from lib.bert_regressor_utils import (
predict_flavours, # Hauptfunktion: Vorhersage der 8 Aromenachsen
cleanup_tasting_note
)
from lib.wheel import build_svg_with_values # SVG-Rendering für Flavour Wheel
from lib.examples import EXAMPLES # Beispiel-Reviews (vordefiniert)
##################################################################################
device = "cuda" if torch.cuda.is_available() else "cpu"
### Stettings ####################################################################
MODEL_BASE_CLEANUP = "distilbert-base-uncased"
REPO_ID_CLEANUP = "ziem-io/binary_classifier_is_review_sentence"
MODEL_FILE_CLEANUP = os.getenv("MODEL_FILE_CLEANUP") # in Space-Secrets hinterlegen
MODEL_BASE_CLASSIFY = "microsoft/deberta-v3-base"
REPO_ID_CLASSIFY = "ziem-io/flavour_regressor_multi_head"
MODEL_FILE_CLASSIFY = os.getenv("MODEL_FILE_CLASSIFY") # in Space-Secrets hinterlegen
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights Model Cleanup ---
weights_path_cleanup = hf_hub_download(
repo_id=REPO_ID_CLEANUP,
filename=MODEL_FILE_CLEANUP,
token=HF_TOKEN
)
# --- Tokenizer Model Classify
tokenizer_cleanup = AutoTokenizer.from_pretrained(
MODEL_BASE_CLEANUP,
use_fast=True # DistilBERT kann fast tokenizer problemlos
)
model_cleanup = BertBinaryClassifier(
pretrained_model_name=MODEL_BASE_CLEANUP
)
state_cleanup = load_file(weights_path_cleanup) # safetensors -> dict[str, Tensor]
res = model_cleanup.load_state_dict(state_cleanup, strict=True) # strict=True wenn Keys exakt passen
model_cleanup.to(device).eval()
##################################################################################
# --- Download Weights Model Classify ---
weights_path_classify = hf_hub_download(
repo_id=REPO_ID_CLASSIFY,
filename=MODEL_FILE_CLASSIFY,
token=HF_TOKEN
)
# --- Tokenizer Model Classify (SentencePiece!) ---
tokenizer_classify = AutoTokenizer.from_pretrained(
MODEL_BASE_CLASSIFY,
use_fast=False
)
model_classify = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE_CLASSIFY
)
state_classify = load_file(weights_path_classify) # safetensors -> dict[str, Tensor]
res = model_classify.load_state_dict(state_classify, strict=False) # strict=True wenn Keys exakt passen
model_classify.to(device).eval()
### Check if lang is english #####################################################
ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)
def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
t = (text or "").strip()
if len(t) < min_chars:
return True, 0.0
lang, prob = ID.classify(t) # prob ∈ [0,1]
return (lang == "en" and prob >= threshold), float(prob)
def _translate_en(text: str, target_lang: str = "EN-GB"):
deepl_client = deepl.Translator(DEEPL_API_KEY)
result = deepl_client.translate_text(text, target_lang=target_lang)
return result.text
def _render_cleanup_html(cleanup_meta):
"""
Renders cleanup_meta into HTML with struck-through non-note sentences.
"""
html_out = "<p>"
for s in cleanup_meta:
sent = html.escape(s.get("sentence", ""))
if s.get("is_note"):
html_out += f"{sent} "
else:
html_out += f"<span style='text-decoration: line-through; color: gray;'>{sent}</span> "
html_out += "</p>"
return html_out
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review_raw: str, do_cleanup: bool):
# Normalize input (handle None and trim whitespace)
review_raw = (review_raw or "").strip()
is_translated = False
html_info_out = ""
# Abort early if no text is provided
if not review_raw:
return "Please enter a review.", "", {}
# Detect language of the input text
review_is_eng, review_lang_prob = _is_eng(review_raw)
# Automatically translate non-English text
if not review_is_eng:
review_raw = _translate_en(review_raw)
html_info_out += (
"<strong style='display:block'>Your text has been automatically translated:</strong>"
f"<p>{html.escape(review_raw)}</p>"
)
is_translated = True
# Initialize prediction outputs
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Start timing the model inference
t_start_flavours = time.time()
# Default values to ensure all variables are always defined
# Without cleanup enabled, the full text is treated as a tasting note
review_clean = review_raw
cleanup_meta = []
review_status = "review_only"
# Apply text cleanup only when the checkbox is enabled
if do_cleanup:
review_clean, cleanup_meta, review_status = cleanup_tasting_note(
review_raw,
model_cleanup,
tokenizer_cleanup,
device
)
# Display cleanup visualization based on detected content
if review_status == "mixed":
# Show which parts were kept vs. removed
html_info_out = "<strong>Your text has been cleaned up</strong>"
html_info_out += _render_cleanup_html(cleanup_meta)
elif review_status == "noise_only":
# No tasting notes detected in the text
html_info_out += "<strong>No tasting notes detected</strong>"
html_info_out += _render_cleanup_html(cleanup_meta)
else: # review_status == "review_only"
# Text consists entirely of tasting notes; no cleanup needed
html_info_out += "<strong>No cleanup was necessary</strong>"
# Run flavour prediction only if review content is present
if (not do_cleanup) or (review_status in ("review_only", "mixed")):
prediction_flavours = predict_flavours(
review_clean,
model_classify,
tokenizer_classify,
device
)
prediction_flavours_list = list(prediction_flavours.values())
# Stop timing inference
t_end_flavours = time.time()
# Build the flavour wheel SVG
html_wheel_out = build_svg_with_values(prediction_flavours_list)
# Prepare structured JSON output
json_out = {
"result": dict(prediction_flavours.items()),
"range": {"min": 0, "max": 4},
"review": {
"raw": review_raw,
"clean": review_clean,
"clean_meta": cleanup_meta,
"status": review_status
},
"models": {
"cleanup": MODEL_FILE_CLEANUP,
"classify": MODEL_FILE_CLASSIFY
},
"device": device,
"translated": is_translated,
"duration": round((t_end_flavours - t_start_flavours), 3),
}
# Return HTML info, flavour wheel, and JSON output
return html_info_out, html_wheel_out, json_out
##################################################################################
def random_text():
return random.choice(EXAMPLES)
def _start_text():
return EXAMPLES[20]
### Create Form interface with Gradio Framework ##################################
custom_css = """
@media (prefers-color-scheme: dark) {
svg#wheel > text {
fill: rgb(200, 200, 200);
}
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
gr.HTML("""
<h3>Automatically turns Whisky Tasting Notes into Flavour Wheels.</h3>
<p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")
gr.HTML("""
<p style='color: var(--block-title-text-color)'>Learn more about use cases and get in touch at <a href='https://www.whisky-wheel.com'>www.whisky-wheel.com</a></p>
""")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
review_box = gr.Textbox(
label="Whisky Tasting Note",
lines=8,
placeholder="Enter whisky tasting note",
value=_start_text(),
)
gr.HTML("<div style='color: gray; font-size: 0.9em;'>Note: Non-English texts will be automatically translated.</div>")
with gr.Column():
cleanup_cb = gr.Checkbox(
label="BETA: Cleanup Tasting Note",
value=False
)
gr.HTML("<div style='color: gray; font-size: 0.9em;'>Removes non–tasting note parts from text.</div>")
with gr.Row():
replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_info_out = gr.HTML(label="Info")
html_wheel_out = gr.HTML(label="Flavour Wheel")
json_out = gr.JSON(label="JSON")
# Event Button Submit
submit_btn.click(
predict,
inputs=[review_box, cleanup_cb],
outputs=[html_info_out, html_wheel_out, json_out]
)
# Event Button Submit
replace_btn.click(random_text, outputs=review_box)
demo.launch(show_api=False)