Mehedi2 commited on
Commit
28c303d
·
verified ·
1 Parent(s): 58052c9

Update gaia_api.py

Browse files
Files changed (1) hide show
  1. gaia_api.py +200 -98
gaia_api.py CHANGED
@@ -1,105 +1,207 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from app import run_agent # Import from your main app
4
- import logging
5
-
6
- # Set up logging
7
- logging.basicConfig(level=logging.INFO)
8
- logger = logging.getLogger(__name__)
9
-
10
- app = FastAPI(
11
- title="GAIA Test Agent API",
12
- description="API endpoint for GAIA benchmark evaluation",
13
- version="1.0.0"
14
- )
15
-
16
- class GaiaRequest(BaseModel):
17
- prompt: str
18
-
19
- class GaiaResponse(BaseModel):
20
- output: str
 
 
21
 
22
- @app.post("/predict", response_model=GaiaResponse)
23
- async def predict(request: GaiaRequest):
 
 
 
 
 
24
  """
25
- Main prediction endpoint for GAIA evaluation
26
- This is the endpoint that GAIA will call to get answers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
  try:
29
- logger.info(f"Received question: {request.prompt[:100]}...")
30
-
31
- # Get answer from your agent
32
- answer = run_agent(request.prompt)
33
-
34
- logger.info(f"Generated answer: {answer[:100]}...")
35
-
36
- return GaiaResponse(output=answer)
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
- logger.error(f"Error processing request: {str(e)}")
40
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
41
-
42
- @app.get("/health")
43
- async def health_check():
44
- """Health check endpoint"""
45
- return {
46
- "status": "healthy",
47
- "message": "GAIA Test Agent is running"
48
- }
49
-
50
- @app.get("/")
51
- async def root():
52
- """Root endpoint with API information"""
53
- return {
54
- "name": "GAIA Test Agent",
55
- "description": "AI Agent for GAIA Benchmark Evaluation",
56
- "endpoints": {
57
- "predict": "/predict - Main prediction endpoint (POST)",
58
- "health": "/health - Health check (GET)",
59
- "docs": "/docs - Interactive API documentation (GET)"
60
- },
61
- "usage": {
62
- "predict": {
63
- "method": "POST",
64
- "body": {
65
- "prompt": "Your question here"
66
- },
67
- "response": {
68
- "output": "Agent's answer"
69
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
71
- }
72
- }
73
-
74
- @app.get("/info")
75
- async def info():
76
- """Get agent information"""
77
- return {
78
- "agent_type": "General AI Assistant",
79
- "model": "DeepSeek V3.1 Terminus via OpenRouter",
80
- "capabilities": [
81
- "General question answering",
82
- "Mathematical calculations",
83
- "Factual queries",
84
- "Yes/No questions",
85
- "Reasoning tasks"
86
- ],
87
- "optimized_for": "GAIA benchmark evaluation"
88
- }
89
-
90
- # For debugging - remove in production
91
- @app.get("/test")
92
- async def test_endpoint():
93
- """Test endpoint to verify the agent works"""
94
  try:
95
- test_answer = run_agent("What is 2 + 2?")
96
- return {
97
- "test_question": "What is 2 + 2?",
98
- "test_answer": test_answer,
99
- "status": "Agent working correctly"
100
- }
101
- except Exception as e:
102
- return {
103
- "status": "Error",
104
- "error": str(e)
105
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+ import re
3
+ import tempfile
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ import requests
7
+ from pandas import DataFrame
8
+
9
+ # --- Constants ---
10
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
+ QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
12
+ SUBMIT_URL = f"{DEFAULT_API_URL}/submit"
13
+ FILE_PATH = f"{DEFAULT_API_URL}/files/"
14
+
15
+
16
+ # --- Helper Methods ---
17
+ def fetch_all_questions() -> Dict:
18
+ """Fetches all questions from the specified API endpoint.
19
+
20
+ This function retrieves a list of questions from the API, handles potential errors
21
+ such as network issues, invalid responses, or empty question lists, and returns
22
+ the questions as a dictionary.
23
 
