Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import subprocess | |
| import argparse | |
| from tqdm.contrib.concurrent import process_map | |
| from functools import partial | |
| def run_retrieve(src_dir, json_name, data_root): | |
| if 'StorageFurniture' not in src_dir and 'Table' not in src_dir: | |
| data_root = '../acd_data/merged-data' | |
| fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root] | |
| try: | |
| subprocess.run(fn_call, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) | |
| except subprocess.CalledProcessError as e: | |
| print(f'Error from run_retrieve: {src_dir}') | |
| print(f'Error: {e}') | |
| return ' '.join(fn_call) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--src", type=str, required=True, help="path to the experiment folder") | |
| parser.add_argument("--json_name", type=str, default="object.json", help="name of the json file") | |
| parser.add_argument("--gt_data_root", type=str, default="../data", help="path to the ground truth data") | |
| parser.add_argument("--max_workers", type=int, default=6, help="number of images to render for each object") | |
| args = parser.parse_args() | |
| assert os.path.exists(args.src), f"Src path does not exist: {args.src}" | |
| assert os.path.exists(args.gt_data_root), f"GT data root does not exist: {args.gt_data_root}" | |
| exp_path = args.src | |
| # len_root = len(exp) | |
| print('----------- Retrieve Part Meshes -----------') | |
| src_dirs = [] | |
| # exps = os.listdir(root) | |
| # for exp in exps: | |
| # exp_path = os.path.join(root, exp) | |
| for model_id in os.listdir(exp_path): | |
| model_id_path = os.path.join(exp_path, model_id) | |
| # print(model_id_path) | |
| if '.' in model_id: | |
| continue | |
| for model_id_id in os.listdir(model_id_path): | |
| if '.' not in model_id_id: | |
| model_id_id_path = os.path.join(model_id_path, model_id_id) | |
| for json_file in os.listdir(model_id_id_path): | |
| if json_file.endswith(args.json_name): | |
| if os.path.exists(os.path.join(model_id_id_path, 'object.ply')): | |
| print(f"Found {model_id_id_path} with object.ply") | |
| else: | |
| # run_retrieve(model_id_id_path, json_name=args.json_name, data_root=args.gt_data_root) | |
| src_dirs.append(model_id_id_path) | |
| print(len(src_dirs), model_id_id_path) | |
| # for dirpath, dirname, fnames in os.walk(root): | |
| # for fname in fnames: | |
| # if fname.endswith(args.json_name): | |
| # src_dirs.append(dirpath) # save the relative directory path | |
| # print(root) | |
| print(f"Found {len(src_dirs)} jsons to retrieve part meshes") | |
| # print(src_dirs) | |
| # import ipdb | |
| # ipdb.set_trace() | |
| # for src_dir in src_dirs: | |
| # print(src_dir) | |
| # command = run_retrieve(src_dir, json_name=args.json_name, data_root=args.gt_data_root) | |
| # command_file = open('retrieve_commands.sh', 'a') | |
| # command_file.write(command + '\n') | |
| # command_file.close() | |
| process_map(partial(run_retrieve, json_name=args.json_name, data_root=args.gt_data_root), src_dirs, max_workers=6, chunksize=1) |