import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import re import json import base64 import argparse from PIL import Image from io import BytesIO from openai import AzureOpenAI from scripts.graph_pred.prompt_workflow_new import messages import json_repair # Initialize the OpenAI client endpoint = os.environ.get("ENDPOINT") api_key = os.environ.get("API_KEY") api_version = os.environ.get("API_VERSION") model_name = os.environ.get("MODEL_NAME") client = AzureOpenAI( azure_endpoint=endpoint, api_key=api_key, api_version=api_version, ) def encode_image(image_path: str, center_crop=False): """Resize and encode the image as base64""" # load the image image = Image.open(image_path) # resize the image to 224x224 if center_crop: # (resize to 256x256 and then center crop to 224x224) image = image.resize((256, 256)) width, height = image.size left = (width - 224) / 2 top = (height - 224) / 2 right = (width + 224) / 2 bottom = (height + 224) / 2 image = image.crop((left, top, right, bottom)) else: image = image.resize((224, 224)) # conver the image to bytes buffer = BytesIO() image.save(buffer, format="PNG") buffer.seek(0) # encode the image as base64 encoded_image = base64.b64encode(buffer.read()).decode("utf-8") return encoded_image def display_image(image_data): """Display the image from the base64 encoded image data""" img = Image.open(BytesIO(base64.b64decode(image_data))) img.show() img.close() def convert_format(src): '''Convert the JSON format from the response to a tree format''' def _sort_nodes(tree): num_nodes = len(tree) sorted_tree = [dict() for _ in range(num_nodes)] for node in tree: sorted_tree[node["id"]] = node return sorted_tree def _traverse(node, parent_id, current_id): for key, value in node.items(): node_id = current_id[0] current_id[0] += 1 # Create the node tree_node = { "id": node_id, "parent": parent_id, "name": key, "children": [], } # Traverse children if they exist if isinstance(value, list): for child in value: child_id = _traverse(child, node_id, current_id) tree_node["children"].append(child_id) # Add this node to the tree tree.append(tree_node) return node_id tree = [] current_id = [0] _traverse(src, -1, current_id) diffuse_tree = _sort_nodes(tree) return diffuse_tree def predict_graph_twomode(image_path, first_img_data=None, second_img_data=None, debug=False, center_crop=False): '''Predict the part connectivity graph from the image''' # Encode the image if first_img_data is None or second_img_data is None: first_img_data = encode_image(image_path, center_crop) second_img_data = encode_image(image_path.replace('close', 'open'), center_crop) # if debug: # display_image(image_data) # for double checking the image # breakpoint() new_message = messages.copy() new_message.append( { "role": "user", "content": [ { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{first_img_data}"}, }, { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{second_img_data}"}, } ], }, ) # Get the completion from the model completion = client.chat.completions.create( model=model_name, messages=new_message, response_format={"type": "text"}, temperature=1, max_tokens=4096, top_p=1, frequency_penalty=0, presence_penalty=0, ) print('processing the response...') # Extract the response content = completion.choices[0].message.content src = json.loads(re.search(r"```json\n(.*?)\n```", content, re.DOTALL).group(1)) print(src) # Convert the JSON format to tree format diffuse_tree = convert_format(src) return {"diffuse_tree": diffuse_tree, "original_response": content} def save_response(save_path, response): '''Save the response to a json file''' with open(save_path, "w") as file: json.dump(response, file, indent=4) def gpt_infer_image_category(image1, image2): system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." text_prompt = ( "Given two images of an object, determine its category. " "The category must be one of the following: Table, Dishwasher, StorageFurniture, " "Refrigerator, WashingMachine, Microwave, Oven. " "Output only the category name and nothing else. Do not include any other text." ) content_user = [ { "type": "text", "text": text_prompt, }, { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{image1}"}, }, { "type": "image_url", "image_url": {"url": f"data:image/png;base64,{image2}"}, }, ] payload = { "messages": [ {"role": "system", "content": system_role}, {"role": "user", "content": content_user}, ], "temperature": 0.1, "max_tokens": 500, "top_p": 0.1, "frequency_penalty": 0, "presence_penalty": 0, "stop": None, "model": model_name, } completion = client.chat.completions.create(**payload) response = completion.choices[0].message.content json_repair.loads(response) return response if __name__ == "__main__": parser = argparse.ArgumentParser(description="Predict the part connectivity graph from an image") parser.add_argument("--img_path", type=str, required=True, help="path to the image") parser.add_argument("--save_path", type=str, required=True, help="path to the save the response") parser.add_argument("--center_crop", action="store_true", help="whether to center crop the image to 224x224, otherwise resize to 224x224") args = parser.parse_args() try: response = predict_graph(args.img_path, args.center_crop) save_response(args.save_path, response) response = predict_graph_twomode(args.img_path, args.center_crop) save_response(args.save_path[:-5] + 'twomode.json', response) except Exception as e: with open('openai_err.log', 'a') as f: f.write('---------------------------\n') f.write(f'{args.img_path}\n') f.write(f'{e}\n')