| """ | |
| dataRetrievalAgentGraph.py - Data Retrieval Agent Graph Builder | |
| """ | |
| from langgraph.graph import StateGraph, START, END | |
| from src.llms.groqllm import GroqLLM | |
| from src.states.dataRetrievalAgentState import DataRetrievalAgentState | |
| from src.nodes.dataRetrievalAgentNode import DataRetrievalAgentNode | |
| class DataRetrievalAgentGraph(DataRetrievalAgentNode): | |
| def __init__(self, llm): | |
| super().__init__(llm) | |
| self.llm = llm | |
| def prepare_worker_tasks(self, state: DataRetrievalAgentState) -> dict: | |
| tasks = state.generated_tasks | |
| initial_states = [{"generated_tasks": [task]} for task in tasks] | |
| return {"tasks_for_workers": initial_states} | |
| def create_worker_graph(self): | |
| worker_graph_builder = StateGraph(DataRetrievalAgentState) | |
| worker_graph_builder.add_node("worker_agent", self.worker_agent_node) | |
| worker_graph_builder.add_node("tool_node", self.tool_node) | |
| worker_graph_builder.set_entry_point("worker_agent") | |
| worker_graph_builder.add_edge("worker_agent", "tool_node") | |
| worker_graph_builder.add_edge("tool_node", END) | |
| return worker_graph_builder.compile() | |
| def aggregate_results(self, state: DataRetrievalAgentState) -> dict: | |
| worker_outputs = getattr(state, "worker", []) | |
| new_results = [] | |
| if isinstance(worker_outputs, list): | |
| for output in worker_outputs: | |
| if "worker_results" in output and output["worker_results"]: | |
| new_results.extend(output["worker_results"]) | |
| return {"worker_results": new_results, "latest_worker_results": new_results} | |
| def format_output(self, state: DataRetrievalAgentState) -> dict: | |
| classified_events = state.classified_buffer | |
| insights = [] | |
| for event in classified_events: | |
| insights.append( | |
| { | |
| "source_event_id": event.event_id, | |
| "domain": event.target_agent, | |
| "severity": "medium", | |
| "summary": event.content_summary, | |
| "risk_score": event.confidence_score, | |
| } | |
| ) | |
| print(f"[DATA RETRIEVAL] Formatted {len(insights)} insights for parent graph") | |
| return {"domain_insights": insights} | |
| def build_data_retrieval_agent_graph(self): | |
| worker_graph = self.create_worker_graph() | |
| workflow = StateGraph(DataRetrievalAgentState) | |
| workflow.add_node("master_delegator", self.master_agent_node) | |
| workflow.add_node("prepare_worker_tasks", self.prepare_worker_tasks) | |
| workflow.add_node( | |
| "worker", | |
| lambda state: { | |
| "worker": worker_graph.map().invoke(state.tasks_for_workers) | |
| }, | |
| ) | |
| workflow.add_node("aggregate_results", self.aggregate_results) | |
| workflow.add_node("classifier_agent", self.classifier_agent_node) | |
| workflow.add_node("format_output", self.format_output) | |
| workflow.set_entry_point("master_delegator") | |
| workflow.add_edge("master_delegator", "prepare_worker_tasks") | |
| workflow.add_edge("prepare_worker_tasks", "worker") | |
| workflow.add_edge("worker", "aggregate_results") | |
| workflow.add_edge("aggregate_results", "classifier_agent") | |
| workflow.add_edge("classifier_agent", "format_output") | |
| workflow.add_edge("format_output", END) | |
| return workflow.compile() | |
| llm = GroqLLM().get_llm() | |
| graph_builder = DataRetrievalAgentGraph(llm) | |
| graph = graph_builder.build_data_retrieval_agent_graph() | |