Spaces:
Runtime error
Runtime error
File size: 6,138 Bytes
0a25329 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
from typing import List, Literal, Optional, TypedDict
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field
from qdrant_client.http.models import (
FieldCondition,
Filter,
MatchValue,
)
from clients import LLM, VECTOR_STORE
class RetrievalState(TypedDict):
"""State for the agentic retrieval graph."""
original_query: str
current_query: str
category: Optional[str]
topic: Optional[str]
documents: List[Document]
relevant_documents: List[Document]
generation: str
retry_count: int
max_retries: int
class GradeDocuments(BaseModel):
"""Grade whether a document is relevant to the query."""
is_relevant: Literal["yes", "no"] = Field(
description="Is the document relevant to the query? 'yes' or 'no'"
)
reason: str = Field(description="Brief reason for the relevance decision")
def retrieve_documents(state: RetrievalState) -> RetrievalState:
"""Retrieve documents from vector store."""
query = state["current_query"]
category = state.get("category")
topic = state.get("topic")
# Build Qdrant filter
conditions = []
if category:
conditions.append(
FieldCondition(key="metadata.category", match=MatchValue(value=category))
)
if topic:
conditions.append(
FieldCondition(key="metadata.topic", match=MatchValue(value=topic))
)
qdrant_filter = Filter(must=conditions) if conditions else None
documents = VECTOR_STORE.similarity_search(
query,
k=5,
filter=qdrant_filter,
)
return {**state, "documents": documents}
def grade_documents(state: RetrievalState) -> RetrievalState:
"""Grade documents for relevance using LLM."""
query = state["original_query"]
documents = state["documents"]
if not documents:
return {**state, "relevant_documents": []}
# Create grader with structured output
grader_llm = LLM.with_structured_output(GradeDocuments)
grading_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a grader assessing relevance of a retrieved document to a user query.
If the document contains keywords or semantic meaning related to the query, grade it as relevant.
Be lenient - even partial relevance should be marked as 'yes'.
Only mark 'no' if the document is completely unrelated.""",
),
(
"human",
"""Query: {query}
Document content: {document}
Is this document relevant to the query?""",
),
]
)
relevant_docs = []
for doc in documents:
try:
result = grader_llm.invoke(
grading_prompt.format_messages(
query=query,
document=doc.page_content[:1000], # Limit content length
)
)
if result.is_relevant == "yes":
relevant_docs.append(doc)
except Exception:
# If grading fails, include the document (fail-safe)
relevant_docs.append(doc)
return {**state, "relevant_documents": relevant_docs}
def rewrite_query(state: RetrievalState) -> RetrievalState:
"""Rewrite the query for better retrieval."""
original_query = state["original_query"]
retry_count = state["retry_count"]
rewrite_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are an expert at reformulating search queries.
Given the original query, generate a better search query that might retrieve more relevant documents.
Focus on:
- Extracting key concepts and entities
- Using synonyms or related terms
- Being more specific or more general as appropriate
Return ONLY the rewritten query, nothing else.""",
),
("human", "Original query: {query}\n\nRewritten query:"),
]
)
response = LLM.invoke(rewrite_prompt.format_messages(query=original_query))
new_query = response.content.strip()
return {
**state,
"current_query": new_query,
"retry_count": retry_count + 1,
}
def generate_response(state: RetrievalState) -> RetrievalState:
"""Generate final response from relevant documents."""
relevant_docs = state["relevant_documents"]
if not relevant_docs:
return {**state, "generation": "No relevant memories found."}
# Format documents
formatted = []
for i, doc in enumerate(relevant_docs, 1):
meta = doc.metadata
formatted.append(
f"{i}. [{meta.get('category', 'N/A')}/{meta.get('topic', 'N/A')}]: {doc.page_content}"
)
return {**state, "generation": "\n".join(formatted)}
def should_retry(state: RetrievalState) -> Literal["rewrite", "generate"]:
"""Decide whether to retry with a rewritten query."""
relevant_docs = state["relevant_documents"]
retry_count = state["retry_count"]
max_retries = state["max_retries"]
# If we have relevant docs, generate response
if relevant_docs:
return "generate"
# If no relevant docs and we can still retry, rewrite query
if retry_count < max_retries:
return "rewrite"
# Max retries reached, generate (empty) response
return "generate"
def build_retrieval_graph():
workflow = StateGraph(RetrievalState)
# Add nodes
workflow.add_node("retrieve", retrieve_documents)
workflow.add_node("grade", grade_documents)
workflow.add_node("rewrite", rewrite_query)
workflow.add_node("generate", generate_response)
# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade")
workflow.add_conditional_edges(
"grade",
should_retry,
{
"rewrite": "rewrite",
"generate": "generate",
},
)
workflow.add_edge("rewrite", "retrieve")
workflow.add_edge("generate", END)
return workflow.compile()
|