|
|
from typing import List, Literal, Optional, TypedDict |
|
|
|
|
|
from langchain_core.documents import Document |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langgraph.graph import END, START, StateGraph |
|
|
from pydantic import BaseModel, Field |
|
|
from qdrant_client.http.models import ( |
|
|
FieldCondition, |
|
|
Filter, |
|
|
MatchValue, |
|
|
) |
|
|
|
|
|
from clients import LLM, VECTOR_STORE |
|
|
|
|
|
|
|
|
class RetrievalState(TypedDict): |
|
|
"""State for the agentic retrieval graph.""" |
|
|
|
|
|
original_query: str |
|
|
current_query: str |
|
|
category: Optional[str] |
|
|
topic: Optional[str] |
|
|
documents: List[Document] |
|
|
relevant_documents: List[Document] |
|
|
generation: str |
|
|
retry_count: int |
|
|
max_retries: int |
|
|
|
|
|
|
|
|
class GradeDocuments(BaseModel): |
|
|
"""Grade whether a document is relevant to the query.""" |
|
|
|
|
|
is_relevant: Literal["yes", "no"] = Field( |
|
|
description="Is the document relevant to the query? 'yes' or 'no'" |
|
|
) |
|
|
reason: str = Field(description="Brief reason for the relevance decision") |
|
|
|
|
|
|
|
|
def retrieve_documents(state: RetrievalState) -> RetrievalState: |
|
|
"""Retrieve documents from vector store.""" |
|
|
query = state["current_query"] |
|
|
category = state.get("category") |
|
|
topic = state.get("topic") |
|
|
|
|
|
|
|
|
conditions = [] |
|
|
if category: |
|
|
conditions.append( |
|
|
FieldCondition(key="metadata.category", match=MatchValue(value=category)) |
|
|
) |
|
|
if topic: |
|
|
conditions.append( |
|
|
FieldCondition(key="metadata.topic", match=MatchValue(value=topic)) |
|
|
) |
|
|
|
|
|
qdrant_filter = Filter(must=conditions) if conditions else None |
|
|
|
|
|
documents = VECTOR_STORE.similarity_search( |
|
|
query, |
|
|
k=5, |
|
|
filter=qdrant_filter, |
|
|
) |
|
|
|
|
|
return {**state, "documents": documents} |
|
|
|
|
|
|
|
|
def grade_documents(state: RetrievalState) -> RetrievalState: |
|
|
"""Grade documents for relevance using LLM.""" |
|
|
query = state["original_query"] |
|
|
documents = state["documents"] |
|
|
|
|
|
if not documents: |
|
|
return {**state, "relevant_documents": []} |
|
|
|
|
|
|
|
|
grader_llm = LLM.with_structured_output(GradeDocuments) |
|
|
|
|
|
grading_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
( |
|
|
"system", |
|
|
"""You are a grader assessing relevance of a retrieved document to a user query. |
|
|
|
|
|
If the document contains keywords or semantic meaning related to the query, grade it as relevant. |
|
|
Be lenient - even partial relevance should be marked as 'yes'. |
|
|
Only mark 'no' if the document is completely unrelated.""", |
|
|
), |
|
|
( |
|
|
"human", |
|
|
"""Query: {query} |
|
|
|
|
|
Document content: {document} |
|
|
|
|
|
Is this document relevant to the query?""", |
|
|
), |
|
|
] |
|
|
) |
|
|
|
|
|
relevant_docs = [] |
|
|
for doc in documents: |
|
|
try: |
|
|
result = grader_llm.invoke( |
|
|
grading_prompt.format_messages( |
|
|
query=query, |
|
|
document=doc.page_content[:1000], |
|
|
) |
|
|
) |
|
|
if result.is_relevant == "yes": |
|
|
relevant_docs.append(doc) |
|
|
except Exception: |
|
|
|
|
|
relevant_docs.append(doc) |
|
|
|
|
|
return {**state, "relevant_documents": relevant_docs} |
|
|
|
|
|
|
|
|
def rewrite_query(state: RetrievalState) -> RetrievalState: |
|
|
"""Rewrite the query for better retrieval.""" |
|
|
original_query = state["original_query"] |
|
|
retry_count = state["retry_count"] |
|
|
|
|
|
rewrite_prompt = ChatPromptTemplate.from_messages( |
|
|
[ |
|
|
( |
|
|
"system", |
|
|
"""You are an expert at reformulating search queries. |
|
|
Given the original query, generate a better search query that might retrieve more relevant documents. |
|
|
|
|
|
Focus on: |
|
|
- Extracting key concepts and entities |
|
|
- Using synonyms or related terms |
|
|
- Being more specific or more general as appropriate |
|
|
|
|
|
Return ONLY the rewritten query, nothing else.""", |
|
|
), |
|
|
("human", "Original query: {query}\n\nRewritten query:"), |
|
|
] |
|
|
) |
|
|
|
|
|
response = LLM.invoke(rewrite_prompt.format_messages(query=original_query)) |
|
|
|
|
|
new_query = response.content.strip() |
|
|
|
|
|
return { |
|
|
**state, |
|
|
"current_query": new_query, |
|
|
"retry_count": retry_count + 1, |
|
|
} |
|
|
|
|
|
|
|
|
def generate_response(state: RetrievalState) -> RetrievalState: |
|
|
"""Generate final response from relevant documents.""" |
|
|
relevant_docs = state["relevant_documents"] |
|
|
|
|
|
if not relevant_docs: |
|
|
return {**state, "generation": "No relevant memories found."} |
|
|
|
|
|
|
|
|
formatted = [] |
|
|
for i, doc in enumerate(relevant_docs, 1): |
|
|
meta = doc.metadata |
|
|
formatted.append( |
|
|
f"{i}. [{meta.get('category', 'N/A')}/{meta.get('topic', 'N/A')}]: {doc.page_content}" |
|
|
) |
|
|
|
|
|
return {**state, "generation": "\n".join(formatted)} |
|
|
|
|
|
|
|
|
def should_retry(state: RetrievalState) -> Literal["rewrite", "generate"]: |
|
|
"""Decide whether to retry with a rewritten query.""" |
|
|
relevant_docs = state["relevant_documents"] |
|
|
retry_count = state["retry_count"] |
|
|
max_retries = state["max_retries"] |
|
|
|
|
|
|
|
|
if relevant_docs: |
|
|
return "generate" |
|
|
|
|
|
|
|
|
if retry_count < max_retries: |
|
|
return "rewrite" |
|
|
|
|
|
|
|
|
return "generate" |
|
|
|
|
|
|
|
|
def build_retrieval_graph(): |
|
|
workflow = StateGraph(RetrievalState) |
|
|
|
|
|
|
|
|
workflow.add_node("retrieve", retrieve_documents) |
|
|
workflow.add_node("grade", grade_documents) |
|
|
workflow.add_node("rewrite", rewrite_query) |
|
|
workflow.add_node("generate", generate_response) |
|
|
|
|
|
|
|
|
workflow.add_edge(START, "retrieve") |
|
|
workflow.add_edge("retrieve", "grade") |
|
|
workflow.add_conditional_edges( |
|
|
"grade", |
|
|
should_retry, |
|
|
{ |
|
|
"rewrite": "rewrite", |
|
|
"generate": "generate", |
|
|
}, |
|
|
) |
|
|
workflow.add_edge("rewrite", "retrieve") |
|
|
workflow.add_edge("generate", END) |
|
|
|
|
|
return workflow.compile() |
|
|
|