V-MAGE-DEMO / agent /game_agent.py
Fengx1n's picture
Initial DEMO
e53fda1
import logging
import queue
import random
import re
from time import sleep
from typing import Dict
from utils.config import Config
from utils.encoding_utils import encode_data_to_base64_path, encode_image_path
from utils.file_utils import assemble_project_path, get_all_files
from utils.json_utils import parse_semi_formatted_text
from utils.lmm_utils import assemble_prompt
from utils.planner_utils import _extract_keys_from_template
import json
config = Config()
class game_agent:
def __init__(self, llm_provider=None):
# if config has attribute level_prompt, use it, otherwise use default prompt
if hasattr(config, "level_prompt") and config.level_prompt is not None:
self.prompt = config.level_prompt
print("Using level prompt from config file: " + self.prompt)
else:
self.prompt = config.prompt
print("Using default prompt: " + self.prompt)
self.prompt_template_origin, _, _ = _extract_keys_from_template(self.prompt)
self.prompt_template = self.prompt_template_origin
self.use_instruction = config.use_instruction
if self.use_instruction:
self.instruction_template = config.instruction
self.use_history = config.use_history
if self.use_history:
self.history_template = config.history
self.history_size = len([image_history for image_history in self.history_template if "image" in image_history])
print(f"history_size: {self.history_size}")
self.history = []
self.use_sample_history = config.use_sample_history
if self.use_sample_history:
self.sample_size = config.sample_size
self.sample_histroy_template = config.sample_histroy_template
else:
self.history_size = 1
self.history = []
self.reset_provider(llm_provider)
logging.info("prompt: " + self.prompt_template_origin)
print("prompt: " + self.prompt_template_origin)
def reset_provider(self, llm_provider):
print("Resetting provider...")
self.llm_provider = llm_provider
if self.history is not None and len(self.history) > 0:
# pop the last history
self.history.pop(-1)
def produce_instruction(self):
"""
Generates and inserts an instruction string into the prompt template.
This method constructs an instruction string based on the `instruction_template` attribute.
It replaces the placeholder "<$instruction$>" in the `prompt_template` with the generated instruction string.
The instruction string is built by iterating over the `instruction_template` list, appending text and encoded image placeholders as needed.
Raises:
AssertionError: If the placeholder "<$instruction$>" is not found in `prompt_template`.
Side Effects:
Modifies `self.prompt_template` by replacing the "<$instruction$>" placeholder with the generated instruction string.
Updates `self.input` with encoded image paths using placeholders like "image_instruction_{counter}".
"""
assert "<$instruction$>" in self.prompt_template
instruction_str = ""
instruction_str += "\n\n" + self.instruction_template[0]["text"]
counter = 1
for item in self.instruction_template[1:]:
instruction_str += "\n\n"
if "image" in item:
placeholder_token = f"image_instruction_{counter}"
self.input[placeholder_token] = encode_image_path(item["image"])
instruction_str += f"<${placeholder_token}$>"
if "text" in item:
instruction_str += item["text"]
counter += 1
self.prompt_template = self.prompt_template.replace("<$instruction$>", instruction_str + "\n\n")
def produce_history(self):
"""
Generates a history string based on the provided history template and updates the prompt template with this history.
The method processes the `history_template` in reverse order (excluding the first element) and constructs a history string by replacing placeholders with corresponding values from the `history` list. It also updates the `input` dictionary with image history placeholders.
The constructed history string is then inserted into the `prompt_template` at the placeholder "<$history$>".
Raises:
AssertionError: If the placeholder "<$history$>" is not found in `prompt_template`.
"""
assert "<$history$>" in self.prompt_template
history_str = ""
# Note: The history is stored in reverse order, with the most recent step at the end of the list.
########################################################################################
# produce recent history
# skip current step
counter = 2
for item in reversed(self.history_template[1:]):
if counter > len(self.history):
break
# reversed
if "text" in item:
history_text_template = item["text"]
for history_variable in self.history[-counter]:
if history_variable == "image":
continue
# history_variable_X is the X step in the past (X == 1 means the previous step)
history_variable_X = f"{history_variable}_{counter-1}"
placeholder_token = f"<${history_variable_X}$>"
if placeholder_token in history_text_template:
print(f"history_variable: {history_variable_X}")
history_text_template = \
history_text_template.replace(placeholder_token, self.history[-counter][history_variable])
history_str = history_text_template + history_str
if "image" in item:
placeholder_token = f"image_history_{counter}"
self.input[placeholder_token] = self.history[-counter]["image"]
history_str = f"<${placeholder_token}$>" + history_str
history_str = "\n\n" + history_str
counter += 1
########################################################################################
# produce sample history
# a naive implementation
# just randomly select a sample from the history before self.history_size steps of the current step
if self.use_sample_history:
sample_size = min(self.sample_size, max(len(self.history) - self.history_size, 0))
sample_index = random.sample(range(0, len(self.history)- self.history_size), sample_size)
sample_index.sort(reverse=True)
sample_history_str = ""
for index in sample_index:
'''
This screenshot is <$sample_step$> steps before the current step of the game. After this frame, your reasoning message was \"<$sample_history_reasoning$>\". After the action was excuted, the game info was \"<$sample_history_action_info$>\"
'''
sample_history_template = self.sample_histroy_template["text"]
# 0 1 2 [3 4] 5 (cur) (index)
for history_variable in self.history[index]:
if history_variable == "image":
continue
sample_history_variable = f"sample_{history_variable}"
placeholder_token = f"<${sample_history_variable}$>"
if placeholder_token in sample_history_template:
sample_history_template = \
sample_history_template.replace(placeholder_token, self.history[index][history_variable])
sample_history_template = sample_history_template.replace("<$sample_step$>", str(len(self.history) - index))
placeholder_token_image = f"image_sample_{index}"
sample_history_image = self.history[index]["image"]
self.input[placeholder_token_image] = sample_history_image
history_str = "\n\n" + f"<${placeholder_token_image}$>" + sample_history_template + history_str
########################################################################################
if len(self.history) != 0:
history_str = "\n\n" + self.history_template[0]["text"] + history_str
self.prompt_template = self.prompt_template.replace("<$history$>", history_str + "\n\n")
# print("history_str: ", history_str)
if len(self.history) == 10:
sleep(100)
def update_recent_history(self, info: Dict):
# Update the last step with the action taken
if len(self.history) == 0:
return
# for key in ["history_action", "history_action_info", "history_reasoning"]:
for key in info.keys():
if info.get(key) is not None:
self.history[-1][key] = info[key]
def update_new_history(self, info: Dict):
self.history.append({
# Current Step
'image': info["last_frame_base64"],
'image_path': info["last_frame_path"],
'history_action': None,
'history_action_info': None,
'history_reasoning': None
})
if self.use_history and not self.use_sample_history and len(self.history) > self.history_size + 1:
self.history.pop(0)
def update_game_info(self, game_info: Dict):
# TODO: working memory module
# e.g.
# self.memory.update(info)
self.update_recent_history(game_info)
self.update_new_history(game_info)
def generate_input(self):
self.input = {}
self.prompt_template = self.prompt_template_origin
# current step image is at the end of the history
self.input['image_current_step'] = self.history[-1]["image"]
# Instruction
if self.use_instruction:
# replace <$instruction$> with images and texts.
self.produce_instruction()
# History
if self.use_history:
# replace <$history$> with images and texts.
self.produce_history()
def generate_action(self, data):
if data.get("action") is None:
data["action"] = "None"
action = data["action"]
return action
def execute_action(self):
print(f"Agent execcuting action...")
# Generate self.input
self.generate_input()
# Generate prompt
message_prompts = assemble_prompt(template_str=self.prompt_template, params=self.input, image_prompt_format=self.llm_provider.image_prompt_format)
# Replace base64 image data with values from history array
readable_message_prompts = json.dumps(message_prompts, indent=2)
pattern = re.compile(r"\"data:image/png;base64,[^\"]*\"")
print(f"len(self.history): {len(self.history)}")
print(f"self.history_size: {self.history_size}")
# for i, history_item in enumerate(self.history[-self.history_size:]):
# match = pattern.search(readable_message_prompts)
# if match:
# base64_image = match.group(0)
# expected_image_path = f"\"Image {i+1}: {history_item['image_path']}\""
# assert encode_image_path(history_item['image_path']) in base64_image, f"Base64 encoding does not match for i={i}, {history_item['image_path']}"
# readable_message_prompts = readable_message_prompts[:match.start()] + expected_image_path + readable_message_prompts[match.end():]
# print("Base64 image encoding matches history image paths.")
logging.info("message_prompts: " + readable_message_prompts)
# print the message prompts in JSON format
# print("message_prompts: " + readable_message_prompts.encode("utf-8").decode("unicode_escape"))
# Call the LLM provider for decision making
success, response = self.llm_provider.create_completion(message_prompts)
if not success:
print("Failed to generate response., error: " + response)
error_msg = "Failed to generate response, error: " + response
return False, error_msg
print("--------------------------------------------------------------------------------------")
response = re.sub(r'\n+', '\n', response)
# Convert the response to dict
response = response.replace(":", ":\n")
logging.info("response: " + str(response))
print("response: " + response)
data = parse_semi_formatted_text(response)
self.update_recent_history({"history_reasoning": str(data)})
action = self.generate_action(data)
return True, action