Spaces:
Sleeping
Sleeping
wjm55
commited on
Commit
·
1955b0a
1
Parent(s):
066ecb2
refactor init_model function to accept model_id parameter and update predict endpoint to use dynamic model initialization; added supervision library to requirements
Browse files- app.py +30 -7
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import os
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from ultralytics import YOLO
|
| 9 |
import requests
|
| 10 |
-
|
| 11 |
|
| 12 |
###
|
| 13 |
|
|
@@ -19,9 +19,8 @@ import requests
|
|
| 19 |
#wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt
|
| 20 |
|
| 21 |
|
| 22 |
-
def init_model():
|
| 23 |
is_pf=True
|
| 24 |
-
model_id = "yoloe-11s"
|
| 25 |
# Create a models directory if it doesn't exist
|
| 26 |
os.makedirs("models", exist_ok=True)
|
| 27 |
filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
|
|
@@ -34,11 +33,16 @@ def init_model():
|
|
| 34 |
|
| 35 |
app = FastAPI()
|
| 36 |
|
| 37 |
-
# Initialize model at startup
|
| 38 |
-
model = init_model()
|
| 39 |
|
| 40 |
@app.post("/predict")
|
| 41 |
-
async def predict(image_url: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Set classes to filter
|
| 43 |
class_list = [text.strip() for text in texts.split(',')]
|
| 44 |
|
|
@@ -51,7 +55,7 @@ async def predict(image_url: str, texts: str = "hat"):
|
|
| 51 |
model.set_classes(class_list, text_embeddings)
|
| 52 |
|
| 53 |
# Run inference with the PIL Image
|
| 54 |
-
results = model.predict(source=image, conf=
|
| 55 |
|
| 56 |
# Extract detection results
|
| 57 |
result = results[0]
|
|
@@ -66,6 +70,25 @@ async def predict(image_url: str, texts: str = "hat"):
|
|
| 66 |
}
|
| 67 |
detections.append(detection)
|
| 68 |
print(detections)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
return {"detections": detections}
|
| 70 |
|
| 71 |
if __name__ == "__main__":
|
|
|
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from ultralytics import YOLO
|
| 9 |
import requests
|
| 10 |
+
import supervision as sv
|
| 11 |
|
| 12 |
###
|
| 13 |
|
|
|
|
| 19 |
#wget -q https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt
|
| 20 |
|
| 21 |
|
| 22 |
+
def init_model(model_id: str):
|
| 23 |
is_pf=True
|
|
|
|
| 24 |
# Create a models directory if it doesn't exist
|
| 25 |
os.makedirs("models", exist_ok=True)
|
| 26 |
filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
|
|
|
|
| 33 |
|
| 34 |
app = FastAPI()
|
| 35 |
|
|
|
|
|
|
|
| 36 |
|
| 37 |
@app.post("/predict")
|
| 38 |
+
async def predict(image_url: str,
|
| 39 |
+
texts: str = "hat",
|
| 40 |
+
model_id: str = "yoloe-11l",
|
| 41 |
+
conf: float = 0.25,
|
| 42 |
+
iou: float = 0.7
|
| 43 |
+
):
|
| 44 |
+
# Initialize model at startup
|
| 45 |
+
model = init_model(model_id)
|
| 46 |
# Set classes to filter
|
| 47 |
class_list = [text.strip() for text in texts.split(',')]
|
| 48 |
|
|
|
|
| 55 |
model.set_classes(class_list, text_embeddings)
|
| 56 |
|
| 57 |
# Run inference with the PIL Image
|
| 58 |
+
results = model.predict(source=image, conf=conf, iou=iou)
|
| 59 |
|
| 60 |
# Extract detection results
|
| 61 |
result = results[0]
|
|
|
|
| 70 |
}
|
| 71 |
detections.append(detection)
|
| 72 |
print(detections)
|
| 73 |
+
# detections = sv.Detections.from_ultralytics(results[0])
|
| 74 |
+
|
| 75 |
+
# resolution_wh = image.size
|
| 76 |
+
# thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
|
| 77 |
+
# text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh)
|
| 78 |
+
|
| 79 |
+
# labels = [
|
| 80 |
+
# f"{class_name} {confidence:.2f}"
|
| 81 |
+
# for class_name, confidence
|
| 82 |
+
# in zip(detections['class_name'], detections.confidence)
|
| 83 |
+
# ]
|
| 84 |
+
|
| 85 |
+
# annotated_image = image.copy()
|
| 86 |
+
# annotated_image = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX, opacity=0.4).annotate(
|
| 87 |
+
# scene=annotated_image, detections=detections)
|
| 88 |
+
# annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate(
|
| 89 |
+
# scene=annotated_image, detections=detections)
|
| 90 |
+
# annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate(
|
| 91 |
+
# scene=annotated_image, detections=detections, labels=labels)
|
| 92 |
return {"detections": detections}
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
|
|
|
| 3 |
git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP
|
| 4 |
git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip
|
| 5 |
git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api
|
|
|
|
| 1 |
fastapi
|
| 2 |
uvicorn[standard]
|
| 3 |
+
supervision
|
| 4 |
git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/CLIP
|
| 5 |
git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/ml-mobileclip
|
| 6 |
git+https://github.com/THU-MIG/yoloe.git#subdirectory=third_party/lvis-api
|