AhsanAftab commited on
Commit
c44a3c9
·
verified ·
1 Parent(s): 98f2438

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from flask import Flask, request, jsonify
3
+ from flask_cors import CORS
4
+ import numpy as np
5
+ import cv2
6
+ import base64
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ from transformers import SegformerForSemanticSegmentation
10
+ import albumentations as A
11
+ from albumentations.pytorch import ToTensorV2
12
+ import torch.nn as nn
13
+ import os
14
+
15
+ app = Flask(__name__)
16
+ CORS(app)
17
+
18
+ # CONFIG
19
+ DEVICE = torch.device('cpu')
20
+ MODEL_PATH = "best_model.pth"
21
+ MODEL_NAME = "nvidia/segformer-b2-finetuned-ade-512-512"
22
+ NUM_CLASSES = 6
23
+
24
+ # Load Model
25
+ print("Loading model...")
26
+ model = SegformerForSemanticSegmentation.from_pretrained(
27
+ MODEL_NAME, num_labels=NUM_CLASSES, ignore_mismatched_sizes=True
28
+ )
29
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
30
+ state_dict = checkpoint['model_state_dict']
31
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
32
+ model.load_state_dict(new_state_dict)
33
+ model.to(DEVICE)
34
+ model.eval()
35
+ print("Model loaded!")
36
+
37
+ MASK_COLOR_MAP = {
38
+ 0: (0, 0, 0), 1: (255, 0, 0), 2: (0, 255, 0),
39
+ 3: (0, 0, 255), 4: (255, 255, 0), 5: (255, 0, 255)
40
+ }
41
+
42
+ def transform_image(image_bytes):
43
+ nparr = np.frombuffer(image_bytes, np.uint8)
44
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
45
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
46
+ original_size = image.shape[:2]
47
+ transform = A.Compose([
48
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
49
+ ToTensorV2()
50
+ ])
51
+ return transform(image=image)['image'].unsqueeze(0).to(DEVICE), original_size
52
+
53
+ def colorize_mask(mask):
54
+ h, w = mask.shape
55
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
56
+ for label, color in MASK_COLOR_MAP.items():
57
+ color_mask[mask == label] = color
58
+ return color_mask
59
+
60
+ def to_base64(image_array):
61
+ img = Image.fromarray(image_array)
62
+ buffer = BytesIO()
63
+ img.save(buffer, format="PNG")
64
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
65
+
66
+ @app.route('/')
67
+ def home():
68
+ return "Damage Detection API is Running!"
69
+
70
+ @app.route('/predict', methods=['POST'])
71
+ def predict():
72
+ if 'file' not in request.files: return jsonify({'error': 'No file'}), 400
73
+ file = request.files['file']
74
+ try:
75
+ input_tensor, original_size = transform_image(file.read())
76
+ with torch.no_grad():
77
+ outputs = model(pixel_values=input_tensor)
78
+ logits = nn.functional.interpolate(outputs.logits, size=original_size, mode='bilinear', align_corners=False)
79
+ pred_mask = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
80
+
81
+ rgb_mask = colorize_mask(pred_mask)
82
+ return jsonify({'mask': f"data:image/png;base64,{to_base64(rgb_mask)}"})
83
+ except Exception as e:
84
+ return jsonify({'error': str(e)}), 500
85
+
86
+ # --- CRITICAL CHANGE FOR HUGGING FACE ---
87
+ if __name__ == '__main__':
88
+ # Hugging Face runs on port 7860
89
+ app.run(host='0.0.0.0', port=7860)