modelx / src /graphs /dataRetrievalAgentGraph.py
nivakaran's picture
Upload folder using huggingface_hub
b4c4175 verified
"""
dataRetrievalAgentGraph.py - Data Retrieval Agent Graph Builder
"""
from langgraph.graph import StateGraph, START, END
from src.llms.groqllm import GroqLLM
from src.states.dataRetrievalAgentState import DataRetrievalAgentState
from src.nodes.dataRetrievalAgentNode import DataRetrievalAgentNode
class DataRetrievalAgentGraph(DataRetrievalAgentNode):
def __init__(self, llm):
super().__init__(llm)
self.llm = llm
def prepare_worker_tasks(self, state: DataRetrievalAgentState) -> dict:
tasks = state.generated_tasks
initial_states = [{"generated_tasks": [task]} for task in tasks]
return {"tasks_for_workers": initial_states}
def create_worker_graph(self):
worker_graph_builder = StateGraph(DataRetrievalAgentState)
worker_graph_builder.add_node("worker_agent", self.worker_agent_node)
worker_graph_builder.add_node("tool_node", self.tool_node)
worker_graph_builder.set_entry_point("worker_agent")
worker_graph_builder.add_edge("worker_agent", "tool_node")
worker_graph_builder.add_edge("tool_node", END)
return worker_graph_builder.compile()
def aggregate_results(self, state: DataRetrievalAgentState) -> dict:
worker_outputs = getattr(state, "worker", [])
new_results = []
if isinstance(worker_outputs, list):
for output in worker_outputs:
if "worker_results" in output and output["worker_results"]:
new_results.extend(output["worker_results"])
return {"worker_results": new_results, "latest_worker_results": new_results}
def format_output(self, state: DataRetrievalAgentState) -> dict:
classified_events = state.classified_buffer
insights = []
for event in classified_events:
insights.append(
{
"source_event_id": event.event_id,
"domain": event.target_agent,
"severity": "medium",
"summary": event.content_summary,
"risk_score": event.confidence_score,
}
)
print(f"[DATA RETRIEVAL] Formatted {len(insights)} insights for parent graph")
return {"domain_insights": insights}
def build_data_retrieval_agent_graph(self):
worker_graph = self.create_worker_graph()
workflow = StateGraph(DataRetrievalAgentState)
workflow.add_node("master_delegator", self.master_agent_node)
workflow.add_node("prepare_worker_tasks", self.prepare_worker_tasks)
workflow.add_node(
"worker",
lambda state: {
"worker": worker_graph.map().invoke(state.tasks_for_workers)
},
)
workflow.add_node("aggregate_results", self.aggregate_results)
workflow.add_node("classifier_agent", self.classifier_agent_node)
workflow.add_node("format_output", self.format_output)
workflow.set_entry_point("master_delegator")
workflow.add_edge("master_delegator", "prepare_worker_tasks")
workflow.add_edge("prepare_worker_tasks", "worker")
workflow.add_edge("worker", "aggregate_results")
workflow.add_edge("aggregate_results", "classifier_agent")
workflow.add_edge("classifier_agent", "format_output")
workflow.add_edge("format_output", END)
return workflow.compile()
llm = GroqLLM().get_llm()
graph_builder = DataRetrievalAgentGraph(llm)
graph = graph_builder.build_data_retrieval_agent_graph()