Spaces:
Sleeping
Sleeping
File size: 13,625 Bytes
e53fda1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import logging
import queue
import random
import re
from time import sleep
from typing import Dict
from utils.config import Config
from utils.encoding_utils import encode_data_to_base64_path, encode_image_path
from utils.file_utils import assemble_project_path, get_all_files
from utils.json_utils import parse_semi_formatted_text
from utils.lmm_utils import assemble_prompt
from utils.planner_utils import _extract_keys_from_template
import json
config = Config()
class game_agent:
def __init__(self, llm_provider=None):
# if config has attribute level_prompt, use it, otherwise use default prompt
if hasattr(config, "level_prompt") and config.level_prompt is not None:
self.prompt = config.level_prompt
print("Using level prompt from config file: " + self.prompt)
else:
self.prompt = config.prompt
print("Using default prompt: " + self.prompt)
self.prompt_template_origin, _, _ = _extract_keys_from_template(self.prompt)
self.prompt_template = self.prompt_template_origin
self.use_instruction = config.use_instruction
if self.use_instruction:
self.instruction_template = config.instruction
self.use_history = config.use_history
if self.use_history:
self.history_template = config.history
self.history_size = len([image_history for image_history in self.history_template if "image" in image_history])
print(f"history_size: {self.history_size}")
self.history = []
self.use_sample_history = config.use_sample_history
if self.use_sample_history:
self.sample_size = config.sample_size
self.sample_histroy_template = config.sample_histroy_template
else:
self.history_size = 1
self.history = []
self.reset_provider(llm_provider)
logging.info("prompt: " + self.prompt_template_origin)
print("prompt: " + self.prompt_template_origin)
def reset_provider(self, llm_provider):
print("Resetting provider...")
self.llm_provider = llm_provider
if self.history is not None and len(self.history) > 0:
# pop the last history
self.history.pop(-1)
def produce_instruction(self):
"""
Generates and inserts an instruction string into the prompt template.
This method constructs an instruction string based on the `instruction_template` attribute.
It replaces the placeholder "<$instruction$>" in the `prompt_template` with the generated instruction string.
The instruction string is built by iterating over the `instruction_template` list, appending text and encoded image placeholders as needed.
Raises:
AssertionError: If the placeholder "<$instruction$>" is not found in `prompt_template`.
Side Effects:
Modifies `self.prompt_template` by replacing the "<$instruction$>" placeholder with the generated instruction string.
Updates `self.input` with encoded image paths using placeholders like "image_instruction_{counter}".
"""
assert "<$instruction$>" in self.prompt_template
instruction_str = ""
instruction_str += "\n\n" + self.instruction_template[0]["text"]
counter = 1
for item in self.instruction_template[1:]:
instruction_str += "\n\n"
if "image" in item:
placeholder_token = f"image_instruction_{counter}"
self.input[placeholder_token] = encode_image_path(item["image"])
instruction_str += f"<${placeholder_token}$>"
if "text" in item:
instruction_str += item["text"]
counter += 1
self.prompt_template = self.prompt_template.replace("<$instruction$>", instruction_str + "\n\n")
def produce_history(self):
"""
Generates a history string based on the provided history template and updates the prompt template with this history.
The method processes the `history_template` in reverse order (excluding the first element) and constructs a history string by replacing placeholders with corresponding values from the `history` list. It also updates the `input` dictionary with image history placeholders.
The constructed history string is then inserted into the `prompt_template` at the placeholder "<$history$>".
Raises:
AssertionError: If the placeholder "<$history$>" is not found in `prompt_template`.
"""
assert "<$history$>" in self.prompt_template
history_str = ""
# Note: The history is stored in reverse order, with the most recent step at the end of the list.
########################################################################################
# produce recent history
# skip current step
counter = 2
for item in reversed(self.history_template[1:]):
if counter > len(self.history):
break
# reversed
if "text" in item:
history_text_template = item["text"]
for history_variable in self.history[-counter]:
if history_variable == "image":
continue
# history_variable_X is the X step in the past (X == 1 means the previous step)
history_variable_X = f"{history_variable}_{counter-1}"
placeholder_token = f"<${history_variable_X}$>"
if placeholder_token in history_text_template:
print(f"history_variable: {history_variable_X}")
history_text_template = \
history_text_template.replace(placeholder_token, self.history[-counter][history_variable])
history_str = history_text_template + history_str
if "image" in item:
placeholder_token = f"image_history_{counter}"
self.input[placeholder_token] = self.history[-counter]["image"]
history_str = f"<${placeholder_token}$>" + history_str
history_str = "\n\n" + history_str
counter += 1
########################################################################################
# produce sample history
# a naive implementation
# just randomly select a sample from the history before self.history_size steps of the current step
if self.use_sample_history:
sample_size = min(self.sample_size, max(len(self.history) - self.history_size, 0))
sample_index = random.sample(range(0, len(self.history)- self.history_size), sample_size)
sample_index.sort(reverse=True)
sample_history_str = ""
for index in sample_index:
'''
This screenshot is <$sample_step$> steps before the current step of the game. After this frame, your reasoning message was \"<$sample_history_reasoning$>\". After the action was excuted, the game info was \"<$sample_history_action_info$>\"
'''
sample_history_template = self.sample_histroy_template["text"]
# 0 1 2 [3 4] 5 (cur) (index)
for history_variable in self.history[index]:
if history_variable == "image":
continue
sample_history_variable = f"sample_{history_variable}"
placeholder_token = f"<${sample_history_variable}$>"
if placeholder_token in sample_history_template:
sample_history_template = \
sample_history_template.replace(placeholder_token, self.history[index][history_variable])
sample_history_template = sample_history_template.replace("<$sample_step$>", str(len(self.history) - index))
placeholder_token_image = f"image_sample_{index}"
sample_history_image = self.history[index]["image"]
self.input[placeholder_token_image] = sample_history_image
history_str = "\n\n" + f"<${placeholder_token_image}$>" + sample_history_template + history_str
########################################################################################
if len(self.history) != 0:
history_str = "\n\n" + self.history_template[0]["text"] + history_str
self.prompt_template = self.prompt_template.replace("<$history$>", history_str + "\n\n")
# print("history_str: ", history_str)
if len(self.history) == 10:
sleep(100)
def update_recent_history(self, info: Dict):
# Update the last step with the action taken
if len(self.history) == 0:
return
# for key in ["history_action", "history_action_info", "history_reasoning"]:
for key in info.keys():
if info.get(key) is not None:
self.history[-1][key] = info[key]
def update_new_history(self, info: Dict):
self.history.append({
# Current Step
'image': info["last_frame_base64"],
'image_path': info["last_frame_path"],
'history_action': None,
'history_action_info': None,
'history_reasoning': None
})
if self.use_history and not self.use_sample_history and len(self.history) > self.history_size + 1:
self.history.pop(0)
def update_game_info(self, game_info: Dict):
# TODO: working memory module
# e.g.
# self.memory.update(info)
self.update_recent_history(game_info)
self.update_new_history(game_info)
def generate_input(self):
self.input = {}
self.prompt_template = self.prompt_template_origin
# current step image is at the end of the history
self.input['image_current_step'] = self.history[-1]["image"]
# Instruction
if self.use_instruction:
# replace <$instruction$> with images and texts.
self.produce_instruction()
# History
if self.use_history:
# replace <$history$> with images and texts.
self.produce_history()
def generate_action(self, data):
if data.get("action") is None:
data["action"] = "None"
action = data["action"]
return action
def execute_action(self):
print(f"Agent execcuting action...")
# Generate self.input
self.generate_input()
# Generate prompt
message_prompts = assemble_prompt(template_str=self.prompt_template, params=self.input, image_prompt_format=self.llm_provider.image_prompt_format)
# Replace base64 image data with values from history array
readable_message_prompts = json.dumps(message_prompts, indent=2)
pattern = re.compile(r"\"data:image/png;base64,[^\"]*\"")
print(f"len(self.history): {len(self.history)}")
print(f"self.history_size: {self.history_size}")
# for i, history_item in enumerate(self.history[-self.history_size:]):
# match = pattern.search(readable_message_prompts)
# if match:
# base64_image = match.group(0)
# expected_image_path = f"\"Image {i+1}: {history_item['image_path']}\""
# assert encode_image_path(history_item['image_path']) in base64_image, f"Base64 encoding does not match for i={i}, {history_item['image_path']}"
# readable_message_prompts = readable_message_prompts[:match.start()] + expected_image_path + readable_message_prompts[match.end():]
# print("Base64 image encoding matches history image paths.")
logging.info("message_prompts: " + readable_message_prompts)
# print the message prompts in JSON format
# print("message_prompts: " + readable_message_prompts.encode("utf-8").decode("unicode_escape"))
# Call the LLM provider for decision making
success, response = self.llm_provider.create_completion(message_prompts)
if not success:
print("Failed to generate response., error: " + response)
error_msg = "Failed to generate response, error: " + response
return False, error_msg
print("--------------------------------------------------------------------------------------")
response = re.sub(r'\n+', '\n', response)
# Convert the response to dict
response = response.replace(":", ":\n")
logging.info("response: " + str(response))
print("response: " + response)
data = parse_semi_formatted_text(response)
self.update_recent_history({"history_reasoning": str(data)})
action = self.generate_action(data)
return True, action |