Spaces:
Running on CPU Upgrade

File size: 11,439 Bytes
9a5d8ec
ba84108
 
 
 
 
 
9a5d8ec
 
ba84108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a5d8ec
 
ba84108
406793a
bb3d05e
e31ea41
bb3d05e
ba84108
9b9fd34
bb3d05e
ba84108
 
bb3d05e
9b9fd34
 
 
 
9cfe75e
 
b6df863
265d5e8
 
e6317fa
 
265d5e8
e6317fa
3d475b8
 
 
406793a
3d475b8
9b9fd34
 
 
 
 
 
 
 
 
b6df863
e12a61c
9b9fd34
b6df863
9b9fd34
 
e31ea41
9b9fd34
 
db5e662
b6df863
9b9fd34
 
e12a61c
87f6a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe75e
325ed03
b6ecc6e
6370ba9
ba84108
b6ecc6e
 
 
 
 
9ea7979
ba84108
406793a
71c9493
 
406793a
9a2984d
 
 
 
87cb4a4
9a2984d
 
 
 
 
 
 
f52f2a8
9a2984d
f52f2a8
9a2984d
 
 
9cfe75e
defc447
0856c76
9a2984d
9b9fd34
406793a
1e3cd55
9ea7979
9a2984d
9b9fd34
1e3cd55
9a2984d
 
9b9fd34
f1b3d54
9a2984d
f1b3d54
9b9fd34
9a2984d
 
 
 
406793a
9a2984d
 
177a8e6
d5e14ad
177a8e6
9a2984d
f1b3d54
ba1e296
9a2984d
 
b81d9a3
 
9a2984d
b81d9a3
ae8c8c9
0856c76
9a2984d
0856c76
 
924a06f
0856c76
 
6af94a0
ae8c8c9
 
 
87cb4a4
9a2984d
 
ae8c8c9
0eb70bb
87cb4a4
ae8c8c9
 
0eb70bb
ba1e296
9a2984d
 
b81d9a3
 
 
 
 
 
 
ba1e296
9a2984d
f1b3d54
f35e5d2
9a2984d
1e3cd55
c921016
9a2984d
c921016
598b6e6
9a2984d
ba1e296
 
 
db5e662
9a2984d
ba1e296
23ced0f
 
 
 
4c56373
5a001f3
4c56373
c921016
 
9a2984d
1e3cd55
cd07d29
79079d1
 
fa07bfe
2954446
79079d1
4612b23
f2be8d8
4612b23
9cfe75e
b0c4a5b
 
 
 
 
 
 
 
f0fd942
06353d7
296279c
1f7a544
296279c
06353d7
be1d31d
 
 
 
2954446
c50c843
 
 
518369b
c50c843
518369b
4612b23
c50c843
5a001f3
ec632c7
0856c76
 
 
 
 
 
 
c50c843
ad4809c
c50c843
 
 
1e3cd55
 
c50c843
79079d1
0856c76
 
 
 
 
 
 
 
fa07bfe
ac240ce
f0fd942
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Standardbibliotheken
import os          # Umgebungsvariablen (z.B. HF_TOKEN)
import time        # Timing / Performance-Messung
import random      # Zufallswerte (z.B. Beispiel-Reviews)
import html        # HTML-Escaping für sichere Ausgabe in Gradio
import types       # Monkeypatching von Instanzen (fastText .predict)
import numpy as np # Numerische Arrays und Wahrscheinlichkeiten

# Machine Learning / NLP
import torch       # PyTorch: Modelle, Tensoren, Devices
import fasttext    # Sprach-ID-Modell (lid.176)
# Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden:
import sentencepiece  # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3)
import tiktoken       # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer)
from langid.langid import LanguageIdentifier, model  # Alternative Sprach-ID

# Hugging Face Ökosystem
import spaces                                     # HF Spaces-Dekoratoren (@spaces.GPU)
from transformers import AutoTokenizer            # Tokenizer laden (use_fast=False für DeBERTa v3)
from huggingface_hub import hf_hub_download       # Download von Dateien/Weights aus dem HF Hub
from safetensors.torch import load_file           # Sicheres & schnelles Laden von Weights (.safetensors)

# Übersetzung
import deepl  # DeepL API für automatische Übersetzung

# UI / Serving
import gradio as gr  # Web-UI für Demo/Spaces

# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier
from lib.bert_regressor_utils import (
    predict_flavours,  # Hauptfunktion: Vorhersage der 8 Aromenachsen
    cleanup_tasting_note
)
from lib.wheel import build_svg_with_values       # SVG-Rendering für Flavour Wheel
from lib.examples import EXAMPLES                 # Beispiel-Reviews (vordefiniert)

##################################################################################

device = "cuda" if torch.cuda.is_available() else "cpu"

### Stettings ####################################################################

MODEL_BASE_CLEANUP  = "distilbert-base-uncased"
REPO_ID_CLEANUP     = "ziem-io/binary_classifier_is_review_sentence"
MODEL_FILE_CLEANUP  = os.getenv("MODEL_FILE_CLEANUP")  # in Space-Secrets hinterlegen

MODEL_BASE_CLASSIFY = "microsoft/deberta-v3-base"
REPO_ID_CLASSIFY    = "ziem-io/flavour_regressor_multi_head"
MODEL_FILE_CLASSIFY = os.getenv("MODEL_FILE_CLASSIFY")  # in Space-Secrets hinterlegen

# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN")  # in Space-Secrets hinterlegen
DEEPL_API_KEY = os.getenv("DEEPL_API_KEY")  # in Space-Secrets hinterlegen

##################################################################################

# --- Download Weights Model Cleanup ---
weights_path_cleanup = hf_hub_download(
    repo_id=REPO_ID_CLEANUP, 
    filename=MODEL_FILE_CLEANUP, 
    token=HF_TOKEN
)

# --- Tokenizer Model Classify
tokenizer_cleanup = AutoTokenizer.from_pretrained(
    MODEL_BASE_CLEANUP, 
    use_fast=True  # DistilBERT kann fast tokenizer problemlos
)

model_cleanup = BertBinaryClassifier(
    pretrained_model_name=MODEL_BASE_CLEANUP
)
state_cleanup = load_file(weights_path_cleanup)                          # safetensors -> dict[str, Tensor]
res = model_cleanup.load_state_dict(state_cleanup, strict=True)  # strict=True wenn Keys exakt passen

model_cleanup.to(device).eval()

##################################################################################

# --- Download Weights Model Classify ---
weights_path_classify = hf_hub_download(
    repo_id=REPO_ID_CLASSIFY, 
    filename=MODEL_FILE_CLASSIFY, 
    token=HF_TOKEN
)

# --- Tokenizer Model Classify (SentencePiece!) ---
tokenizer_classify = AutoTokenizer.from_pretrained(
    MODEL_BASE_CLASSIFY, 
    use_fast=False
)

model_classify = BertMultiHeadRegressor(
    pretrained_model_name=MODEL_BASE_CLASSIFY
)
state_classify = load_file(weights_path_classify)                          # safetensors -> dict[str, Tensor]
res = model_classify.load_state_dict(state_classify, strict=False)  # strict=True wenn Keys exakt passen

model_classify.to(device).eval()

### Check if lang is english #####################################################

ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)

def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
    t = (text or "").strip()
    if len(t) < min_chars:
        return True, 0.0
    lang, prob = ID.classify(t)  # prob ∈ [0,1]
    return (lang == "en" and prob >= threshold), float(prob)

def _translate_en(text: str, target_lang: str = "EN-GB"):
    deepl_client = deepl.Translator(DEEPL_API_KEY)
    result = deepl_client.translate_text(text, target_lang=target_lang)
    return result.text

def _render_cleanup_html(cleanup_meta):
    """
    Renders cleanup_meta into HTML with struck-through non-note sentences.
    """
    html_out = "<p>"

    for s in cleanup_meta:
        sent = html.escape(s.get("sentence", ""))

        if s.get("is_note"):
            html_out += f"{sent} "
        else:
            html_out += f"<span style='text-decoration: line-through; color: gray;'>{sent}</span> "

    html_out += "</p>"

    return html_out

