import os, sys sys.path.append(os.path.join(os.path.dirname(__file__), '..')) import numpy as np from PIL import Image from my_utils.refs import joint_ref, sem_ref def rescale_axis(jtype, axis_d, axis_o, box_center): ''' Function to rescale the axis for rendering Args: - jtype (int): joint type - axis_d (np.array): axis direction - axis_o (np.array): axis origin - box_center (np.array): bounding box center Returns: - center (np.array): rescaled axis origin - axis_d (np.array): rescaled axis direction ''' if jtype == 0 or jtype == 1: return [0., 0., 0.], [0., 0., 0.] if jtype == 3 or jtype == 4: center = box_center else: center = axis_o + np.dot(axis_d, box_center-axis_o) * axis_d return center.tolist(), axis_d.tolist() def make_white_background(src_img): '''Make the white background for the input RGBA image.''' src_img.load() background = Image.new("RGB", src_img.size, (255, 255, 255)) background.paste(src_img, mask=src_img.split()[3]) # 3 is the alpha channel return background def build_graph(tree, K=32): ''' Function to build graph from the node list. Args: nodes: list of nodes K: the maximum number of nodes in the graph Returns: adj: adjacency matrix, records the 1-ring relationship (parent+children) between nodes edge_list: list of edges, for visualization ''' adj = np.zeros((K, K), dtype=np.float32) parents = [] for node in tree: # 1-ring relationship if node['parent'] != -1: adj[node['id'], node['parent']] = 1 parents.append(node['parent']) else: adj[node['id'], node['id']] = 1 parents.append(-1) for child_id in node['children']: adj[node['id'], child_id] = 1 return { 'adj': adj, 'parents': np.array(parents, dtype=np.int8) } def load_input_from(pred_file, K=32): ''' Function to parse input item from a file containing the predicted graph ''' cond = {} # conditional information and axillary data # prepare node data n_nodes = len(pred_file['diffuse_tree']) # prepare graph pred_graph = build_graph(pred_file['diffuse_tree'], K) # attr mask (for Local Attention) attr_mask = np.eye(K, K, dtype=bool) attr_mask = attr_mask.repeat(5, axis=0).repeat(5, axis=1) cond['attr_mask'] = attr_mask # key padding mask (for Global Attention) pad_mask = np.zeros((K*5, K*5), dtype=bool) pad_mask[:, :n_nodes*5] = 1 cond['key_pad_mask'] = pad_mask # adj mask (for Graph Relation Attention) adj_mask = pred_graph['adj'][:].astype(bool) adj_mask = adj_mask.repeat(5, axis=0).repeat(5, axis=1) adj_mask[n_nodes*5:, :] = 1 cond['adj_mask'] = adj_mask # placeholder data = np.zeros((K*5, 6), dtype=bool) cond['cat'] = 2 # axillary info cond['adj'] = pred_graph['adj'] cond['parents'] = np.zeros(K, dtype=np.int8) cond['parents'][:n_nodes] = pred_graph['parents'] cond['n_nodes'] = n_nodes return data, cond def convert_data_range(x): '''postprocessing: convert the raw model output to the original range, following CAGE''' x = x.reshape(-1, 30) # (K, 36) aabb_max = x[:, 0:3] aabb_min = x[:, 3:6] center = (aabb_max + aabb_min) / 2.0 size = (aabb_max - aabb_min).clip(min=5e-3) j_type = np.mean(x[:, 6:12], axis=1) j_type = ((j_type + 0.5) * 5).clip(min=1.0, max=5.0).round() axis_d = x[:, 12:15] axis_d = axis_d / ( np.linalg.norm(axis_d, axis=1, keepdims=True) + np.finfo(float).eps ) axis_o = x[:, 15:18] j_range = (x[:, 18:20] + x[:, 20:22] + x[:, 22:24]) / 3 j_range = j_range.clip(min=-1.0, max=1.0) j_range[:, 0] = j_range[:, 0] * 360 j_range[:, 1] = j_range[:, 1] label = np.mean(x[:, 24:30], axis=1) label = ((label + 0.8) * 5).clip(min=0.0, max=7.0).round() return { "center": center, "size": size, "type": j_type, "axis_d": axis_d, "axis_o": axis_o, "range": j_range, "label": label, } def parse_tree(data, n_nodes, par, adj): tree = [] # convert to json format for i in range(n_nodes): node = {"id": i} node["name"] = sem_ref["bwd"][int(data["label"][i].item())] node["parent"] = int(par[i]) node["children"] = [ int(child) for child in np.where(adj[i] == 1)[0] if child != par[i] ] node["aabb"] = {} node["aabb"]["center"] = data["center"][i].tolist() node["aabb"]["size"] = data["size"][i].tolist() node["joint"] = {} if node['name'] == 'base': node["joint"]["type"] = 'fixed' else: node["joint"]["type"] = joint_ref["bwd"][int(data["type"][i].item())] if node["joint"]["type"] == "fixed": node["joint"]["range"] = [0.0, 0.0] elif node["joint"]["type"] == "revolute": node["joint"]["range"] = [0.0, float(data["range"][i][0])] elif node["joint"]["type"] == "continuous": node["joint"]["range"] = [0.0, 360.0] elif ( node["joint"]["type"] == "prismatic" or node["joint"]["type"] == "screw" ): node["joint"]["range"] = [0.0, float(data["range"][i][1])] node["joint"]["axis"] = {} # relocate the axis to visualize well axis_o, axis_d = rescale_axis( int(data["type"][i].item()), data["axis_d"][i], data["axis_o"][i], data["center"][i], ) node["joint"]["axis"]["direction"] = axis_d node["joint"]["axis"]["origin"] = axis_o # append node to the tree tree.append(node) return tree def convert_json(x, c, prefix=''): out = {"meta": {}, "diffuse_tree": []} n_nodes = c[f"{prefix}n_nodes"][0].item() par = c[f"{prefix}parents"][0].cpu().numpy().tolist() adj = c[f"{prefix}adj"][0].cpu().numpy() np.fill_diagonal(adj, 0) # remove self-loop for the root node if f"{prefix}obj_cat" in c: out["meta"]["obj_cat"] = c[f"{prefix}obj_cat"][0] # convert the data to original range data = convert_data_range(x) # parse the tree tree = parse_tree(data, n_nodes, par, adj) out["diffuse_tree"] = tree return out