File size: 3,308 Bytes
c28dddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import subprocess
import argparse
from tqdm.contrib.concurrent import process_map
from functools import partial

def run_retrieve(src_dir, json_name, data_root):
    if 'StorageFurniture' not in src_dir and 'Table' not in src_dir:
        data_root = '../acd_data/merged-data'
    fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
    try:
        subprocess.run(fn_call, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
    except subprocess.CalledProcessError as e:
        print(f'Error from run_retrieve: {src_dir}')
        print(f'Error: {e}')
    return ' '.join(fn_call)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--src", type=str, required=True, help="path to the experiment folder")
    parser.add_argument("--json_name", type=str, default="object.json", help="name of the json file")
    parser.add_argument("--gt_data_root", type=str, default="../data", help="path to the ground truth data")
    parser.add_argument("--max_workers", type=int, default=6, help="number of images to render for each object")
    args = parser.parse_args()
    
    assert os.path.exists(args.src), f"Src path does not exist: {args.src}"
    assert os.path.exists(args.gt_data_root), f"GT data root does not exist: {args.gt_data_root}"

    exp_path = args.src
    # len_root = len(exp)
    print('----------- Retrieve Part Meshes -----------')
    src_dirs = []
    # exps = os.listdir(root)
    # for exp in exps:
    # exp_path = os.path.join(root, exp)
    for model_id in os.listdir(exp_path):
        model_id_path = os.path.join(exp_path, model_id)
        # print(model_id_path)
        if '.' in model_id:
            continue
        for model_id_id in os.listdir(model_id_path):
            if '.' not in model_id_id:
                model_id_id_path = os.path.join(model_id_path, model_id_id)
                for json_file in os.listdir(model_id_id_path):
                    if json_file.endswith(args.json_name):
                        if os.path.exists(os.path.join(model_id_id_path, 'object.ply')):
                            print(f"Found {model_id_id_path} with object.ply")
                        else:
                            # run_retrieve(model_id_id_path, json_name=args.json_name, data_root=args.gt_data_root)
                            src_dirs.append(model_id_id_path)
                            print(len(src_dirs), model_id_id_path)
# for dirpath, dirname, fnames in os.walk(root):
    #     for fname in fnames:
    #         if fname.endswith(args.json_name):
    #             src_dirs.append(dirpath) # save the relative directory path
    #             print(root)
    print(f"Found {len(src_dirs)} jsons to retrieve part meshes")
    # print(src_dirs)
    # import ipdb
    # ipdb.set_trace()

    # for src_dir in src_dirs:
    #     print(src_dir)
    #     command = run_retrieve(src_dir, json_name=args.json_name, data_root=args.gt_data_root)
    #     command_file = open('retrieve_commands.sh', 'a')
    #     command_file.write(command + '\n')
    #     command_file.close()
    process_map(partial(run_retrieve, json_name=args.json_name, data_root=args.gt_data_root), src_dirs, max_workers=6, chunksize=1)