| import gradio as gr | |
| from dotenv import load_dotenv | |
| import requests | |
| from flask import Flask, jsonify, request, send_file | |
| from botocore.exceptions import ClientError | |
| from botocore.client import Config | |
| import boto3 | |
| from urllib.parse import urlparse | |
| import os | |
| from PIL import Image | |
| from io import BytesIO | |
| import uuid | |
| load_dotenv() | |
| AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") | |
| AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") | |
| BUCKET_NAME = "tech-tailor" | |
| s3_client = boto3.client( | |
| "s3", | |
| region_name='ap-south-1', | |
| aws_access_key_id=AWS_ACCESS_KEY_ID, | |
| aws_secret_access_key=AWS_SECRET_ACCESS_KEY, | |
| config=Config(signature_version='s3v4') | |
| ) | |
| MODAL_INFERENCE_ENDPOINT_URL = os.getenv("MODAL_INFERENCE_ENDPOINT_URL") | |
| app = Flask(__name__) | |
| GARM_SAVE_DIR = "garment_images" | |
| MODE_SAVE_DIR = "model_images" | |
| garment_upload_dir = "gradio_demo_garment/" | |
| model_upload_dir = "gradio_demo_model/" | |
| def load_image_from_url(image_url): | |
| try: | |
| response = requests.get(image_url) | |
| if "image" in response.headers["Content-Type"]: | |
| img = Image.open(BytesIO(response.content)) | |
| return img | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Error loading image: {e}") | |
| return None | |
| def process_cloth_image(image_url): | |
| if image_url: | |
| try: | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| img = Image.open(BytesIO(response.content)) | |
| img = img.convert("RGB") | |
| img_width, img_height = img.size | |
| target_width = 768 | |
| target_height = 1024 | |
| scale_width = target_width / img_width | |
| scale_height = target_height / img_height | |
| scale_factor = min(scale_width, scale_height) | |
| new_width = int(img_width * scale_factor) | |
| new_height = int(img_height * scale_factor) | |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0)) | |
| left_padding = (target_width - new_width) // 2 | |
| top_padding = (target_height - new_height) // 2 | |
| new_img.paste(img, (left_padding, top_padding)) | |
| img_byte_array = BytesIO() | |
| new_img.save(img_byte_array, format="JPEG") | |
| img_byte_array.seek(0) | |
| filename = f"{uuid.uuid4().hex}.jpg" | |
| s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = garment_upload_dir + filename, ContentType= 'image/jpeg') | |
| garment_url = s3_client.generate_presigned_url( | |
| 'get_object', | |
| Params={'Bucket': BUCKET_NAME, 'Key': garment_upload_dir + filename}, | |
| ExpiresIn=3600 | |
| ) | |
| return garment_url | |
| except requests.exceptions.RequestException as e: | |
| return f"Error downloading image: {e}" | |
| except Exception as e: | |
| return f"Error processing image: {e}" | |
| else: | |
| return "No image provided" | |
| def process_model_image(image): | |
| img = image.convert("RGB") | |
| img_width, img_height = img.size | |
| target_width = 768 | |
| target_height = 1024 | |
| scale_width = target_width / img_width | |
| scale_height = target_height / img_height | |
| scale_factor = min(scale_width, scale_height) | |
| new_width = int(img_width * scale_factor) | |
| new_height = int(img_height * scale_factor) | |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| new_img = Image.new("RGB", (target_width, target_height), (0, 0, 0)) | |
| left_padding = (target_width - new_width) // 2 | |
| top_padding = (target_height - new_height) // 2 | |
| new_img.paste(img, (left_padding, top_padding)) | |
| img_byte_array = BytesIO() | |
| new_img.save(img_byte_array, format="JPEG") | |
| img_byte_array.seek(0) | |
| filename = f"{uuid.uuid4().hex}.jpg" | |
| s3_client.put_object(Body = img_byte_array, Bucket = BUCKET_NAME, Key = model_upload_dir + filename, ContentType = 'image/jpeg') | |
| model_url = s3_client.generate_presigned_url( | |
| 'get_object', | |
| Params={'Bucket': BUCKET_NAME, 'Key': model_upload_dir + filename}, | |
| ExpiresIn=3600 | |
| ) | |
| return model_url | |
| def display_image(image, image_url): | |
| garment_file_path = process_cloth_image(image_url) | |
| model_file_path = process_model_image(image) | |
| print(garment_file_path, model_file_path) | |
| payload = { | |
| "human_image_url": model_file_path, | |
| "garment_image_url": garment_file_path | |
| } | |
| print(payload) | |
| results = [] | |
| try: | |
| print("Entering Modal block") | |
| response = requests.post(MODAL_INFERENCE_ENDPOINT_URL, json=payload) | |
| if response.status_code == 200: | |
| result_data = response.json() | |
| url = result_data["url"] | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)) | |
| img_resized = img.resize((512, 682)) | |
| return img_resized | |
| else: | |
| results.append({"error": f"Failed to process the garment image. Status Code: {response.status_code}"}) | |
| except requests.exceptions.RequestException as e: | |
| results.append({"error": f"Request failed for the garment image. Error: {str(e)}"}) | |
| return "" | |
| def generate_presigned_url(object_url): | |
| parsed_url = urlparse(object_url) | |
| path_parts = parsed_url.path.lstrip('/').split('/', 1) | |
| object_key = path_parts[1] if len(path_parts) > 1 else '' | |
| print(f"Extracted Object Key: {object_key}") | |
| try: | |
| presigned_url = s3_client.generate_presigned_url( | |
| 'get_object', | |
| Params={ | |
| 'Bucket': BUCKET_NAME, | |
| 'Key': object_key | |
| }, | |
| ExpiresIn=3600 | |
| ) | |
| return presigned_url | |
| except Exception as e: | |
| print(f"Error generating pre-signed URL: {e}") | |
| return None | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| image_url_input = gr.Textbox(label="Image URL", placeholder="Enter image URL here") | |
| input_garment_image = gr.Image(label="Garment Image", type="pil", width="384px", height = "512px") | |
| uploaded_image = gr.Image(label="Upload or Capture Image", type="pil", width="384px", height="512px") | |
| output_display = gr.Image(label="Displayed Image or URL Result", width="384px", height="512px") | |
| image_url_input.change( | |
| load_image_from_url, | |
| inputs=image_url_input, | |
| outputs=input_garment_image | |
| ) | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click( | |
| display_image, | |
| inputs=[uploaded_image, image_url_input], | |
| outputs=output_display | |
| ) | |
| demo.launch(share=True) | |