Spaces:
Sleeping
Sleeping
| import argparse | |
| import importlib | |
| import os | |
| import threading | |
| import time | |
| import pandas as pd | |
| from runner.game_runner import entry | |
| from utils.config import Config | |
| import configparser | |
| import os | |
| config = Config() | |
| def main(args): | |
| lmm_config = configparser.ConfigParser() | |
| lmm_config.read(args.llmProviderConfig) | |
| try: | |
| model_name = lmm_config.get('lmm', 'model_path') | |
| model_name = model_name.replace("/", "_") | |
| except: | |
| model_name = lmm_config.get('lmm', 'model_name') | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser("Cradle Agent Runner") | |
| parser.add_argument("--llmProviderConfig", type=str, default="./config/gpt_server_config.ini", help="The path to the LLM provider config file.") | |
| parser.add_argument("--gameEnvConfig", type=str, default="./config/env_config/env_config_race.json", help="The path to the environment config file.") | |
| parser.add_argument("--levelConfig", type=str, default="./config/level_config/racegame/level1.json", help="The path to the level config file.") | |
| parser.add_argument("--generationConfig", type=str, default="./config/generation_config.ini", help="The path to the swift generation config file.") | |
| parser.add_argument("--test_rounds", type=int, default=1, help="Rounds to test the game.") | |
| parser.add_argument("--output_dir", type=str, default="./runs", help="The path to output the results and log.") | |
| return parser | |
| def get_local_rank(): | |
| if "LOCAL_RANK" in os.environ: | |
| return int(os.environ["LOCAL_RANK"]) | |
| if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: | |
| return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) | |
| return None | |
| if __name__ == '__main__': | |
| local_rank = get_local_rank() | |
| if not local_rank or local_rank == 0: | |
| parser = get_args_parser() | |
| args = parser.parse_args() | |
| config.load_env_config(args.gameEnvConfig) | |
| config.load_level_config(args.levelConfig) | |
| main(args) | |
| else: | |
| print(local_rank) | |
| print("process killed.") | |