### Do actual prediction #########################################################
@spaces.GPU(duration=10)  # Sekunden GPU-Zeit pro Call
def predict(review_raw: str, do_cleanup: bool):
    # Normalize input (handle None and trim whitespace)
    review_raw = (review_raw or "").strip()
    is_translated = False
    html_info_out = ""

    # Abort early if no text is provided
    if not review_raw:
        return "Please enter a review.", "", {}

    # Detect language of the input text
    review_is_eng, review_lang_prob = _is_eng(review_raw)

    # Automatically translate non-English text
    if not review_is_eng:
        review_raw = _translate_en(review_raw)
        html_info_out += (
            "<strong style='display:block'>Your text has been automatically translated:</strong>"
            f"<p>{html.escape(review_raw)}</p>"
        )
        is_translated = True

    # Initialize prediction outputs
    prediction_flavours = {}
    prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]

    # Start timing the model inference
    t_start_flavours = time.time()

    # Default values to ensure all variables are always defined
    # Without cleanup enabled, the full text is treated as a tasting note
    review_clean = review_raw
    cleanup_meta = []
    review_status = "review_only"

    # Apply text cleanup only when the checkbox is enabled
    if do_cleanup:
        review_clean, cleanup_meta, review_status = cleanup_tasting_note(
            review_raw,
            model_cleanup,
            tokenizer_cleanup,
            device
        )

        # Display cleanup visualization based on detected content
        if review_status == "mixed":
            # Show which parts were kept vs. removed
            html_info_out = "<strong>Your text has been cleaned up</strong>"
            html_info_out += _render_cleanup_html(cleanup_meta)
        elif review_status == "noise_only":
            # No tasting notes detected in the text
            html_info_out += "<strong>No tasting notes detected</strong>"
            html_info_out += _render_cleanup_html(cleanup_meta)
        else:  # review_status == "review_only"
            # Text consists entirely of tasting notes; no cleanup needed
            html_info_out += "<strong>No cleanup was necessary</strong>"

    # Run flavour prediction only if review content is present
    if (not do_cleanup) or (review_status in ("review_only", "mixed")):
        prediction_flavours = predict_flavours(
            review_clean,
            model_classify,
            tokenizer_classify,
            device
        )
        prediction_flavours_list = list(prediction_flavours.values())

    # Stop timing inference
    t_end_flavours = time.time()

    # Build the flavour wheel SVG
    html_wheel_out = build_svg_with_values(prediction_flavours_list)

    # Prepare structured JSON output
    json_out = {
        "result": dict(prediction_flavours.items()),
        "range": {"min": 0, "max": 4},
        "review": {
            "raw": review_raw,
            "clean": review_clean,
            "clean_meta": cleanup_meta,
            "status": review_status
        },
        "models": {
            "cleanup": MODEL_FILE_CLEANUP,
            "classify": MODEL_FILE_CLASSIFY
        },
        "device": device,
        "translated": is_translated,
        "duration": round((t_end_flavours - t_start_flavours), 3),
    }

    # Return HTML info, flavour wheel, and JSON output
    return html_info_out, html_wheel_out, json_out

##################################################################################

def random_text():
    return random.choice(EXAMPLES)

def _start_text():
    return EXAMPLES[20]

### Create Form interface with Gradio Framework ##################################
custom_css = """
@media (prefers-color-scheme: dark) {
    svg#wheel > text {
        fill: rgb(200, 200, 200);
    }
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
    gr.HTML("""
    <h3>Automatically turns Whisky Tasting Notes into Flavour Wheels.</h3>
    <p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")

    gr.HTML("""
    <p style='color: var(--block-title-text-color)'>Learn more about use cases and get in touch at <a href='https://www.whisky-wheel.com'>www.whisky-wheel.com</a></p>
""")

    with gr.Row():  # alles nebeneinander
        with gr.Column(scale=1):  # linke Seite: Input
            review_box = gr.Textbox(
                label="Whisky Tasting Note",
                lines=8,
                placeholder="Enter whisky tasting note",
                value=_start_text(),
            )
            gr.HTML("<div style='color: gray; font-size: 0.9em;'>Note: Non-English texts will be automatically translated.</div>")
            
            with gr.Column():
                cleanup_cb = gr.Checkbox(
                    label="BETA: Cleanup Tasting Note", 
                    value=False
                )
                gr.HTML("<div style='color: gray; font-size: 0.9em;'>Removes non–tasting note parts from text.</div>")

            with gr.Row():
                replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
                submit_btn  = gr.Button("Submit", variant="primary", scale=1)

        with gr.Column(scale=1):  # rechte Seite: Output
            html_info_out = gr.HTML(label="Info")
            html_wheel_out = gr.HTML(label="Flavour Wheel")
            json_out = gr.JSON(label="JSON")

    # Event Button Submit
    submit_btn.click(
        predict, 
        inputs=[review_box, cleanup_cb],
        outputs=[html_info_out, html_wheel_out, json_out]
    )

    # Event Button Submit
    replace_btn.click(random_text, outputs=review_box)

demo.launch(show_api=False)