mlbench123's picture
Upload 9 files
492772b verified
"""
Test script for Binary Segmentation API
Run this to verify the API is working correctly.
"""
import requests
import sys
import time
from pathlib import Path
def test_api(base_url: str = "http://localhost:7860"):
"""Run basic API tests"""
print("=" * 60)
print("Binary Segmentation API - Test Suite")
print("=" * 60)
print(f"\nTesting API at: {base_url}\n")
# Test 1: Health Check
print("Test 1: Health Check")
try:
response = requests.get(f"{base_url}/health", timeout=5)
if response.status_code == 200:
print("βœ“ Health check passed")
print(f" Response: {response.json()}")
else:
print(f"βœ— Health check failed: {response.status_code}")
return False
except Exception as e:
print(f"βœ— Health check failed: {e}")
print("\n Make sure the API is running:")
print(" python app.py")
print(" or")
print(" uvicorn app:app --host 0.0.0.0 --port 7860")
return False
print()
# Test 2: List Models
print("Test 2: List Models")
try:
response = requests.get(f"{base_url}/models", timeout=5)
if response.status_code == 200:
print("βœ“ Models endpoint working")
data = response.json()
print(f" Available models: {len(data.get('models', []))}")
for model in data.get('models', []):
print(f" - {model['name']}: {model['description']}")
else:
print(f"βœ— Models endpoint failed: {response.status_code}")
except Exception as e:
print(f"βœ— Models endpoint failed: {e}")
print()
# Test 3: Create test image
print("Test 3: Create Test Image")
try:
import numpy as np
from PIL import Image
# Create a simple test image (100x100 red square on white background)
img = np.ones((200, 200, 3), dtype=np.uint8) * 255
img[50:150, 50:150] = [255, 0, 0] # Red square
test_img = Image.fromarray(img)
test_path = Path("test_image.jpg")
test_img.save(test_path)
print(f"βœ“ Test image created: {test_path}")
except Exception as e:
print(f"βœ— Failed to create test image: {e}")
return False
print()
# Test 4: Segmentation (if test image exists)
if test_path.exists():
print("Test 4: Image Segmentation")
try:
with open(test_path, 'rb') as f:
files = {'file': f}
data = {
'model': 'u2netp',
'threshold': '0.5'
}
start_time = time.time()
response = requests.post(
f"{base_url}/segment",
files=files,
data=data,
timeout=30
)
elapsed = time.time() - start_time
if response.status_code == 200:
output_path = Path("test_output.png")
with open(output_path, 'wb') as out:
out.write(response.content)
print(f"βœ“ Segmentation successful ({elapsed:.2f}s)")
print(f" Output saved to: {output_path}")
print(f" Output size: {len(response.content)} bytes")
else:
print(f"βœ— Segmentation failed: {response.status_code}")
print(f" Response: {response.text}")
except Exception as e:
print(f"βœ— Segmentation failed: {e}")
print()
# Test 5: Mask endpoint
if test_path.exists():
print("Test 5: Binary Mask")
try:
with open(test_path, 'rb') as f:
files = {'file': f}
data = {
'model': 'u2netp',
'threshold': '0.5'
}
response = requests.post(
f"{base_url}/segment/mask",
files=files,
data=data,
timeout=30
)
if response.status_code == 200:
mask_path = Path("test_mask.png")
with open(mask_path, 'wb') as out:
out.write(response.content)
print(f"βœ“ Mask generation successful")
print(f" Mask saved to: {mask_path}")
else:
print(f"βœ— Mask generation failed: {response.status_code}")
except Exception as e:
print(f"βœ— Mask generation failed: {e}")
print()
# Test 6: Base64 endpoint
if test_path.exists():
print("Test 6: Base64 Output")
try:
with open(test_path, 'rb') as f:
files = {'file': f}
data = {
'model': 'u2netp',
'threshold': '0.5',
'return_type': 'both'
}
response = requests.post(
f"{base_url}/segment/base64",
files=files,
data=data,
timeout=30
)
if response.status_code == 200:
result = response.json()
print(f"βœ“ Base64 output successful")
print(f" Has RGBA: {'rgba' in result}")
print(f" Has Mask: {'mask' in result}")
else:
print(f"βœ— Base64 output failed: {response.status_code}")
except Exception as e:
print(f"βœ— Base64 output failed: {e}")
print()
# Cleanup
print("Cleanup:")
try:
if test_path.exists():
test_path.unlink()
print(f" Removed: {test_path}")
output_path = Path("test_output.png")
if output_path.exists():
output_path.unlink()
print(f" Removed: {output_path}")
mask_path = Path("test_mask.png")
if mask_path.exists():
mask_path.unlink()
print(f" Removed: {mask_path}")
except Exception as e:
print(f" Warning: Cleanup failed: {e}")
print()
print("=" * 60)
print("Test Suite Complete!")
print("=" * 60)
return True
if __name__ == "__main__":
# Get base URL from command line or use default
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860"
success = test_api(base_url)
if success:
print("\nβœ“ All critical tests passed!")
print("\nNext steps:")
print("1. Open http://localhost:7860 in your browser")
print("2. Upload an image and test the web interface")
print("3. Deploy to Hugging Face Spaces (see DEPLOYMENT.md)")
sys.exit(0)
else:
print("\nβœ— Some tests failed!")
print("\nTroubleshooting:")
print("1. Make sure the server is running:")
print(" uvicorn app:app --host 0.0.0.0 --port 7860")
print("2. Check that u2netp.pth is in .model_cache/")
print("3. Check logs for errors")
sys.exit(1)