Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import List | |
| import joblib | |
| import pandas as pd | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| import torch | |
| app = FastAPI(title="Comment Validator API") | |
| # ===================================== | |
| # 🔹 Chargement des modèles | |
| # ===================================== | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" # pour ton Mac local | |
| else: | |
| device = "cpu" | |
| print(f"🧠 Using device: {device}") | |
| print("Loading model embedding") | |
| text_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device=device) | |
| print("Loading model classifier") | |
| clf = joblib.load("models/classifier.joblib") | |
| print("Loading model encoder") | |
| encoder = joblib.load("models/encoder.joblib") | |
| print("Loading model sentiment-analysis") | |
| sentiment_analyzer = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment", device=device) | |
| print("Loading model toxicity") | |
| toxicity_analyzer = pipeline("text-classification", model="unitary/toxic-bert", return_all_scores=True, device=device) | |
| def analyze_comment(comment: str, category: str, country: str) -> dict: | |
| reasons = [] | |
| # --- Analyse du sentiment --- | |
| try: | |
| sentiment = sentiment_analyzer(comment[:512])[0] | |
| label = sentiment["label"] | |
| score = sentiment["score"] | |
| except Exception: | |
| label, score = "unknown", 0.0 | |
| if "1" in label or "2" in label: | |
| sentiment_score = -1 | |
| reasons.append("Le ton semble négatif ou insatisfait.") | |
| elif "4" in label or "5" in label: | |
| sentiment_score = 1 | |
| else: | |
| sentiment_score = 0 | |
| # --- Encodage du texte --- | |
| X_text = text_model.encode([comment]) | |
| # --- Encodage catégorie/pays --- | |
| df_cat = pd.DataFrame([[category, country]], columns=["category", "country"]) | |
| try: | |
| X_cat = encoder.transform(df_cat) | |
| except ValueError: | |
| reasons.append(f"Catégorie ou pays inconnus : {category}, {country}") | |
| n_features = sum(len(cats) for cats in encoder.categories_) | |
| X_cat = np.zeros((1, n_features)) | |
| # --- Concaténation --- | |
| X = np.concatenate([X_text, X_cat], axis=1) | |
| # --- Prédiction validité --- | |
| proba = clf.predict_proba(X)[0][1] | |
| prediction = proba >= 0.5 | |
| if len(comment.split()) < 3: | |
| reasons.append("Le commentaire est trop court.") | |
| if sentiment_score < 0: | |
| reasons.append("Le ton global est négatif.") | |
| if proba < 0.4: | |
| reasons.append("Le modèle estime une faible probabilité de validité.") | |
| # --- Analyse toxicité --- | |
| try: | |
| tox_scores = toxicity_analyzer(comment[:512])[0] # tronquer pour sécurité | |
| tags = {f"tag_{item['label']}": round(item['score'], 3) for item in tox_scores} | |
| except Exception: | |
| tags = {f"tag_{label}": 0.0 for label in ["toxicity","severe_toxicity","obscene","identity_attack","insult","threat"]} | |
| # --- Résultat final --- | |
| result = { | |
| "is_valid": bool(prediction), | |
| "confidence": round(float(proba), 3), | |
| "sentiment": label, | |
| "sentiment_score": round(float(score), 3), | |
| "reasons": "; ".join(reasons) if reasons else "Aucune anomalie détectée." | |
| } | |
| result.update(tags) | |
| return result | |
| # ===================================== | |
| # 🔸 Modèles de requête/réponse | |
| # ===================================== | |
| class CommentRequest(BaseModel): | |
| comment: str | |
| category: str | |
| country: str | |
| class BatchRequest(BaseModel): | |
| items: List[CommentRequest] | |
| # ===================================== | |
| # 🔹 Routes | |
| # ===================================== | |
| def predict(item: CommentRequest): | |
| """Analyse un seul commentaire""" | |
| result = analyze_comment(item.comment, item.category, item.country) | |
| return result | |
| def batch_predict(request: BatchRequest): | |
| """Analyse plusieurs commentaires à la fois""" | |
| results = [] | |
| for item in request.items: | |
| results.append(analyze_comment(item.comment, item.category, item.country)) | |
| return {"results": results} | |