File size: 2,027 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import json
import argparse
import networkx as nx
from tqdm import tqdm

def get_hash(file, key='diffuse_tree'):
    tree = file[key]
    G = nx.DiGraph()
    for node in tree:
        G.add_node(node['id'])
        if node['parent'] != -1:
            G.add_edge(node['id'], node['parent'])
    hashcode = nx.weisfeiler_lehman_graph_hash(G)
    return hashcode

if __name__ == "__main__":
    '''Script to evaluate the accuracy of the generated graphs'''

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_dir', type=str, required=True, help='path to the experiment directory')
    parser.add_argument('--gt_data_root', type=str, required=True, help='root directory of the ground-truth data')
    parser.add_argument('--gt_json_name', type=str, default='object.json', help='Path to the ground truth data')
    args = parser.parse_args()

    assert os.path.exists(args.exp_dir), "The experiment directory does not exist"
    assert os.path.exists(args.gt_data_root), "The ground-truth data root does not exist"

    exp_dir = args.exp_dir
    gt_data_dir = args.gt_data_root

    acc = 0
    files = os.listdir(exp_dir)
    sorted(files)
    total = len(files)
    wrong_files = []
    for file in tqdm(files):
        tokens = file.split('@')
        gt_dir = f'{gt_data_dir}'
        for token in tokens[:-1]:
            gt_dir = os.path.join(gt_dir, token)
        with open(os.path.join(gt_dir, args.gt_json_name)) as f:
            gt = json.load(f)
        # load json files
        with open(os.path.join(exp_dir, file)) as f:
            pred = json.load(f)
        # get hash for the graph
        pred_hash = get_hash(pred)
        gt_hash = get_hash(gt)
        # compare hash
        if pred_hash == gt_hash:
            acc += 1
        else:
            wrong_files.append(file)


    with open(os.path.join(os.path.dirname(exp_dir), f'acc_{os.path.basename(exp_dir)}.json'), 'w') as f:
        json.dump({'acc': acc/total, 'wrong_files': wrong_files}, f, indent=4)