# Standardbibliotheken import os # Umgebungsvariablen (z.B. HF_TOKEN) import time # Timing / Performance-Messung import random # Zufallswerte (z.B. Beispiel-Reviews) import html # HTML-Escaping für sichere Ausgabe in Gradio import types # Monkeypatching von Instanzen (fastText .predict) import numpy as np # Numerische Arrays und Wahrscheinlichkeiten # Machine Learning / NLP import torch # PyTorch: Modelle, Tensoren, Devices import fasttext # Sprach-ID-Modell (lid.176) # Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden: import sentencepiece # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3) import tiktoken # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer) from langid.langid import LanguageIdentifier, model # Alternative Sprach-ID # Hugging Face Ökosystem import spaces # HF Spaces-Dekoratoren (@spaces.GPU) from transformers import AutoTokenizer # Tokenizer laden (use_fast=False für DeBERTa v3) from huggingface_hub import hf_hub_download # Download von Dateien/Weights aus dem HF Hub from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors) # Übersetzung import deepl # DeepL API für automatische Übersetzung # UI / Serving import gradio as gr # Web-UI für Demo/Spaces # Projektspezifische Module from lib.bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier from lib.bert_regressor_utils import ( predict_flavours, # Hauptfunktion: Vorhersage der 8 Aromenachsen cleanup_tasting_note ) from lib.wheel import build_svg_with_values # SVG-Rendering für Flavour Wheel from lib.examples import EXAMPLES # Beispiel-Reviews (vordefiniert) ################################################################################## device = "cuda" if torch.cuda.is_available() else "cpu" ### Stettings #################################################################### MODEL_BASE_CLEANUP = "distilbert-base-uncased" REPO_ID_CLEANUP = "ziem-io/binary_classifier_is_review_sentence" MODEL_FILE_CLEANUP = os.getenv("MODEL_FILE_CLEANUP") # in Space-Secrets hinterlegen MODEL_BASE_CLASSIFY = "microsoft/deberta-v3-base" REPO_ID_CLASSIFY = "ziem-io/flavour_regressor_multi_head" MODEL_FILE_CLASSIFY = os.getenv("MODEL_FILE_CLASSIFY") # in Space-Secrets hinterlegen # (optional) falls das Model-Repo privat ist: HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") # in Space-Secrets hinterlegen ################################################################################## # --- Download Weights Model Cleanup --- weights_path_cleanup = hf_hub_download( repo_id=REPO_ID_CLEANUP, filename=MODEL_FILE_CLEANUP, token=HF_TOKEN ) # --- Tokenizer Model Classify tokenizer_cleanup = AutoTokenizer.from_pretrained( MODEL_BASE_CLEANUP, use_fast=True # DistilBERT kann fast tokenizer problemlos ) model_cleanup = BertBinaryClassifier( pretrained_model_name=MODEL_BASE_CLEANUP ) state_cleanup = load_file(weights_path_cleanup) # safetensors -> dict[str, Tensor] res = model_cleanup.load_state_dict(state_cleanup, strict=True) # strict=True wenn Keys exakt passen model_cleanup.to(device).eval() ################################################################################## # --- Download Weights Model Classify --- weights_path_classify = hf_hub_download( repo_id=REPO_ID_CLASSIFY, filename=MODEL_FILE_CLASSIFY, token=HF_TOKEN ) # --- Tokenizer Model Classify (SentencePiece!) --- tokenizer_classify = AutoTokenizer.from_pretrained( MODEL_BASE_CLASSIFY, use_fast=False ) model_classify = BertMultiHeadRegressor( pretrained_model_name=MODEL_BASE_CLASSIFY ) state_classify = load_file(weights_path_classify) # safetensors -> dict[str, Tensor] res = model_classify.load_state_dict(state_classify, strict=False) # strict=True wenn Keys exakt passen model_classify.to(device).eval() ### Check if lang is english ##################################################### ID = LanguageIdentifier.from_modelstring(model, norm_probs=True) def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1): t = (text or "").strip() if len(t) < min_chars: return True, 0.0 lang, prob = ID.classify(t) # prob ∈ [0,1] return (lang == "en" and prob >= threshold), float(prob) def _translate_en(text: str, target_lang: str = "EN-GB"): deepl_client = deepl.Translator(DEEPL_API_KEY) result = deepl_client.translate_text(text, target_lang=target_lang) return result.text def _render_cleanup_html(cleanup_meta): """ Renders cleanup_meta into HTML with struck-through non-note sentences. """ html_out = "

