Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Standardbibliotheken | |
| import os # Umgebungsvariablen (z.B. HF_TOKEN) | |
| import time # Timing / Performance-Messung | |
| import random # Zufallswerte (z.B. Beispiel-Reviews) | |
| import html # HTML-Escaping für sichere Ausgabe in Gradio | |
| import types # Monkeypatching von Instanzen (fastText .predict) | |
| import numpy as np # Numerische Arrays und Wahrscheinlichkeiten | |
| # Machine Learning / NLP | |
| import torch # PyTorch: Modelle, Tensoren, Devices | |
| import fasttext # Sprach-ID-Modell (lid.176) | |
| # Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden: | |
| import sentencepiece # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3) | |
| import tiktoken # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer) | |
| from langid.langid import LanguageIdentifier, model # Alternative Sprach-ID | |
| # Hugging Face Ökosystem | |
| import spaces # HF Spaces-Dekoratoren (@spaces.GPU) | |
| from transformers import AutoTokenizer # Tokenizer laden (use_fast=False für DeBERTa v3) | |
| from huggingface_hub import hf_hub_download # Download von Dateien/Weights aus dem HF Hub | |
| from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors) | |
| # Übersetzung | |
| import deepl # DeepL API für automatische Übersetzung | |
| # UI / Serving | |
| import gradio as gr # Web-UI für Demo/Spaces | |
| # Projektspezifische Module | |
| from lib.bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier | |
| from lib.bert_regressor_utils import ( | |
| predict_flavours, # Hauptfunktion: Vorhersage der 8 Aromenachsen | |
| cleanup_tasting_note | |
| ) | |
| from lib.wheel import build_svg_with_values # SVG-Rendering für Flavour Wheel | |
| from lib.examples import EXAMPLES # Beispiel-Reviews (vordefiniert) | |
| ################################################################################## | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ### Stettings #################################################################### | |
| MODEL_BASE_CLEANUP = "distilbert-base-uncased" | |
| REPO_ID_CLEANUP = "ziem-io/binary_classifier_is_review_sentence" | |
| MODEL_FILE_CLEANUP = os.getenv("MODEL_FILE_CLEANUP") # in Space-Secrets hinterlegen | |
| MODEL_BASE_CLASSIFY = "microsoft/deberta-v3-base" | |
| REPO_ID_CLASSIFY = "ziem-io/flavour_regressor_multi_head" | |
| MODEL_FILE_CLASSIFY = os.getenv("MODEL_FILE_CLASSIFY") # in Space-Secrets hinterlegen | |
| # (optional) falls das Model-Repo privat ist: | |
| HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen | |
| DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") # in Space-Secrets hinterlegen | |
| ################################################################################## | |
| # --- Download Weights Model Cleanup --- | |
| weights_path_cleanup = hf_hub_download( | |
| repo_id=REPO_ID_CLEANUP, | |
| filename=MODEL_FILE_CLEANUP, | |
| token=HF_TOKEN | |
| ) | |
| # --- Tokenizer Model Classify | |
| tokenizer_cleanup = AutoTokenizer.from_pretrained( | |
| MODEL_BASE_CLEANUP, | |
| use_fast=True # DistilBERT kann fast tokenizer problemlos | |
| ) | |
| model_cleanup = BertBinaryClassifier( | |
| pretrained_model_name=MODEL_BASE_CLEANUP | |
| ) | |
| state_cleanup = load_file(weights_path_cleanup) # safetensors -> dict[str, Tensor] | |
| res = model_cleanup.load_state_dict(state_cleanup, strict=True) # strict=True wenn Keys exakt passen | |
| model_cleanup.to(device).eval() | |
| ################################################################################## | |
| # --- Download Weights Model Classify --- | |
| weights_path_classify = hf_hub_download( | |
| repo_id=REPO_ID_CLASSIFY, | |
| filename=MODEL_FILE_CLASSIFY, | |
| token=HF_TOKEN | |
| ) | |
| # --- Tokenizer Model Classify (SentencePiece!) --- | |
| tokenizer_classify = AutoTokenizer.from_pretrained( | |
| MODEL_BASE_CLASSIFY, | |
| use_fast=False | |
| ) | |
| model_classify = BertMultiHeadRegressor( | |
| pretrained_model_name=MODEL_BASE_CLASSIFY | |
| ) | |
| state_classify = load_file(weights_path_classify) # safetensors -> dict[str, Tensor] | |
| res = model_classify.load_state_dict(state_classify, strict=False) # strict=True wenn Keys exakt passen | |
| model_classify.to(device).eval() | |
| ### Check if lang is english ##################################################### | |
| ID = LanguageIdentifier.from_modelstring(model, norm_probs=True) | |
| def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1): | |
| t = (text or "").strip() | |
| if len(t) < min_chars: | |
| return True, 0.0 | |
| lang, prob = ID.classify(t) # prob ∈ [0,1] | |
| return (lang == "en" and prob >= threshold), float(prob) | |
| def _translate_en(text: str, target_lang: str = "EN-GB"): | |
| deepl_client = deepl.Translator(DEEPL_API_KEY) | |
| result = deepl_client.translate_text(text, target_lang=target_lang) | |
| return result.text | |
| def _render_cleanup_html(cleanup_meta): | |
| """ | |
| Renders cleanup_meta into HTML with struck-through non-note sentences. | |
| """ | |
| html_out = "<p>" | |
| for s in cleanup_meta: | |
| sent = html.escape(s.get("sentence", "")) | |
| if s.get("is_note"): | |
| html_out += f"{sent} " | |
| else: | |
| html_out += f"<span style='text-decoration: line-through; color: gray;'>{sent}</span> " | |
| html_out += "</p>" | |
| return html_out | |
| ### Do actual prediction ######################################################### | |
| # Sekunden GPU-Zeit pro Call | |
| def predict(review_raw: str, do_cleanup: bool): | |
| # Normalize input (handle None and trim whitespace) | |
| review_raw = (review_raw or "").strip() | |
| is_translated = False | |
| html_info_out = "" | |
| # Abort early if no text is provided | |
| if not review_raw: | |
| return "Please enter a review.", "", {} | |
| # Detect language of the input text | |
| review_is_eng, review_lang_prob = _is_eng(review_raw) | |
| # Automatically translate non-English text | |
| if not review_is_eng: | |
| review_raw = _translate_en(review_raw) | |
| html_info_out += ( | |
| "<strong style='display:block'>Your text has been automatically translated:</strong>" | |
| f"<p>{html.escape(review_raw)}</p>" | |
| ) | |
| is_translated = True | |
| # Initialize prediction outputs | |
| prediction_flavours = {} | |
| prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0] | |
| # Start timing the model inference | |
| t_start_flavours = time.time() | |
| # Default values to ensure all variables are always defined | |
| # Without cleanup enabled, the full text is treated as a tasting note | |
| review_clean = review_raw | |
| cleanup_meta = [] | |
| review_status = "review_only" | |
| # Apply text cleanup only when the checkbox is enabled | |
| if do_cleanup: | |
| review_clean, cleanup_meta, review_status = cleanup_tasting_note( | |
| review_raw, | |
| model_cleanup, | |
| tokenizer_cleanup, | |
| device | |
| ) | |
| # Display cleanup visualization based on detected content | |
| if review_status == "mixed": | |
| # Show which parts were kept vs. removed | |
| html_info_out = "<strong>Your text has been cleaned up</strong>" | |
| html_info_out += _render_cleanup_html(cleanup_meta) | |
| elif review_status == "noise_only": | |
| # No tasting notes detected in the text | |
| html_info_out += "<strong>No tasting notes detected</strong>" | |
| html_info_out += _render_cleanup_html(cleanup_meta) | |
| else: # review_status == "review_only" | |
| # Text consists entirely of tasting notes; no cleanup needed | |
| html_info_out += "<strong>No cleanup was necessary</strong>" | |
| # Run flavour prediction only if review content is present | |
| if (not do_cleanup) or (review_status in ("review_only", "mixed")): | |
| prediction_flavours = predict_flavours( | |
| review_clean, | |
| model_classify, | |
| tokenizer_classify, | |
| device | |
| ) | |
| prediction_flavours_list = list(prediction_flavours.values()) | |
| # Stop timing inference | |
| t_end_flavours = time.time() | |
| # Build the flavour wheel SVG | |
| html_wheel_out = build_svg_with_values(prediction_flavours_list) | |
| # Prepare structured JSON output | |
| json_out = { | |
| "result": dict(prediction_flavours.items()), | |
| "range": {"min": 0, "max": 4}, | |
| "review": { | |
| "raw": review_raw, | |
| "clean": review_clean, | |
| "clean_meta": cleanup_meta, | |
| "status": review_status | |
| }, | |
| "models": { | |
| "cleanup": MODEL_FILE_CLEANUP, | |
| "classify": MODEL_FILE_CLASSIFY | |
| }, | |
| "device": device, | |
| "translated": is_translated, | |
| "duration": round((t_end_flavours - t_start_flavours), 3), | |
| } | |
| # Return HTML info, flavour wheel, and JSON output | |
| return html_info_out, html_wheel_out, json_out | |
| ################################################################################## | |
| def random_text(): | |
| return random.choice(EXAMPLES) | |
| def _start_text(): | |
| return EXAMPLES[20] | |
| ### Create Form interface with Gradio Framework ################################## | |
| custom_css = """ | |
| @media (prefers-color-scheme: dark) { | |
| svg#wheel > text { | |
| fill: rgb(200, 200, 200); | |
| } | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>") | |
| gr.HTML(""" | |
| <h3>Automatically turns Whisky Tasting Notes into Flavour Wheels.</h3> | |
| <p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p> | |
| """) | |
| gr.HTML(""" | |
| <p style='color: var(--block-title-text-color)'>Learn more about use cases and get in touch at <a href='https://www.whisky-wheel.com'>www.whisky-wheel.com</a></p> | |
| """) | |
| with gr.Row(): # alles nebeneinander | |
| with gr.Column(scale=1): # linke Seite: Input | |
| review_box = gr.Textbox( | |
| label="Whisky Tasting Note", | |
| lines=8, | |
| placeholder="Enter whisky tasting note", | |
| value=_start_text(), | |
| ) | |
| gr.HTML("<div style='color: gray; font-size: 0.9em;'>Note: Non-English texts will be automatically translated.</div>") | |
| with gr.Column(): | |
| cleanup_cb = gr.Checkbox( | |
| label="BETA: Cleanup Tasting Note", | |
| value=False | |
| ) | |
| gr.HTML("<div style='color: gray; font-size: 0.9em;'>Removes non–tasting note parts from text.</div>") | |
| with gr.Row(): | |
| replace_btn = gr.Button("Load Example", variant="secondary", scale=1) | |
| submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
| with gr.Column(scale=1): # rechte Seite: Output | |
| html_info_out = gr.HTML(label="Info") | |
| html_wheel_out = gr.HTML(label="Flavour Wheel") | |
| json_out = gr.JSON(label="JSON") | |
| # Event Button Submit | |
| submit_btn.click( | |
| predict, | |
| inputs=[review_box, cleanup_cb], | |
| outputs=[html_info_out, html_wheel_out, json_out] | |
| ) | |
| # Event Button Submit | |
| replace_btn.click(random_text, outputs=review_box) | |
| demo.launch(show_api=False) |