DA-2 / da2 /utils /base.py
haodongli's picture
update
d82e7f9
import json
import argparse
import os
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
InitProcessGroupKwargs,
ProjectConfiguration,
set_seed
)
import logging
from datetime import (
timedelta,
datetime
)
def load_config(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
return config
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', type=str, required=True)
args = parser.parse_args()
return args
def prepare_to_run():
args = arg_parser()
logging.basicConfig(
format='%(asctime)s --> %(message)s',
datefmt='%m/%d %H:%M:%S',
level=logging.INFO,
)
config = load_config(args.config_path)
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
version = os.path.basename(args.config_path)[:-5]
output_dir = f'output/{version}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
accu_steps = config['accelerator']['accumulation_nsteps']
accelerator = Accelerator(
gradient_accumulation_steps=accu_steps,
mixed_precision=config['accelerator']['mixed_precision'],
log_with=config['accelerator']['report_to'],
project_config=ProjectConfiguration(project_dir=output_dir),
kwargs_handlers=[kwargs]
)
logger = get_logger(__name__, log_level='INFO')
config['env']['logger'] = logger
set_seed(config['env']['seed'])
if config['env']['verbose']:
logger.info(f'Version: {version} (from {args.config_path})')
logger.info(f'Output dir: {output_dir}')
logger.info(f'Using {accelerator.num_processes} GPU' + ('s' if accelerator.num_processes > 1 else ''))
return config, accelerator, output_dir