Spaces:
Sleeping
Sleeping
Rivalcoder
commited on
Commit
·
a85ce4c
1
Parent(s):
884137e
Edit
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import torch
|
| 2 |
-
import gradio as gr
|
| 3 |
from fastapi import FastAPI, File, UploadFile
|
| 4 |
from fastapi.responses import JSONResponse
|
| 5 |
from transformers import ConvNextForImageClassification, AutoImageProcessor
|
|
@@ -29,49 +28,28 @@ processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
|
|
| 29 |
# FastAPI app
|
| 30 |
app = FastAPI()
|
| 31 |
|
| 32 |
-
#
|
| 33 |
def predict(image: Image.Image):
|
| 34 |
-
# Preprocess the image
|
| 35 |
inputs = processor(images=image, return_tensors="pt")
|
| 36 |
-
|
| 37 |
-
# Perform inference
|
| 38 |
with torch.no_grad():
|
| 39 |
outputs = model(**inputs)
|
| 40 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
| 41 |
-
|
| 42 |
return predicted_class, class_names[predicted_class]
|
| 43 |
|
| 44 |
-
#
|
| 45 |
@app.post("/predict/")
|
| 46 |
async def predict_endpoint(file: UploadFile = File(...)):
|
| 47 |
try:
|
| 48 |
-
# Read and process the image
|
| 49 |
img_bytes = await file.read()
|
| 50 |
-
img = Image.open(io.BytesIO(img_bytes))
|
| 51 |
-
|
| 52 |
-
# Get the prediction
|
| 53 |
predicted_class, predicted_name = predict(img)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
except Exception as e:
|
| 59 |
return JSONResponse(content={"error": str(e)}, status_code=500)
|
| 60 |
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
return f"Predicted Class: {predicted_name}"
|
| 65 |
-
|
| 66 |
-
# Gradio Interface
|
| 67 |
-
iface = gr.Interface(fn=gradio_predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox())
|
| 68 |
-
|
| 69 |
-
# Serve Gradio interface on FastAPI
|
| 70 |
-
@app.get("/gradio/")
|
| 71 |
-
async def gradio_interface():
|
| 72 |
-
return iface.launch(share=True, inline=True)
|
| 73 |
-
|
| 74 |
-
# Run the FastAPI app using Uvicorn
|
| 75 |
-
if __name__ == "__main__":
|
| 76 |
-
import uvicorn
|
| 77 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from fastapi import FastAPI, File, UploadFile
|
| 3 |
from fastapi.responses import JSONResponse
|
| 4 |
from transformers import ConvNextForImageClassification, AutoImageProcessor
|
|
|
|
| 28 |
# FastAPI app
|
| 29 |
app = FastAPI()
|
| 30 |
|
| 31 |
+
# Prediction helper
|
| 32 |
def predict(image: Image.Image):
|
|
|
|
| 33 |
inputs = processor(images=image, return_tensors="pt")
|
|
|
|
|
|
|
| 34 |
with torch.no_grad():
|
| 35 |
outputs = model(**inputs)
|
| 36 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
|
|
|
| 37 |
return predicted_class, class_names[predicted_class]
|
| 38 |
|
| 39 |
+
# Endpoint: /predict
|
| 40 |
@app.post("/predict/")
|
| 41 |
async def predict_endpoint(file: UploadFile = File(...)):
|
| 42 |
try:
|
|
|
|
| 43 |
img_bytes = await file.read()
|
| 44 |
+
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
|
|
|
| 45 |
predicted_class, predicted_name = predict(img)
|
| 46 |
+
return JSONResponse(content={
|
| 47 |
+
"predicted_class": predicted_class,
|
| 48 |
+
"predicted_name": predicted_name
|
| 49 |
+
})
|
| 50 |
except Exception as e:
|
| 51 |
return JSONResponse(content={"error": str(e)}, status_code=500)
|
| 52 |
|
| 53 |
+
# Required for Hugging Face Spaces (do NOT run uvicorn manually)
|
| 54 |
+
# Just expose the app
|
| 55 |
+
app = app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|