24
+ Returns:
25
+ Dict: A dictionary containing the questions data retrieved from the API.
26
+
27
+ Raises:
28
+ UserWarning: If there is an error fetching the questions, such as network issues,
29
+ invalid JSON response, or an empty question list. The exception message
30
+ provides details about the specific error encountered.
31
  """
32
+ print(f"Fetching questions from: {QUESTIONS_URL}")
33
+ response = requests.get(QUESTIONS_URL, timeout=15)
34
+ try:
35
+ response.raise_for_status()
36
+ questions_data = response.json()
37
+ if not questions_data:
38
+ print("Fetched questions list is empty.")
39
+ raise UserWarning("Fetched questions list is empty or invalid format.")
40
+ print(f"Fetched {len(questions_data)} questions.")
41
+ return questions_data
42
+ except requests.exceptions.RequestException as e:
43
+ print(f"Error fetching questions: {e}")
44
+ raise UserWarning(f"Error fetching questions: {e}")
45
+ except requests.exceptions.JSONDecodeError as e:
46
+ print(f"Error decoding JSON response from questions endpoint: {e}")
47
+ print(f"Response text: {response.text[:500]}")
48
+ raise UserWarning(f"Error decoding server response for questions: {e}")
49
+ except Exception as e:
50
+ print(f"An unexpected error occurred fetching questions: {e}")
51
+ raise UserWarning(f"An unexpected error occurred fetching questions: {e}")
52
+
53
+
54
+ def submit_answers(submission_data: dict, results_log: list) -> Tuple[str, DataFrame]:
55
+ """Submits answers to the scoring API and returns the submission status and results.
56
+
57
+ This function sends the provided answers to the scoring API, handles potential errors
58
+ such as network issues, server errors, or invalid responses, and returns a status
59
+ message indicating the success or failure of the submission, along with a DataFrame
60
+ containing the results log.
61
+
62
+ Args:
63
+ submission_data (dict): A dictionary containing the answers to be submitted.
64
+ Expected to have a structure compatible with the scoring API.
65
+ results_log (list): A list of dictionaries containing the results log.
66
+ This log is converted to a Pandas DataFrame and returned.
67
+
68
+ Returns:
69
+ Tuple[str, DataFrame]: A tuple containing:
70
+ - A status message (str) indicating the submission status and any relevant
71
+ information or error messages.
72
+ - A Pandas DataFrame containing the results log.
73
+
74
  """
75
  try:
76
+ response = requests.post(SUBMIT_URL, json=submission_data, timeout=60)
77
+ response.raise_for_status()
78
+ result_data = response.json()
79
+ final_status = (
80
+ f"Submission Successful!\n"
81
+ f"User: {result_data.get('username')}\n"
82
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
83
+ f"({result_data.get('correct_count', '?')}/"
84
+ f"{result_data.get('total_attempted', '?')} correct)\n"
85
+ f"Message: {result_data.get('message', 'No message received.')}"
86
+ )
87
+ print("Submission successful.")
88
+ results_df = pd.DataFrame(results_log)
89
+ return final_status, results_df
90
+ except requests.exceptions.HTTPError as e:
91
+ error_detail = f"Server responded with status {e.response.status_code}."
92
+ try:
93
+ error_json = e.response.json()
94
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
95
+ except requests.exceptions.JSONDecodeError:
96
+ error_detail += f" Response: {e.response.text[:500]}"
97
+ status_message = f"Submission Failed: {error_detail}"
98
+ print(status_message)
99
+ results_df = pd.DataFrame(results_log)
100
+ return status_message, results_df
101
+ except requests.exceptions.Timeout:
102
+ status_message = "Submission Failed: The request timed out."
103
+ print(status_message)
104
+ results_df = pd.DataFrame(results_log)
105
+ return status_message, results_df
106
+ except requests.exceptions.RequestException as e:
107
+ status_message = f"Submission Failed: Network error - {e}"
108
+ print(status_message)
109
+ results_df = pd.DataFrame(results_log)
110
+ return status_message, results_df
111
  except Exception as e:
112
+ status_message = f"An unexpected error occurred during submission: {e}"
113
+ print(status_message)
114
+ results_df = pd.DataFrame(results_log)
115
+ return status_message, results_df
116
+
117
+
118
+ def run_agent(gaia_agent, questions_data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
119
+ """Runs the agent on a list of questions and returns the results and answers.
120
+
121
+ This function iterates through a list of questions, runs the provided agent on each
122
+ question, and collects the results and answers. It handles potential errors during
123
+ agent execution and returns the results log and the answers payload.
124
+
125
+ Args:
126
+ gaia_agent: An instance of the GaiaAgent class, which is responsible for
127
+ generating answers to the questions.
128
+ questions_data (List[Dict]): A list of dictionaries, where each dictionary
129
+ represents a question and contains at least the 'task_id' and 'question' keys.
130
+
131
+ Returns:
132
+ Tuple[List[Dict], List[Dict]]: A tuple containing:
133
+ - A list of dictionaries representing the results log, where each dictionary
134
+ contains the 'Task ID', 'Question', and 'Submitted Answer'.
135
+ - A list of dictionaries representing the answers payload, where each dictionary
136
+ contains the 'task_id' and 'submitted_answer'.
137
+ """
138
+ results_log = []
139
+ answers_payload = []
140
+
141
+ print(f"Running agent on {len(questions_data)} questions...")
142
+ for item in questions_data:
143
+ task_id = item.get("task_id")
144
+ question_text = item.get("question")
145
+ question_text = process_file(task_id, question_text)
146
+ if not task_id or question_text is None:
147
+ print(f"Skipping invalid item (missing task_id or question): {item}")
148
+ continue
149
+ try:
150
+ submitted_answer = gaia_agent(task_id, question_text)
151
+ answers_payload.append(
152
+ {"task_id": task_id, "submitted_answer": submitted_answer}
153
+ )
154
+ except Exception as e:
155
+ print(f"Error running agent on task {task_id}: {e}")
156
+ submitted_answer = f"AGENT ERROR: {e}"
157
+
158
+ results_log.append(
159
+ {
160
+ "Task ID": task_id,
161
+ "Question": question_text,
162
+ "Submitted Answer": submitted_answer,
163
  }
164
+ )
165
+ return results_log, answers_payload
166
+
167
+
168
+ def process_file(task_id: str, question_text: str) -> str:
169
+ """
170
+ Attempt to download a file associated with a task from the API.
171
+
172
+ - If the file exists (HTTP 200), it is saved to a temp directory and the local file path is returned.
173
+ - If no file is found (HTTP 404), returns the original question text.
174
+ - For all other HTTP errors, the exception is propagated to the caller.
175
+ """
176
+ file_url = f"{FILE_PATH}{task_id}"
177
+
 
 
 
 
 
 
 
 
 
178
  try:
179
+ response = requests.get(file_url, timeout=30)
180
+ response.raise_for_status()
181
+ except requests.exceptions.RequestException as exc:
182
+ print(f"Exception in download_file>> {str(exc)}")
183
+ return question_text # Unable to get the file
184
+
185
+ # Determine filename from 'Content-Disposition' header, fallback to task_id
186
+ content_disposition = response.headers.get("content-disposition", "")
187
+ filename = task_id
188
+ match = re.search(r'filename="([^"]+)"', content_disposition)
189
+ if match:
190
+ filename = match.group(1)
191
+
192
+ # Save file in a temp directory
193
+ temp_storage_dir = Path(tempfile.gettempdir()) / "gaia_cached_files"
194
+ temp_storage_dir.mkdir(parents=True, exist_ok=True)
195
+
196
+ file_path = temp_storage_dir / filename
197
+ file_path.write_bytes(response.content)
198
+
199
+ print(f"Downloaded file for task {task_id}: {filename}")
200
+
201
+ return (
202
+ f"{question_text}\n\n"
203
+ f"---\n"
204
+ f"A file was downloaded for this task and saved locally at:\n"
205
+ f"{str(file_path)}\n"
206
+ f"---\n\n"
207
+ )