Spaces:
Running
Running
| """ | |
| Test script for Binary Segmentation API | |
| Run this to verify the API is working correctly. | |
| """ | |
| import requests | |
| import sys | |
| import time | |
| from pathlib import Path | |
| def test_api(base_url: str = "http://localhost:7860"): | |
| """Run basic API tests""" | |
| print("=" * 60) | |
| print("Binary Segmentation API - Test Suite") | |
| print("=" * 60) | |
| print(f"\nTesting API at: {base_url}\n") | |
| # Test 1: Health Check | |
| print("Test 1: Health Check") | |
| try: | |
| response = requests.get(f"{base_url}/health", timeout=5) | |
| if response.status_code == 200: | |
| print("β Health check passed") | |
| print(f" Response: {response.json()}") | |
| else: | |
| print(f"β Health check failed: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| print(f"β Health check failed: {e}") | |
| print("\n Make sure the API is running:") | |
| print(" python app.py") | |
| print(" or") | |
| print(" uvicorn app:app --host 0.0.0.0 --port 7860") | |
| return False | |
| print() | |
| # Test 2: List Models | |
| print("Test 2: List Models") | |
| try: | |
| response = requests.get(f"{base_url}/models", timeout=5) | |
| if response.status_code == 200: | |
| print("β Models endpoint working") | |
| data = response.json() | |
| print(f" Available models: {len(data.get('models', []))}") | |
| for model in data.get('models', []): | |
| print(f" - {model['name']}: {model['description']}") | |
| else: | |
| print(f"β Models endpoint failed: {response.status_code}") | |
| except Exception as e: | |
| print(f"β Models endpoint failed: {e}") | |
| print() | |
| # Test 3: Create test image | |
| print("Test 3: Create Test Image") | |
| try: | |
| import numpy as np | |
| from PIL import Image | |
| # Create a simple test image (100x100 red square on white background) | |
| img = np.ones((200, 200, 3), dtype=np.uint8) * 255 | |
| img[50:150, 50:150] = [255, 0, 0] # Red square | |
| test_img = Image.fromarray(img) | |
| test_path = Path("test_image.jpg") | |
| test_img.save(test_path) | |
| print(f"β Test image created: {test_path}") | |
| except Exception as e: | |
| print(f"β Failed to create test image: {e}") | |
| return False | |
| print() | |
| # Test 4: Segmentation (if test image exists) | |
| if test_path.exists(): | |
| print("Test 4: Image Segmentation") | |
| try: | |
| with open(test_path, 'rb') as f: | |
| files = {'file': f} | |
| data = { | |
| 'model': 'u2netp', | |
| 'threshold': '0.5' | |
| } | |
| start_time = time.time() | |
| response = requests.post( | |
| f"{base_url}/segment", | |
| files=files, | |
| data=data, | |
| timeout=30 | |
| ) | |
| elapsed = time.time() - start_time | |
| if response.status_code == 200: | |
| output_path = Path("test_output.png") | |
| with open(output_path, 'wb') as out: | |
| out.write(response.content) | |
| print(f"β Segmentation successful ({elapsed:.2f}s)") | |
| print(f" Output saved to: {output_path}") | |
| print(f" Output size: {len(response.content)} bytes") | |
| else: | |
| print(f"β Segmentation failed: {response.status_code}") | |
| print(f" Response: {response.text}") | |
| except Exception as e: | |
| print(f"β Segmentation failed: {e}") | |
| print() | |
| # Test 5: Mask endpoint | |
| if test_path.exists(): | |
| print("Test 5: Binary Mask") | |
| try: | |
| with open(test_path, 'rb') as f: | |
| files = {'file': f} | |
| data = { | |
| 'model': 'u2netp', | |
| 'threshold': '0.5' | |
| } | |
| response = requests.post( | |
| f"{base_url}/segment/mask", | |
| files=files, | |
| data=data, | |
| timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| mask_path = Path("test_mask.png") | |
| with open(mask_path, 'wb') as out: | |
| out.write(response.content) | |
| print(f"β Mask generation successful") | |
| print(f" Mask saved to: {mask_path}") | |
| else: | |
| print(f"β Mask generation failed: {response.status_code}") | |
| except Exception as e: | |
| print(f"β Mask generation failed: {e}") | |
| print() | |
| # Test 6: Base64 endpoint | |
| if test_path.exists(): | |
| print("Test 6: Base64 Output") | |
| try: | |
| with open(test_path, 'rb') as f: | |
| files = {'file': f} | |
| data = { | |
| 'model': 'u2netp', | |
| 'threshold': '0.5', | |
| 'return_type': 'both' | |
| } | |
| response = requests.post( | |
| f"{base_url}/segment/base64", | |
| files=files, | |
| data=data, | |
| timeout=30 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| print(f"β Base64 output successful") | |
| print(f" Has RGBA: {'rgba' in result}") | |
| print(f" Has Mask: {'mask' in result}") | |
| else: | |
| print(f"β Base64 output failed: {response.status_code}") | |
| except Exception as e: | |
| print(f"β Base64 output failed: {e}") | |
| print() | |
| # Cleanup | |
| print("Cleanup:") | |
| try: | |
| if test_path.exists(): | |
| test_path.unlink() | |
| print(f" Removed: {test_path}") | |
| output_path = Path("test_output.png") | |
| if output_path.exists(): | |
| output_path.unlink() | |
| print(f" Removed: {output_path}") | |
| mask_path = Path("test_mask.png") | |
| if mask_path.exists(): | |
| mask_path.unlink() | |
| print(f" Removed: {mask_path}") | |
| except Exception as e: | |
| print(f" Warning: Cleanup failed: {e}") | |
| print() | |
| print("=" * 60) | |
| print("Test Suite Complete!") | |
| print("=" * 60) | |
| return True | |
| if __name__ == "__main__": | |
| # Get base URL from command line or use default | |
| base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860" | |
| success = test_api(base_url) | |
| if success: | |
| print("\nβ All critical tests passed!") | |
| print("\nNext steps:") | |
| print("1. Open http://localhost:7860 in your browser") | |
| print("2. Upload an image and test the web interface") | |
| print("3. Deploy to Hugging Face Spaces (see DEPLOYMENT.md)") | |
| sys.exit(0) | |
| else: | |
| print("\nβ Some tests failed!") | |
| print("\nTroubleshooting:") | |
| print("1. Make sure the server is running:") | |
| print(" uvicorn app:app --host 0.0.0.0 --port 7860") | |
| print("2. Check that u2netp.pth is in .model_cache/") | |
| print("3. Check logs for errors") | |
| sys.exit(1) | |