File size: 2,924 Bytes
c44a3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea4cbcc
c44a3c9
 
 
 
 
 
 
 
 
ea4cbcc
 
 
c44a3c9
ea4cbcc
 
c44a3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
import numpy as np
import cv2
import base64
from io import BytesIO
from PIL import Image
from transformers import SegformerForSemanticSegmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import os

app = Flask(__name__)
CORS(app)

# CONFIG
DEVICE = torch.device('cpu') # Hugging Face Free Tier is CPU
MODEL_PATH = "best_model.pth"
MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512"
NUM_CLASSES = 6

# Load Model
print("Loading model...")
model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
)

# --- FIX IS HERE ---
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
state_dict = checkpoint['model_state_dict']

# Fix key names (remove 'module.' if trained on multi-GPU)
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(new_state_dict)
model.to(DEVICE)
model.eval()
print("Model loaded!")

MASK_COLOR_MAP = {
    0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 255, 0),
    3: (0, 0, 255), 4: (255, 255, 0), 5: (255, 0, 255)
}

def transform_image(image_bytes):
    nparr = np.frombuffer(image_bytes, np.uint8)
    image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    original_size = image.shape[:2]
    transform = A.Compose([
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    return transform(image=image)['image'].unsqueeze(0).to(DEVICE), original_size

def colorize_mask(mask):
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for label, color in MASK_COLOR_MAP.items():
        color_mask[mask == label] = color
    return color_mask

def to_base64(image_array):
    img = Image.fromarray(image_array)
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode('utf-8')

@app.route('/')
def home():
    return "Damage Detection API is Running!"

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files: return jsonify({'error': 'No file'}), 400
    file = request.files['file']
    try:
        input_tensor, original_size = transform_image(file.read())
        with torch.no_grad():
            outputs = model(pixel_values=input_tensor)
            logits = nn.functional.interpolate(outputs.logits, size=original_size, mode='bilinear', align_corners=False)
            pred_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
        
        rgb_mask = colorize_mask(pred_mask)
        return jsonify({'mask': f"data:image/png;base64,{to_base64(rgb_mask)}"})
    except Exception as e:
        return jsonify({'error': str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)