xinjie.wang
init commit
c28dddb
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import re
import json
import base64
import argparse
from PIL import Image
from io import BytesIO
from openai import AzureOpenAI
from scripts.graph_pred.prompt_workflow_new import messages
import json_repair
# Initialize the OpenAI client
endpoint = os.environ.get("ENDPOINT")
api_key = os.environ.get("API_KEY")
api_version = os.environ.get("API_VERSION")
model_name = os.environ.get("MODEL_NAME")
client = AzureOpenAI(
azure_endpoint=endpoint,
api_key=api_key,
api_version=api_version,
)
def encode_image(image_path: str, center_crop=False):
"""Resize and encode the image as base64"""
# load the image
image = Image.open(image_path)
# resize the image to 224x224
if center_crop: # (resize to 256x256 and then center crop to 224x224)
image = image.resize((256, 256))
width, height = image.size
left = (width - 224) / 2
top = (height - 224) / 2
right = (width + 224) / 2
bottom = (height + 224) / 2
image = image.crop((left, top, right, bottom))
else:
image = image.resize((224, 224))
# conver the image to bytes
buffer = BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
# encode the image as base64
encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
return encoded_image
def display_image(image_data):
"""Display the image from the base64 encoded image data"""
img = Image.open(BytesIO(base64.b64decode(image_data)))
img.show()
img.close()
def convert_format(src):
'''Convert the JSON format from the response to a tree format'''
def _sort_nodes(tree):
num_nodes = len(tree)
sorted_tree = [dict() for _ in range(num_nodes)]
for node in tree:
sorted_tree[node["id"]] = node
return sorted_tree
def _traverse(node, parent_id, current_id):
for key, value in node.items():
node_id = current_id[0]
current_id[0] += 1
# Create the node
tree_node = {
"id": node_id,
"parent": parent_id,
"name": key,
"children": [],
}
# Traverse children if they exist
if isinstance(value, list):
for child in value:
child_id = _traverse(child, node_id, current_id)
tree_node["children"].append(child_id)
# Add this node to the tree
tree.append(tree_node)
return node_id
tree = []
current_id = [0]
_traverse(src, -1, current_id)
diffuse_tree = _sort_nodes(tree)
return diffuse_tree
def predict_graph_twomode(image_path, first_img_data=None, second_img_data=None, debug=False, center_crop=False):
'''Predict the part connectivity graph from the image'''
# Encode the image
if first_img_data is None or second_img_data is None:
first_img_data = encode_image(image_path, center_crop)
second_img_data = encode_image(image_path.replace('close', 'open'), center_crop)
# if debug:
# display_image(image_data) # for double checking the image
# breakpoint()
new_message = messages.copy()
new_message.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{first_img_data}"},
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{second_img_data}"},
}
],
},
)
# Get the completion from the model
completion = client.chat.completions.create(
model=model_name,
messages=new_message,
response_format={"type": "text"},
temperature=1,
max_tokens=4096,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
print('processing the response...')
# Extract the response
content = completion.choices[0].message.content
src = json.loads(re.search(r"```json\n(.*?)\n```", content, re.DOTALL).group(1))
print(src)
# Convert the JSON format to tree format
diffuse_tree = convert_format(src)
return {"diffuse_tree": diffuse_tree, "original_response": content}
def save_response(save_path, response):
'''Save the response to a json file'''
with open(save_path, "w") as file:
json.dump(response, file, indent=4)
def gpt_infer_image_category(image1, image2):
system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties."
text_prompt = (
"Given two images of an object, determine its category. "
"The category must be one of the following: Table, Dishwasher, StorageFurniture, "
"Refrigerator, WashingMachine, Microwave, Oven. "
"Output only the category name and nothing else. Do not include any other text."
)
content_user = [
{
"type": "text",
"text": text_prompt,
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image1}"},
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image2}"},
},
]
payload = {
"messages": [
{"role": "system", "content": system_role},
{"role": "user", "content": content_user},
],
"temperature": 0.1,
"max_tokens": 500,
"top_p": 0.1,
"frequency_penalty": 0,
"presence_penalty": 0,
"stop": None,
"model": model_name,
}
completion = client.chat.completions.create(**payload)
response = completion.choices[0].message.content
json_repair.loads(response)
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Predict the part connectivity graph from an image")
parser.add_argument("--img_path", type=str, required=True, help="path to the image")
parser.add_argument("--save_path", type=str, required=True, help="path to the save the response")
parser.add_argument("--center_crop", action="store_true", help="whether to center crop the image to 224x224, otherwise resize to 224x224")
args = parser.parse_args()
try:
response = predict_graph(args.img_path, args.center_crop)
save_response(args.save_path, response)
response = predict_graph_twomode(args.img_path, args.center_crop)
save_response(args.save_path[:-5] + 'twomode.json', response)
except Exception as e:
with open('openai_err.log', 'a') as f:
f.write('---------------------------\n')
f.write(f'{args.img_path}\n')
f.write(f'{e}\n')