" for s in cleanup_meta: sent = html.escape(s.get("sentence", "")) if s.get("is_note"): html_out += f"{sent} " else: html_out += f"{sent} " html_out += "

" return html_out ### Do actual prediction ######################################################### @spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call def predict(review_raw: str, do_cleanup: bool): # Normalize input (handle None and trim whitespace) review_raw = (review_raw or "").strip() is_translated = False html_info_out = "" # Abort early if no text is provided if not review_raw: return "Please enter a review.", "", {} # Detect language of the input text review_is_eng, review_lang_prob = _is_eng(review_raw) # Automatically translate non-English text if not review_is_eng: review_raw = _translate_en(review_raw) html_info_out += ( "Your text has been automatically translated:" f"

{html.escape(review_raw)}

" ) is_translated = True # Initialize prediction outputs prediction_flavours = {} prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0] # Start timing the model inference t_start_flavours = time.time() # Default values to ensure all variables are always defined # Without cleanup enabled, the full text is treated as a tasting note review_clean = review_raw cleanup_meta = [] review_status = "review_only" # Apply text cleanup only when the checkbox is enabled if do_cleanup: review_clean, cleanup_meta, review_status = cleanup_tasting_note( review_raw, model_cleanup, tokenizer_cleanup, device ) # Display cleanup visualization based on detected content if review_status == "mixed": # Show which parts were kept vs. removed html_info_out = "Your text has been cleaned up" html_info_out += _render_cleanup_html(cleanup_meta) elif review_status == "noise_only": # No tasting notes detected in the text html_info_out += "No tasting notes detected" html_info_out += _render_cleanup_html(cleanup_meta) else: # review_status == "review_only" # Text consists entirely of tasting notes; no cleanup needed html_info_out += "No cleanup was necessary" # Run flavour prediction only if review content is present if (not do_cleanup) or (review_status in ("review_only", "mixed")): prediction_flavours = predict_flavours( review_clean, model_classify, tokenizer_classify, device ) prediction_flavours_list = list(prediction_flavours.values()) # Stop timing inference t_end_flavours = time.time() # Build the flavour wheel SVG html_wheel_out = build_svg_with_values(prediction_flavours_list) # Prepare structured JSON output json_out = { "result": dict(prediction_flavours.items()), "range": {"min": 0, "max": 4}, "review": { "raw": review_raw, "clean": review_clean, "clean_meta": cleanup_meta, "status": review_status }, "models": { "cleanup": MODEL_FILE_CLEANUP, "classify": MODEL_FILE_CLASSIFY }, "device": device, "translated": is_translated, "duration": round((t_end_flavours - t_start_flavours), 3), } # Return HTML info, flavour wheel, and JSON output return html_info_out, html_wheel_out, json_out ################################################################################## def random_text(): return random.choice(EXAMPLES) def _start_text(): return EXAMPLES[20] ### Create Form interface with Gradio Framework ################################## custom_css = """ @media (prefers-color-scheme: dark) { svg#wheel > text { fill: rgb(200, 200, 200); } } """ with gr.Blocks(css=custom_css) as demo: gr.HTML("

Multi-Axis Regression of Whisky Tasting Notes

") gr.HTML("""

Automatically turns Whisky Tasting Notes into Flavour Wheels.

This model is a fine-tuned version of microsoft/deberta-v3-base designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — grainy, grassy, fragrant, fruity, peated, woody, winey and off-notes — on a continuous scale from 0 (none) to 4 (extreme).

""") gr.HTML("""

Learn more about use cases and get in touch at www.whisky-wheel.com

""") with gr.Row(): # alles nebeneinander with gr.Column(scale=1): # linke Seite: Input review_box = gr.Textbox( label="Whisky Tasting Note", lines=8, placeholder="Enter whisky tasting note", value=_start_text(), ) gr.HTML("
Note: Non-English texts will be automatically translated.
") with gr.Column(): cleanup_cb = gr.Checkbox( label="BETA: Cleanup Tasting Note", value=False ) gr.HTML("
Removes non–tasting note parts from text.
") with gr.Row(): replace_btn = gr.Button("Load Example", variant="secondary", scale=1) submit_btn = gr.Button("Submit", variant="primary", scale=1) with gr.Column(scale=1): # rechte Seite: Output html_info_out = gr.HTML(label="Info") html_wheel_out = gr.HTML(label="Flavour Wheel") json_out = gr.JSON(label="JSON") # Event Button Submit submit_btn.click( predict, inputs=[review_box, cleanup_cb], outputs=[html_info_out, html_wheel_out, json_out] ) # Event Button Submit replace_btn.click(random_text, outputs=review_box) demo.launch(show_api=False)