from flask import Flask, request, jsonify from flask_jwt_extended import JWTManager import hashlib import secrets import os from datetime import timedelta def setup_secure_cookies(app: Flask): """Setup secure cookie configuration for the Flask app.""" # Improved environment detection for Hugging Face Spaces is_development = ( app.config.get('DEBUG', False) or app.config.get('ENV') == 'development' or app.config.get('ENVIRONMENT') == 'development' ) # Check if we're running in Hugging Face Spaces is_huggingface = os.environ.get('SPACE_ID') is not None # Configure JWT cookie settings in app config app.config['JWT_TOKEN_LOCATION'] = ['cookies', 'headers'] app.config['JWT_ACCESS_COOKIE_PATH'] = '/api' app.config['JWT_REFRESH_COOKIE_PATH'] = '/api/auth/refresh' app.config['JWT_COOKIE_CSRF_PROTECT'] = True # Enable CSRF protection app.config['JWT_CSRF_IN_COOKIES'] = True app.config['JWT_ACCESS_COOKIE_NAME'] = 'access_token' app.config['JWT_REFRESH_COOKIE_NAME'] = 'refresh_token' app.config['JWT_COOKIE_SAMESITE'] = 'Lax' # CSRF protection app.config['JWT_COOKIE_SECURE'] = not is_development # Secure in production only @app.after_request def set_secure_cookies(response): """Set secure cookies for all responses.""" # Only set cookies for requests that might have JSON data (typically POST/PUT) if request.method in ['POST', 'PUT']: # Get token from request if available token = request.headers.get('Authorization') if token and token.startswith('Bearer '): token = token[7:] # Remove 'Bearer ' prefix # Determine cookie security settings secure_cookie = not is_development samesite_policy = 'Lax' if is_huggingface else 'Strict' # Set secure cookie for access token response.set_cookie( 'access_token', token, httponly=True, # Prevent XSS attacks secure=secure_cookie, # Send over HTTPS in production/HF Spaces samesite=samesite_policy, # CSRF protection max_age=3600, # 1 hour (matches default JWT expiration) path='/api', # Restrict to API routes domain=None # Don't set domain for cross-origin security ) # Safely check for rememberMe in JSON data remember_me = False try: if request.is_json: json_data = request.get_json(silent=True) if json_data and isinstance(json_data, dict): remember_me = json_data.get('rememberMe', False) except: # If there's any error parsing JSON, default to False remember_me = False # Set remember me cookie if requested if remember_me: response.set_cookie( 'refresh_token', secrets.token_urlsafe(32), httponly=True, secure=secure_cookie, samesite=samesite_policy, max_age=7*24*60*60, # 7 days path='/api/auth/refresh', # Restrict to refresh endpoint domain=None # Don't set domain for cross-origin security ) return response return app def configure_jwt_with_cookies(app: Flask): """Configure JWT to work with cookies.""" jwt = JWTManager(app) # Get allowed origins from CORS configuration allowed_origins = [ 'http://localhost:3000', 'http://localhost:5000', 'http://127.0.0.1:3000', 'http://127.0.0.1:5000', 'http://192.168.1.4:3000', 'https://zelyanoth-lin-cbfcff2.hf.space' ] @jwt.token_verification_loader def verify_token_on_refresh_callback(jwt_header, jwt_payload): """Verify token and refresh if needed.""" # This is a simplified version - in production, you'd check a refresh token return True @jwt.expired_token_loader def expired_token_callback(jwt_header, jwt_payload): """Handle expired tokens.""" # Clear cookies when token expires response = jsonify({'success': False, 'message': 'Token has expired'}) response.set_cookie('access_token', '', expires=0, path='/api', httponly=True, samesite='Lax') response.set_cookie('refresh_token', '', expires=0, path='/api/auth/refresh', httponly=True, samesite='Lax') # Add CORS headers for all allowed origins for origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 @jwt.invalid_token_loader def invalid_token_callback(error): """Handle invalid tokens.""" response = jsonify({'success': False, 'message': 'Invalid token'}) response.set_cookie('access_token', '', expires=0, path='/api', httponly=True, samesite='Lax') response.set_cookie('refresh_token', '', expires=0, path='/api/auth/refresh', httponly=True, samesite='Lax') # Add CORS headers for all allowed origins for origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 @jwt.unauthorized_loader def missing_token_callback(error): """Handle missing tokens.""" response = jsonify({'success': False, 'message': 'Authorization token required'}) # Add CORS headers for all allowed origins for origin in allowed_origins: response.headers.add('Access-Control-Allow-Origin', origin) response.headers.add('Access-Control-Allow-Credentials', 'true') return response, 401 return jwt