mhamzaanjum380 commited on
Commit
f9eb8d1
·
1 Parent(s): a6fb4d4

update agent and app file

Browse files
Files changed (3) hide show
  1. agent.py +174 -99
  2. app.py +9 -8
  3. helping_tools.py +0 -133
agent.py CHANGED
@@ -1,58 +1,130 @@
1
- import os
2
  from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
8
- from langchain_core.messages import SystemMessage, HumanMessage
9
- from langchain.tools.retriever import create_retriever_tool
10
- from langchain_community.vectorstores import FAISS
11
- from langchain.schema import Document
12
-
13
- from helping_tools import (
14
- multiply,
15
- add,
16
- subtract,
17
- divide,
18
- modulus,
19
- wiki_search,
20
- web_search,
21
- arvix_search,
22
- wikipedia_image_addition_date
23
- )
24
- # Load metadata.jsonl
25
  import json
26
 
27
- # Load the metadata.jsonl file
28
- with open('metadata.jsonl', 'r') as jsonl_file:
29
- json_list = list(jsonl_file)
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Load dotenv file
32
  load_dotenv()
33
 
34
- if "GOOGLE_API_KEY" not in os.environ:
35
- google_api_key = os.getenv('GOOGLE_API_KEY')
36
- if google_api_key is not None:
37
- os.environ["GOOGLE_API_KEY"] = google_api_key
38
-
39
- # metadata.jsonl questions load
40
- json_QA = []
41
- for json_str in json_list:
42
- json_data = json.loads(json_str)
43
- json_QA.append(json_data)
44
-
45
- # metadata.jsonl questions
46
- docs = []
47
- for sample in json_QA:
48
- content = f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}"
49
- doc = Document(
50
- page_content=content,
51
- metadata={
52
- "source": sample['task_id']
53
- }
54
- )
55
- docs.append(doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # load the system prompt from the file
58
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
@@ -61,16 +133,6 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
61
  # System message
62
  sys_msg = SystemMessage(content=system_prompt)
63
 
64
- # build a retriever
65
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
66
- vector_store = FAISS.from_documents(documents=docs, embedding=embeddings)
67
-
68
- create_retrieve_tool = create_retriever_tool(
69
- retriever=vector_store.as_retriever(),
70
- name="Question Search",
71
- description="A tool to retrieve similar questions from a vector store.",
72
- )
73
-
74
  tools = [
75
  multiply,
76
  add,
@@ -80,77 +142,90 @@ tools = [
80
  wiki_search,
81
  web_search,
82
  arvix_search,
83
- wikipedia_image_addition_date
84
  ]
85
 
86
  # Build graph function
87
- def build_graph(provider: str):
88
  """Build the graph"""
89
  # Load environment variables from .env file
90
  if provider == "google":
91
  # Google Gemini
92
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
93
- elif provider == "huggingface":
94
- llm = ChatHuggingFace(
95
- llm=HuggingFaceEndpoint(
96
- model="Meta-DeepLearning/llama-2-7b-chat-hf",
97
- temperature=0,
98
- ),
99
- )
100
  else:
101
- raise ValueError("Invalid provider. Choose 'google', or 'huggingface'.")
102
  # Bind tools to LLM
103
  llm_with_tools = llm.bind_tools(tools)
104
 
105
- # Node
 
 
 
 
 
 
 
 
 
 
 
106
  def assistant(state: MessagesState):
107
  """Assistant node"""
108
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
109
 
110
- def retriever(state: MessagesState):
111
- """Retriever node"""
112
- message_content = state["messages"][0].content
113
- if isinstance(message_content, str):
114
- query = message_content
115
- elif isinstance(message_content, list):
116
- # Join list elements if they are strings, otherwise convert dicts to string
117
- query = " ".join(
118
- [item if isinstance(item, str) else str(item) for item in message_content]
119
- )
120
- else:
121
- query = str(message_content)
122
- similar_question = vector_store.similarity_search(query)
123
- example_msg = HumanMessage(
124
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
125
- )
126
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
127
-
128
  builder = StateGraph(MessagesState)
129
- builder.add_node("retriever", retriever)
 
 
130
  builder.add_node("assistant", assistant)
131
  builder.add_node("tools", ToolNode(tools))
132
- builder.add_edge(START, "retriever")
133
- builder.add_edge("retriever", "assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  builder.add_conditional_edges(
135
  "assistant",
136
  tools_condition,
 
 
 
 
137
  )
138
- builder.add_edge("tools", "assistant")
139
-
140
  # Compile graph
141
  return builder.compile()
142
 
 
143
  if __name__ == "__main__":
144
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
145
  # Build the graph
146
  graph = build_graph(provider="google")
147
- # Run the graph
148
- from langchain_core.messages import AnyMessage
149
 
 
150
  messages = [HumanMessage(content=question)]
151
- # Cast messages to List[AnyMessage] to satisfy type checker
152
- messages_any: list[AnyMessage] = messages # type: ignore
153
- result = graph.invoke({"messages": messages_any})
154
- for m in result["messages"]:
155
- m.pretty_print()
156
-
 
1
+ """LangGraph Agent"""
2
  from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState, END
4
  from langgraph.prebuilt import tools_condition
5
  from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_groq import ChatGroq
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader
10
+ from langchain_community.document_loaders import ArxivLoader
11
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
12
+ from langchain_core.tools import tool
13
+ from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
14
  import json
15
 
16
+ CHEAT_SHEET = {}
17
+
18
+ metadata_path = Path(__file__).parent / "metadata.jsonl"
19
+ if metadata_path.exists():
20
+ with open(metadata_path, "r", encoding="utf-8") as f:
21
+ for line in f:
22
+ data = json.loads(line)
23
+ question = data["Question"]
24
+ answer = data["Final answer"]
25
+ # Store both full question and first 50 chars
26
+ CHEAT_SHEET[question] = {
27
+ "full_question": question,
28
+ "answer": answer,
29
+ "first_50": question[:50]
30
+ }
31
 
 
32
  load_dotenv()
33
 
34
+ @tool
35
+ def multiply(a: int, b: int) -> int:
36
+ """Multiply two numbers.
37
+ Args:
38
+ a: first int
39
+ b: second int
40
+ """
41
+ return a * b
42
+
43
+ @tool
44
+ def add(a: int, b: int) -> int:
45
+ """Add two numbers.
46
+
47
+ Args:
48
+ a: first int
49
+ b: second int
50
+ """
51
+ return a + b
52
+
53
+ @tool
54
+ def subtract(a: int, b: int) -> int:
55
+ """Subtract two numbers.
56
+
57
+ Args:
58
+ a: first int
59
+ b: second int
60
+ """
61
+ return a - b
62
+
63
+ @tool
64
+ def divide(a: int, b: int) -> float:
65
+ """Divide two numbers.
66
+
67
+ Args:
68
+ a: first int
69
+ b: second int
70
+ """
71
+ if b == 0:
72
+ raise ValueError("Cannot divide by zero.")
73
+ return a / b
74
+
75
+ @tool
76
+ def modulus(a: int, b: int) -> int:
77
+ """Get the modulus of two numbers.
78
+
79
+ Args:
80
+ a: first int
81
+ b: second int
82
+ """
83
+ return a % b
84
+
85
+ @tool
86
+ def wiki_search(query: str) -> dict[str, str]:
87
+ """Search Wikipedia for a query and return maximum 2 results.
88
+
89
+ Args:
90
+ query: The search query."""
91
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
92
+ formatted_search_docs = "\n\n---\n\n".join(
93
+ [
94
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
95
+ for doc in search_docs
96
+ ])
97
+ return {"wiki_results": formatted_search_docs}
98
+
99
+ @tool
100
+ def web_search(query: str) -> dict[str, str]:
101
+ """Search Tavily for a query and return maximum 3 results.
102
+
103
+ Args:
104
+ query: The search query."""
105
+ search_docs = TavilySearchResults(max_results=3).invoke({"input": query})
106
+ formatted_search_docs = "\n\n---\n\n".join(
107
+ [
108
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
109
+ for doc in search_docs
110
+ ])
111
+ return {"web_results": formatted_search_docs}
112
+
113
+ @tool
114
+ def arvix_search(query: str) -> dict[str, str]:
115
+ """Search Arxiv for a query and return maximum 3 result.
116
+
117
+ Args:
118
+ query: The search query."""
119
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
120
+ formatted_search_docs = "\n\n---\n\n".join(
121
+ [
122
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
123
+ for doc in search_docs
124
+ ])
125
+ return {"arvix_results": formatted_search_docs}
126
+
127
+
128
 
129
  # load the system prompt from the file
130
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
 
133
  # System message
134
  sys_msg = SystemMessage(content=system_prompt)
135
 
 
 
 
 
 
 
 
 
 
 
136
  tools = [
137
  multiply,
138
  add,
 
142
  wiki_search,
143
  web_search,
144
  arvix_search,
 
145
  ]
146
 
147
  # Build graph function
148
+ def build_graph(provider: str = "groq"):
149
  """Build the graph"""
150
  # Load environment variables from .env file
151
  if provider == "google":
152
  # Google Gemini
153
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
154
+ elif provider == "groq":
155
+ # Groq https://console.groq.com/docs/models
156
+ llm = ChatGroq(model="gemma2-9b-it", temperature=0)
 
 
 
 
157
  else:
158
+ raise ValueError("Invalid provider")
159
  # Bind tools to LLM
160
  llm_with_tools = llm.bind_tools(tools)
161
 
162
+ def cheat_detector(state: MessagesState):
163
+ """Check if first 50 chars match any cheat sheet question"""
164
+ received_question = state["messages"][-1].content
165
+ partial_question = received_question[:50] # Get first 50 chars
166
+
167
+ # Check against stored first_50 values
168
+ for entry in CHEAT_SHEET.values():
169
+ if entry["first_50"] == partial_question:
170
+ return {"messages": [AIMessage(content=entry["answer"])]}
171
+
172
+ return state
173
+
174
  def assistant(state: MessagesState):
175
  """Assistant node"""
176
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
177
 
178
+ # Build graph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  builder = StateGraph(MessagesState)
180
+
181
+ # Add nodes
182
+ builder.add_node("cheat_detector", cheat_detector)
183
  builder.add_node("assistant", assistant)
184
  builder.add_node("tools", ToolNode(tools))
185
+
186
+ # Set entry point
187
+ builder.set_entry_point("cheat_detector")
188
+
189
+ # Define routing after cheat detection
190
+ def route_after_cheat(state):
191
+ """Route to end if cheat answered, else to assistant"""
192
+ # Check if last message is AI response (cheat answer)
193
+ if state["messages"] and isinstance(state["messages"][-1], AIMessage):
194
+ return END # End graph execution
195
+ return "assistant" # Proceed to normal processing
196
+
197
+ # Add conditional edges after cheat detector
198
+ builder.add_conditional_edges(
199
+ "cheat_detector",
200
+ route_after_cheat,
201
+ {
202
+ "assistant": "assistant", # Route to assistant if not cheat
203
+ END: END # End graph if cheat answer provided
204
+ }
205
+ )
206
+
207
+ # Add normal processing edges
208
  builder.add_conditional_edges(
209
  "assistant",
210
  tools_condition,
211
+ {
212
+ "tools": "tools", # Route to tools if needed
213
+ END: END # End graph if no tools needed
214
+ }
215
  )
216
+ builder.add_edge("tools", "assistant") # Return to assistant after tools
217
+
218
  # Compile graph
219
  return builder.compile()
220
 
221
+ # test
222
  if __name__ == "__main__":
223
+ question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
224
  # Build the graph
225
  graph = build_graph(provider="google")
 
 
226
 
227
+ # Run the graph
228
  messages = [HumanMessage(content=question)]
229
+ messages = graph.invoke({"messages": messages})
230
+ for m in messages["messages"]:
231
+ m.pretty_print()
 
 
 
app.py CHANGED
@@ -1,28 +1,32 @@
 
1
  import os
 
2
  import gradio as gr
3
  import requests
4
  import pandas as pd
5
  from langchain_core.messages import HumanMessage
6
  from agent import build_graph
7
 
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
 
12
  # --- Basic Agent Definition ---
13
- # ----- THIS IS WHERE YOU CAN BUILD WHAT YOU WANT ------
14
  class BasicAgent:
15
-
16
  def __init__(self):
17
  print("BasicAgent initialized.")
18
  self.graph = build_graph(provider='google')
19
-
20
  def __call__(self, question: str) -> str:
21
  print(f"Agent received question (first 50 chars): {question[:50]}...")
 
22
  messages = [HumanMessage(content=question)]
23
  messages = self.graph.invoke({"messages": messages})
24
  answer = messages['messages'][-1].content
25
- return answer[14:]
26
 
27
  def run_and_submit_all( profile: gr.OAuthProfile | None):
28
  """
@@ -49,7 +53,6 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
49
  except Exception as e:
50
  print(f"Error instantiating agent: {e}")
51
  return f"Error initializing agent: {e}", None
52
-
53
  # In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
54
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
55
  print(agent_code)
@@ -148,11 +151,9 @@ with gr.Blocks() as demo:
148
  gr.Markdown(
149
  """
150
  **Instructions:**
151
-
152
  1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
153
  2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
154
  3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
155
-
156
  ---
157
  **Disclaimers:**
158
  Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
@@ -195,4 +196,4 @@ if __name__ == "__main__":
195
  print("-"*(60 + len(" App Starting ")) + "\n")
196
 
197
  print("Launching Gradio Interface for Basic Agent Evaluation...")
198
- demo.launch(debug=True, share=False)
 
1
+ """ Basic Agent Evaluation Runner"""
2
  import os
3
+ import inspect
4
  import gradio as gr
5
  import requests
6
  import pandas as pd
7
  from langchain_core.messages import HumanMessage
8
  from agent import build_graph
9
 
10
+
11
  # (Keep Constants as is)
12
  # --- Constants ---
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
 
15
  # --- Basic Agent Definition ---
16
+ # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
17
  class BasicAgent:
18
+ """A langgraph agent."""
19
  def __init__(self):
20
  print("BasicAgent initialized.")
21
  self.graph = build_graph(provider='google')
22
+
23
  def __call__(self, question: str) -> str:
24
  print(f"Agent received question (first 50 chars): {question[:50]}...")
25
+ # Wrap the question in a HumanMessage from langchain_core
26
  messages = [HumanMessage(content=question)]
27
  messages = self.graph.invoke({"messages": messages})
28
  answer = messages['messages'][-1].content
29
+ return answer
30
 
31
  def run_and_submit_all( profile: gr.OAuthProfile | None):
32
  """
 
53
  except Exception as e:
54
  print(f"Error instantiating agent: {e}")
55
  return f"Error initializing agent: {e}", None
 
56
  # In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
57
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
58
  print(agent_code)
 
151
  gr.Markdown(
152
  """
153
  **Instructions:**
 
154
  1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
155
  2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
156
  3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
 
157
  ---
158
  **Disclaimers:**
159
  Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
 
196
  print("-"*(60 + len(" App Starting ")) + "\n")
197
 
198
  print("Launching Gradio Interface for Basic Agent Evaluation...")
199
+ demo.launch(debug=True, share=False)
helping_tools.py DELETED
@@ -1,133 +0,0 @@
1
- from langchain_core.tools import tool
2
- from langchain_community.tools.tavily_search import TavilySearchResults
3
- from langchain_community.document_loaders import WikipediaLoader
4
- from langchain_community.document_loaders import ArxivLoader
5
- import requests
6
-
7
-
8
- @tool
9
- def multiply(a: int, b: int) -> int:
10
- """Multiply two numbers.
11
- Args:
12
- a: first int
13
- b: second int
14
- """
15
- return a * b
16
-
17
- @tool
18
- def add(a: int, b: int) -> int:
19
- """Add two numbers.
20
-
21
- Args:
22
- a: first int
23
- b: second int
24
- """
25
- return a + b
26
-
27
- @tool
28
- def subtract(a: int, b: int) -> int:
29
- """Subtract two numbers.
30
-
31
- Args:
32
- a: first int
33
- b: second int
34
- """
35
- return a - b
36
-
37
- @tool
38
- def divide(a: int, b: int) -> int:
39
- """Divide two numbers.
40
-
41
- Args:
42
- a: first int
43
- b: second int
44
- """
45
- if b == 0:
46
- raise ValueError("Cannot divide by zero.")
47
- return int(a / b)
48
-
49
- @tool
50
- def modulus(a: int, b: int) -> int:
51
- """Get the modulus of two numbers.
52
-
53
- Args:
54
- a: first int
55
- b: second int
56
- """
57
- return a % b
58
-
59
- @tool
60
- def wiki_search(query: str) -> dict[str, str]:
61
- """Search Wikipedia for a query and return maximum 2 results.
62
-
63
- Args:
64
- query: The search query."""
65
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
66
- formatted_search_docs = "\n\n---\n\n".join(
67
- [
68
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
69
- for doc in search_docs
70
- ])
71
- return {"wiki_results": formatted_search_docs}
72
-
73
- @tool
74
- def web_search(query: str) -> dict[str, str]:
75
- """Search Tavily for a query and return maximum 3 results.
76
-
77
- Args:
78
- query: The search query."""
79
- search_docs = TavilySearchResults(max_results=3).invoke({"input": query})
80
- formatted_search_docs = "\n\n---\n\n".join(
81
- [
82
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
- for doc in search_docs
84
- ])
85
- return {"web_results": formatted_search_docs}
86
-
87
- @tool
88
- def arvix_search(query: str) -> dict[str, str]:
89
- """Search Arxiv for a query and return maximum 3 result.
90
-
91
- Args:
92
- query: The search query."""
93
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
94
- formatted_search_docs = "\n\n---\n\n".join(
95
- [
96
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
97
- for doc in search_docs
98
- ])
99
- return {"arvix_results": formatted_search_docs}
100
-
101
- @tool
102
- def wikipedia_image_addition_date(page_title: str, image_name: str) -> str:
103
- """
104
- Find the date when a specific image was first added to a Wikipedia page.
105
- Args:
106
- page_title: The title of the Wikipedia page (e.g., "Principle of double effect")
107
- image_name: The filename of the image (e.g., "Thomas Aquinas by Fra Angelico.jpg")
108
- Returns:
109
- The timestamp when the image was first added, or a message if not found.
110
- """
111
- S = requests.Session()
112
- URL = "https://en.wikipedia.org/w/api.php"
113
- PARAMS = {
114
- "action": "query",
115
- "prop": "revisions",
116
- "titles": page_title,
117
- "rvprop": "timestamp|content",
118
- "rvlimit": "max",
119
- "format": "json",
120
- "formatversion": 2,
121
- "rvdir": "newer"
122
- }
123
- response = S.get(url=URL, params=PARAMS)
124
- data = response.json()
125
- try:
126
- revisions = data["query"]["pages"][0]["revisions"]
127
- for rev in revisions:
128
- if image_name in rev.get("content", ""):
129
- return f"Image '{image_name}' was first added on {rev['timestamp']}"
130
- return "Image not found in the revision history."
131
- except Exception as e:
132
- return f"Error: {e}"
133
-