File size: 3,435 Bytes
b4856f1
2473009
b4856f1
752f5cc
b4856f1
 
 
 
 
 
 
 
 
 
752f5cc
b4c4175
 
 
b4856f1
 
 
 
 
752f5cc
b4856f1
 
752f5cc
b4856f1
 
 
752f5cc
b4856f1
 
 
752f5cc
b4856f1
 
 
752f5cc
b4856f1
752f5cc
b4c4175
 
 
b4856f1
752f5cc
b4856f1
 
 
752f5cc
b4856f1
 
 
 
752f5cc
b4856f1
752f5cc
b4856f1
 
752f5cc
b4856f1
 
 
752f5cc
b4856f1
752f5cc
b4c4175
 
 
 
 
 
 
 
 
b4856f1
752f5cc
b4856f1
 
752f5cc
b4856f1
 
 
 
752f5cc
b4856f1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
meteorologicalAgentGraph.py - Meteorological Agent Graph with Subgraph Architecture
"""

import uuid
from langgraph.graph import StateGraph, END
from src.states.meteorologicalAgentState import MeteorologicalAgentState
from src.nodes.meteorologicalAgentNode import MeteorologicalAgentNode
from src.llms.groqllm import GroqLLM


class MeteorologicalGraphBuilder:
    def __init__(self, llm):
        self.llm = llm

    def build_official_sources_subgraph(
        self, node: MeteorologicalAgentNode
    ) -> StateGraph:
        subgraph = StateGraph(MeteorologicalAgentState)
        subgraph.add_node("collect_official", node.collect_official_sources)
        subgraph.set_entry_point("collect_official")
        subgraph.add_edge("collect_official", END)
        return subgraph.compile()

    def build_social_media_subgraph(self, node: MeteorologicalAgentNode) -> StateGraph:
        subgraph = StateGraph(MeteorologicalAgentState)

        subgraph.add_node("national_social", node.collect_national_social_media)
        subgraph.add_node("district_social", node.collect_district_social_media)
        subgraph.add_node("climate_alerts", node.collect_climate_alerts)

        subgraph.set_entry_point("national_social")
        subgraph.set_entry_point("district_social")
        subgraph.set_entry_point("climate_alerts")

        subgraph.add_edge("national_social", END)
        subgraph.add_edge("district_social", END)
        subgraph.add_edge("climate_alerts", END)

        return subgraph.compile()

    def build_feed_generation_subgraph(
        self, node: MeteorologicalAgentNode
    ) -> StateGraph:
        subgraph = StateGraph(MeteorologicalAgentState)

        subgraph.add_node("categorize", node.categorize_by_geography)
        subgraph.add_node("llm_summary", node.generate_llm_summary)
        subgraph.add_node("format_output", node.format_final_output)

        subgraph.set_entry_point("categorize")
        subgraph.add_edge("categorize", "llm_summary")
        subgraph.add_edge("llm_summary", "format_output")
        subgraph.add_edge("format_output", END)

        return subgraph.compile()

    def build_graph(self):
        node = MeteorologicalAgentNode(self.llm)

        official_subgraph = self.build_official_sources_subgraph(node)
        social_subgraph = self.build_social_media_subgraph(node)
        feed_subgraph = self.build_feed_generation_subgraph(node)

        main_graph = StateGraph(MeteorologicalAgentState)

        main_graph.add_node(
            "official_sources_module", lambda state: official_subgraph.invoke(state)
        )
        main_graph.add_node(
            "social_media_module", lambda state: social_subgraph.invoke(state)
        )
        main_graph.add_node(
            "feed_generation_module", lambda state: feed_subgraph.invoke(state)
        )
        main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)

        main_graph.set_entry_point("official_sources_module")
        main_graph.set_entry_point("social_media_module")

        main_graph.add_edge("official_sources_module", "feed_generation_module")
        main_graph.add_edge("social_media_module", "feed_generation_module")
        main_graph.add_edge("feed_generation_module", "feed_aggregator")
        main_graph.add_edge("feed_aggregator", END)

        return main_graph.compile()


llm = GroqLLM().get_llm()
graph = MeteorologicalGraphBuilder(llm).build_graph()