Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import argparse | |
| import networkx as nx | |
| from tqdm import tqdm | |
| def get_hash(file, key='diffuse_tree'): | |
| tree = file[key] | |
| G = nx.DiGraph() | |
| for node in tree: | |
| G.add_node(node['id']) | |
| if node['parent'] != -1: | |
| G.add_edge(node['id'], node['parent']) | |
| hashcode = nx.weisfeiler_lehman_graph_hash(G) | |
| return hashcode | |
| if __name__ == "__main__": | |
| '''Script to evaluate the accuracy of the generated graphs''' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--exp_dir', type=str, required=True, help='path to the experiment directory') | |
| parser.add_argument('--gt_data_root', type=str, required=True, help='root directory of the ground-truth data') | |
| parser.add_argument('--gt_json_name', type=str, default='object.json', help='Path to the ground truth data') | |
| args = parser.parse_args() | |
| assert os.path.exists(args.exp_dir), "The experiment directory does not exist" | |
| assert os.path.exists(args.gt_data_root), "The ground-truth data root does not exist" | |
| exp_dir = args.exp_dir | |
| gt_data_dir = args.gt_data_root | |
| acc = 0 | |
| files = os.listdir(exp_dir) | |
| sorted(files) | |
| total = len(files) | |
| wrong_files = [] | |
| for file in tqdm(files): | |
| tokens = file.split('@') | |
| gt_dir = f'{gt_data_dir}' | |
| for token in tokens[:-1]: | |
| gt_dir = os.path.join(gt_dir, token) | |
| with open(os.path.join(gt_dir, args.gt_json_name)) as f: | |
| gt = json.load(f) | |
| # load json files | |
| with open(os.path.join(exp_dir, file)) as f: | |
| pred = json.load(f) | |
| # get hash for the graph | |
| pred_hash = get_hash(pred) | |
| gt_hash = get_hash(gt) | |
| # compare hash | |
| if pred_hash == gt_hash: | |
| acc += 1 | |
| else: | |
| wrong_files.append(file) | |
| with open(os.path.join(os.path.dirname(exp_dir), f'acc_{os.path.basename(exp_dir)}.json'), 'w') as f: | |
| json.dump({'acc': acc/total, 'wrong_files': wrong_files}, f, indent=4) | |