try-on / app.py
Alaiy's picture
Update app.py
ebcb30e verified
import gradio as gr
from dotenv import load_dotenv
import requests
from flask import Flask, jsonify, request, send_file
from botocore.exceptions import ClientError
from botocore.client import Config
import boto3
from urllib.parse import urlparse
import os
from PIL import Image
from io import BytesIO
import uuid
load_dotenv()
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
BUCKET_NAME = "tech-tailor"
s3_client = boto3.client(
"s3",
region_name='ap-south-1',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
config=Config(signature_version='s3v4')
)
MODAL_INFERENCE_ENDPOINT_URL = os.getenv("MODAL_INFERENCE_ENDPOINT_URL")
app = Flask(__name__)
GARM_SAVE_DIR = "garment_images"
MODE_SAVE_DIR = "model_images"
garment_upload_dir = "gradio_demo_garment/"
model_upload_dir = "gradio_demo_model/"
def load_image_from_url(image_url):
try:
response = requests.get(image_url)
if "image" in response.headers["Content-Type"]:
img = Image.open(BytesIO(response.content))
return img
else:
return None
except Exception as e:
print(f"Error loading image: {e}")
return None
def process_cloth_image(image_url):
if image_url:
try:
response = requests.get(image_url)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
img = img.convert("RGB")
img_width, img_height = img.size
target_width = 768
target_height = 1024
scale_width = target_width / img_width
scale_height = target_height / img_height
scale_factor = min(scale_width, scale_height)
new_width = int(img_width * scale_factor)
new_height = int(img_height * scale_factor)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0))
left_padding = (target_width - new_width) // 2
top_padding = (target_height - new_height) // 2
new_img.paste(img, (left_padding, top_padding))
img_byte_array = BytesIO()
new_img.save(img_byte_array, format="JPEG")
img_byte_array.seek(0)
filename = f"{uuid.uuid4().hex}.jpg"
s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = garment_upload_dir + filename, ContentType= 'image/jpeg')
garment_url = s3_client.generate_presigned_url(
'get_object',
Params={'Bucket': BUCKET_NAME, 'Key': garment_upload_dir + filename},
ExpiresIn=3600
)
return garment_url
except requests.exceptions.RequestException as e:
return f"Error downloading image: {e}"
except Exception as e:
return f"Error processing image: {e}"
else:
return "No image provided"
def process_model_image(image):
img = image.convert("RGB")
img_width, img_height = img.size
target_width = 768
target_height = 1024
scale_width = target_width / img_width
scale_height = target_height / img_height
scale_factor = min(scale_width, scale_height)
new_width = int(img_width * scale_factor)
new_height = int(img_height * scale_factor)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0))
left_padding = (target_width - new_width) // 2
top_padding = (target_height - new_height) // 2
new_img.paste(img, (left_padding, top_padding))
img_byte_array = BytesIO()
new_img.save(img_byte_array, format="JPEG")
img_byte_array.seek(0)
filename = f"{uuid.uuid4().hex}.jpg"
s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = model_upload_dir + filename, ContentType = 'image/jpeg')
model_url = s3_client.generate_presigned_url(
'get_object',
Params={'Bucket': BUCKET_NAME, 'Key': model_upload_dir + filename},
ExpiresIn=3600
)
return model_url
def display_image(image, image_url):
garment_file_path = process_cloth_image(image_url)
model_file_path = process_model_image(image)
print(garment_file_path, model_file_path)
payload = {
"human_image_url": model_file_path,
"garment_image_url": garment_file_path
}
print(payload)
results = []
try:
print("Entering Modal block")
response = requests.post(MODAL_INFERENCE_ENDPOINT_URL, json=payload)
if response.status_code == 200:
result_data = response.json()
url = result_data["url"]
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_resized = img.resize((512, 682))
return img_resized
else:
results.append({"error": f"Failed to process the garment image. Status Code: {response.status_code}"})
except requests.exceptions.RequestException as e:
results.append({"error": f"Request failed for the garment image. Error: {str(e)}"})
return ""
def generate_presigned_url(object_url):
parsed_url = urlparse(object_url)
path_parts = parsed_url.path.lstrip('/').split('/', 1)
object_key = path_parts[1] if len(path_parts) > 1 else ''
print(f"Extracted Object Key: {object_key}")
try:
presigned_url = s3_client.generate_presigned_url(
'get_object',
Params={
'Bucket': BUCKET_NAME,
'Key': object_key
},
ExpiresIn=3600
)
return presigned_url
except Exception as e:
print(f"Error generating pre-signed URL: {e}")
return None
with gr.Blocks() as demo:
with gr.Row():
image_url_input = gr.Textbox(label="Image URL", placeholder="Enter image URL here")
input_garment_image = gr.Image(label="Garment Image", type="pil", width="384px", height = "512px")
uploaded_image = gr.Image(label="Upload or Capture Image", type="pil", width="384px", height="512px")
output_display = gr.Image(label="Displayed Image or URL Result", width="384px", height="512px")
image_url_input.change(
load_image_from_url,
inputs=image_url_input,
outputs=input_garment_image
)
submit_btn = gr.Button("Submit")
submit_btn.click(
display_image,
inputs=[uploaded_image, image_url_input],
outputs=output_display
)
demo.launch(share=True)