|
|
import json |
|
|
from typing import Any, List |
|
|
from agents import Agent, OpenAIChatCompletionsModel, Runner |
|
|
from agents.agent_output import AgentOutputSchemaBase |
|
|
from openai import AsyncOpenAI |
|
|
from config.global_storage import get_model_config |
|
|
from utils.bio_logger import bio_logger as logger |
|
|
from typing import List, Dict |
|
|
from pydantic import BaseModel, Field,ConfigDict |
|
|
|
|
|
|
|
|
class DateRange(BaseModel): |
|
|
|
|
|
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["start", "end"]}) |
|
|
start: str = Field('', description="Start date in YYYY-MM-DD format") |
|
|
end: str = Field('', description="End date in YYYY-MM-DD format") |
|
|
|
|
|
class Journal(BaseModel): |
|
|
|
|
|
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "EISSN"]}) |
|
|
name: str = Field(..., description="Journal name") |
|
|
EISSN: str = Field(..., description="Journal EISSN") |
|
|
|
|
|
class AuthorFilter(BaseModel): |
|
|
|
|
|
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "first_author", "last_author"]}) |
|
|
name: str = Field("", description="Author name to filter") |
|
|
first_author: bool = Field(False, description="Is first author?") |
|
|
last_author: bool = Field(False, description="Is last author?") |
|
|
|
|
|
|
|
|
class Filters(BaseModel): |
|
|
|
|
|
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["date_range", "article_types", "languages", "subjects", "journals", "author"]}) |
|
|
date_range: DateRange = Field(...,default_factory=DateRange) |
|
|
article_types: List[str] = Field(...,default_factory=list) |
|
|
languages: List[str] = Field(["English"],) |
|
|
subjects: List[str] = Field(...,default_factory=list) |
|
|
journals: List[str] = Field([""]) |
|
|
author: AuthorFilter = Field(...,default_factory=AuthorFilter) |
|
|
|
|
|
class RewriteJsonOutput(BaseModel): |
|
|
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["category", "key_words", "key_journals", "queries", "filters"]}) |
|
|
category: str = Field(..., description="Query category") |
|
|
key_words: List[str] = Field(...,default_factory=list) |
|
|
key_journals: List[Journal] = Field(...,default_factory=list) |
|
|
queries: List[str] = Field(...,default_factory=list) |
|
|
filters: Filters = Field(...,default_factory=Filters) |
|
|
|
|
|
|
|
|
class SimpleJsonOutput(BaseModel): |
|
|
key_words: List[str] = Field(...,default_factory=list) |
|
|
|
|
|
|
|
|
class RewriteJsonOutputSchema(AgentOutputSchemaBase): |
|
|
def is_plain_text(self): |
|
|
return False |
|
|
def name(self): |
|
|
return "RewriteJsonOutput" |
|
|
def json_schema(self): |
|
|
return RewriteJsonOutput.model_json_schema() |
|
|
def is_strict_json_schema(self): |
|
|
return True |
|
|
def validate_json(self, json_data: Dict[str, Any]) -> bool: |
|
|
try: |
|
|
if isinstance(json_data, str): |
|
|
json_data = json.loads(json_data) |
|
|
return RewriteJsonOutput.model_validate(json_data) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Validation error: {e}") |
|
|
|
|
|
def parse(self, json_data: Dict[str, Any]) -> Any: |
|
|
if isinstance(json_data, str): |
|
|
json_data = json.loads(json_data) |
|
|
return json_data |
|
|
|
|
|
class RewriteAgent: |
|
|
def __init__(self): |
|
|
self.model_config = get_model_config() |
|
|
self.agent_name = "rewrite agent" |
|
|
self.selected_model = OpenAIChatCompletionsModel( |
|
|
model=self.model_config["rewrite-llm"]["main"]["model"], |
|
|
openai_client=AsyncOpenAI( |
|
|
api_key=self.model_config["rewrite-llm"]["main"]["api_key"], |
|
|
base_url=self.model_config["rewrite-llm"]["main"]["base_url"], |
|
|
timeout=120.0, |
|
|
max_retries=2, |
|
|
), |
|
|
) |
|
|
|
|
|
try: |
|
|
logger.info( |
|
|
f"Rewrite main model: {self.model_config['rewrite-llm']['main']['model']} | " |
|
|
f"base_url: {self.model_config['rewrite-llm']['main']['base_url']}" |
|
|
) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def rewrite_query(self, query: str,INSTRUCTIONS: str,simple_version=False) -> List[str]: |
|
|
try: |
|
|
logger.info(f"Rewriting query with main configuration.") |
|
|
if not simple_version: |
|
|
rewrite_agent = Agent( |
|
|
name=self.agent_name, |
|
|
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', |
|
|
model=self.selected_model, |
|
|
output_type=RewriteJsonOutputSchema(), |
|
|
) |
|
|
else: |
|
|
rewrite_agent = Agent( |
|
|
name=self.agent_name, |
|
|
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', |
|
|
model=self.selected_model, |
|
|
output_type=SimpleJsonOutput, |
|
|
) |
|
|
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query) |
|
|
|
|
|
normalized = result |
|
|
try: |
|
|
if isinstance(result, tuple) and len(result) > 0: |
|
|
normalized = result[0] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
query_result = self.parse_json_output(normalized.final_output.model_dump_json()) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Failed to parse JSON output: {e}") |
|
|
return query_result |
|
|
except Exception as main_error: |
|
|
self.selected_model_backup = OpenAIChatCompletionsModel( |
|
|
model=self.model_config["rewrite-llm"]["backup"]["model"], |
|
|
openai_client=AsyncOpenAI( |
|
|
api_key=self.model_config["rewrite-llm"]["backup"]["api_key"], |
|
|
base_url=self.model_config["rewrite-llm"]["backup"]["base_url"], |
|
|
timeout=120.0, |
|
|
max_retries=2, |
|
|
), |
|
|
) |
|
|
logger.error(f"Error with main model: {main_error}", exc_info=main_error) |
|
|
logger.info("Trying backup model for rewriting query.") |
|
|
if not simple_version: |
|
|
rewrite_agent = Agent( |
|
|
name=self.agent_name, |
|
|
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', |
|
|
model=self.selected_model_backup, |
|
|
output_type=RewriteJsonOutputSchema(), |
|
|
) |
|
|
else: |
|
|
rewrite_agent = Agent( |
|
|
name=self.agent_name, |
|
|
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', |
|
|
model=self.selected_model_backup, |
|
|
output_type=SimpleJsonOutput, |
|
|
) |
|
|
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query) |
|
|
|
|
|
normalized = result |
|
|
try: |
|
|
if isinstance(result, tuple) and len(result) > 0: |
|
|
normalized = result[0] |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
query_result = self.parse_json_output(normalized.final_output.model_dump_json()) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Failed to parse JSON output: {e}") |
|
|
return query_result |
|
|
|
|
|
def parse_json_output(self, output: str) -> Any: |
|
|
"""Take a string output and parse it as JSON""" |
|
|
|
|
|
try: |
|
|
return json.loads(output) |
|
|
except json.JSONDecodeError as e: |
|
|
logger.info(f"Output is not valid JSON: {output}") |
|
|
logger.error(f"Failed to parse output as direct JSON: {e}") |
|
|
|
|
|
|
|
|
parsed_output = output |
|
|
if "```" in parsed_output: |
|
|
try: |
|
|
parts = parsed_output.split("```") |
|
|
if len(parts) >= 3: |
|
|
parsed_output = parts[1] |
|
|
if parsed_output.startswith("json") or parsed_output.startswith( |
|
|
"JSON" |
|
|
): |
|
|
parsed_output = parsed_output[4:].strip() |
|
|
return json.loads(parsed_output) |
|
|
except (IndexError, json.JSONDecodeError) as e: |
|
|
logger.error(f"Failed to parse output from code block: {e}") |
|
|
|
|
|
|
|
|
parsed_output = self.find_json_in_string(output) |
|
|
if parsed_output: |
|
|
try: |
|
|
return json.loads(parsed_output) |
|
|
except json.JSONDecodeError as e: |
|
|
logger.error(f"Failed to parse extracted JSON: {e}") |
|
|
logger.error(f"Extracted JSON: {parsed_output}") |
|
|
return {"queries": []} |
|
|
else: |
|
|
logger.error("No valid JSON found in the output:{output}") |
|
|
|
|
|
return {"queries": []} |
|
|
|
|
|
def find_json_in_string(self, string: str) -> str: |
|
|
""" |
|
|
Method to extract all text in the left-most brace that appears in a string. |
|
|
Used to extract JSON from a string (note that this function does not validate the JSON). |
|
|
|
|
|
Example: |
|
|
string = "bla bla bla {this is {some} text{{}and it's sneaky}} because {it's} confusing" |
|
|
output = "{this is {some} text{{}and it's sneaky}}" |
|
|
""" |
|
|
stack = 0 |
|
|
start_index = None |
|
|
|
|
|
for i, c in enumerate(string): |
|
|
if c == "{": |
|
|
if stack == 0: |
|
|
start_index = i |
|
|
stack += 1 |
|
|
elif c == "}": |
|
|
stack -= 1 |
|
|
if stack == 0: |
|
|
|
|
|
return ( |
|
|
string[start_index : i + 1] if start_index is not None else "" |
|
|
) |
|
|
|
|
|
|
|
|
return "" |
|
|
|