DIPO / scripts /mesh_retrieval /run_retrieve.py
xinjie.wang
init commit
c28dddb
import os
import subprocess
import argparse
from tqdm.contrib.concurrent import process_map
from functools import partial
def run_retrieve(src_dir, json_name, data_root):
if 'StorageFurniture' not in src_dir and 'Table' not in src_dir:
data_root = '../acd_data/merged-data'
fn_call = ['python', 'scripts/mesh_retrieval/retrieve.py', '--src_dir', src_dir, '--json_name', json_name, '--gt_data_root', data_root]
try:
subprocess.run(fn_call, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
print(f'Error from run_retrieve: {src_dir}')
print(f'Error: {e}')
return ' '.join(fn_call)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True, help="path to the experiment folder")
parser.add_argument("--json_name", type=str, default="object.json", help="name of the json file")
parser.add_argument("--gt_data_root", type=str, default="../data", help="path to the ground truth data")
parser.add_argument("--max_workers", type=int, default=6, help="number of images to render for each object")
args = parser.parse_args()
assert os.path.exists(args.src), f"Src path does not exist: {args.src}"
assert os.path.exists(args.gt_data_root), f"GT data root does not exist: {args.gt_data_root}"
exp_path = args.src
# len_root = len(exp)
print('----------- Retrieve Part Meshes -----------')
src_dirs = []
# exps = os.listdir(root)
# for exp in exps:
# exp_path = os.path.join(root, exp)
for model_id in os.listdir(exp_path):
model_id_path = os.path.join(exp_path, model_id)
# print(model_id_path)
if '.' in model_id:
continue
for model_id_id in os.listdir(model_id_path):
if '.' not in model_id_id:
model_id_id_path = os.path.join(model_id_path, model_id_id)
for json_file in os.listdir(model_id_id_path):
if json_file.endswith(args.json_name):
if os.path.exists(os.path.join(model_id_id_path, 'object.ply')):
print(f"Found {model_id_id_path} with object.ply")
else:
# run_retrieve(model_id_id_path, json_name=args.json_name, data_root=args.gt_data_root)
src_dirs.append(model_id_id_path)
print(len(src_dirs), model_id_id_path)
# for dirpath, dirname, fnames in os.walk(root):
# for fname in fnames:
# if fname.endswith(args.json_name):
# src_dirs.append(dirpath) # save the relative directory path
# print(root)
print(f"Found {len(src_dirs)} jsons to retrieve part meshes")
# print(src_dirs)
# import ipdb
# ipdb.set_trace()
# for src_dir in src_dirs:
# print(src_dir)
# command = run_retrieve(src_dir, json_name=args.json_name, data_root=args.gt_data_root)
# command_file = open('retrieve_commands.sh', 'a')
# command_file.write(command + '\n')
# command_file.close()
process_map(partial(run_retrieve, json_name=args.json_name, data_root=args.gt_data_root), src_dirs, max_workers=6, chunksize=1)