tryonapi / SegBody.py
dawoodshahzad0707's picture
Upload SegBody.py
6613c62 verified
from transformers import pipeline
import numpy as np
import cv2
from PIL import Image, ImageDraw
# Try to import insightface for face detection, fallback to OpenCV if not available
try:
from insightface.app import FaceAnalysis
USE_INSIGHTFACE = True
except ImportError:
USE_INSIGHTFACE = False
print("⚠️ InsightFace not available, using OpenCV for face detection")
# Initialize face detection
if USE_INSIGHTFACE:
try:
app = FaceAnalysis(providers=['CPUExecutionProvider']) # CPU only for HF Spaces
app.prepare(ctx_id=-1, det_size=(640, 640)) # -1 for CPU
except Exception as e:
print(f"⚠️ InsightFace initialization failed: {e}, falling back to OpenCV")
USE_INSIGHTFACE = False
# Initialize segmentation pipeline
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes", device=-1) # -1 for CPU
def remove_face_opencv(img, mask):
"""Fallback face detection using OpenCV"""
img_arr = np.asarray(img.convert('RGB'))
gray = cv2.cvtColor(img_arr, cv2.COLOR_RGB2GRAY)
# Load OpenCV's pre-trained face detector
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
# Detect faces
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) > 0:
# Get the first (largest) face
x, y, w, h = faces[0]
# Expand face region
x = max(0, int(x - w * 0.5))
y = max(0, int(y - h * 0.5))
w = int(w * 2.0)
h = int(h * 1.7)
# Draw black rectangle on mask
img1 = ImageDraw.Draw(mask)
img1.rectangle([(x, y), (x + w, y + h)], fill=0)
return mask
def remove_face(img, mask):
"""Remove face from mask using InsightFace or OpenCV fallback"""
if not USE_INSIGHTFACE:
return remove_face_opencv(img, mask)
try:
# Convert image to numpy array
img_arr = np.asarray(img)
# Run face detection
faces = app.get(img_arr)
if len(faces) == 0:
return mask
# Get the first face
bbox = faces[0]['bbox']
# Width and height of face
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
# Make face locations bigger
bbox[0] = bbox[0] - (w * 0.5) # x left
bbox[2] = bbox[2] + (w * 0.5) # x right
bbox[1] = bbox[1] - (h * 0.5) # y top
bbox[3] = bbox[3] + (h * 0.2) # y bottom
# Convert to [(x_left, y_top), (x_right, y_bottom)]
face_locations = [(bbox[0], bbox[1]), (bbox[2], bbox[3])]
# Draw black rect onto mask
img1 = ImageDraw.Draw(mask)
img1.rectangle(face_locations, fill=0)
return mask
except Exception as e:
print(f"⚠️ InsightFace face removal failed: {e}, using OpenCV")
return remove_face_opencv(img, mask)
def segment_body(original_img, face=True):
"""
Segment body from image
Args:
original_img: PIL Image
face: If True, includes face in mask. If False, excludes face.
Returns:
tuple: (segmented_image, mask_image)
"""
# Make a copy
img = original_img.copy()
# Segment image
segments = segmenter(img)
# Create list of masks
segment_include = ["Hat", "Hair", "Sunglasses", "Upper-clothes", "Skirt", "Pants",
"Dress", "Belt", "Left-shoe", "Right-shoe", "Face", "Left-leg",
"Right-leg", "Left-arm", "Right-arm", "Bag", "Scarf"]
mask_list = []
for s in segments:
if s['label'] in segment_include:
mask_list.append(s['mask'])
if len(mask_list) == 0:
# If no segments found, return full mask
print("⚠️ No body segments detected, using full mask")
final_mask = Image.new('L', img.size, 255)
else:
# Paste all masks on top of each other
final_mask = np.array(mask_list[0])
for mask in mask_list[1:]:
current_mask = np.array(mask)
final_mask = final_mask + current_mask
# Convert final mask from np array to PIL image
final_mask = Image.fromarray(final_mask)
# Remove face
if face == False:
final_mask = remove_face(img.convert('RGB'), final_mask)
# Apply mask to original image
img_copy = img.copy()
img_copy.putalpha(final_mask)
return img_copy, final_mask
def segment_torso(original_img):
"""
Segment only torso/upper body from image
Args:
original_img: PIL Image
Returns:
tuple: (segmented_image, mask_image)
"""
# Make a copy
img = original_img.copy()
# Segment image
segments = segmenter(img)
# Create list of masks (torso only)
segment_include = ["Upper-clothes", "Dress", "Belt", "Face", "Left-arm", "Right-arm"]
mask_list = []
for s in segments:
if s['label'] in segment_include:
mask_list.append(s['mask'])
if len(mask_list) == 0:
# If no segments found, return full mask
print("⚠️ No torso segments detected, using full mask")
final_mask = Image.new('L', img.size, 255)
else:
# Paste all masks on top of each other
final_mask = np.array(mask_list[0])
for mask in mask_list[1:]:
current_mask = np.array(mask)
final_mask = final_mask + current_mask
# Convert final mask from np array to PIL image
final_mask = Image.fromarray(final_mask)
# Remove face
final_mask = remove_face(img.convert('RGB'), final_mask)
# Apply mask to original image
img_copy = img.copy()
img_copy.putalpha(final_mask)
return img_copy, final_mask