xinjie.wang
init commit
c28dddb
import os
import json
import argparse
import networkx as nx
from tqdm import tqdm
def get_hash(file, key='diffuse_tree'):
tree = file[key]
G = nx.DiGraph()
for node in tree:
G.add_node(node['id'])
if node['parent'] != -1:
G.add_edge(node['id'], node['parent'])
hashcode = nx.weisfeiler_lehman_graph_hash(G)
return hashcode
if __name__ == "__main__":
'''Script to evaluate the accuracy of the generated graphs'''
parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', type=str, required=True, help='path to the experiment directory')
parser.add_argument('--gt_data_root', type=str, required=True, help='root directory of the ground-truth data')
parser.add_argument('--gt_json_name', type=str, default='object.json', help='Path to the ground truth data')
args = parser.parse_args()
assert os.path.exists(args.exp_dir), "The experiment directory does not exist"
assert os.path.exists(args.gt_data_root), "The ground-truth data root does not exist"
exp_dir = args.exp_dir
gt_data_dir = args.gt_data_root
acc = 0
files = os.listdir(exp_dir)
sorted(files)
total = len(files)
wrong_files = []
for file in tqdm(files):
tokens = file.split('@')
gt_dir = f'{gt_data_dir}'
for token in tokens[:-1]:
gt_dir = os.path.join(gt_dir, token)
with open(os.path.join(gt_dir, args.gt_json_name)) as f:
gt = json.load(f)
# load json files
with open(os.path.join(exp_dir, file)) as f:
pred = json.load(f)
# get hash for the graph
pred_hash = get_hash(pred)
gt_hash = get_hash(gt)
# compare hash
if pred_hash == gt_hash:
acc += 1
else:
wrong_files.append(file)
with open(os.path.join(os.path.dirname(exp_dir), f'acc_{os.path.basename(exp_dir)}.json'), 'w') as f:
json.dump({'acc': acc/total, 'wrong_files': wrong_files}, f, indent=4)