Spaces:
Sleeping
Sleeping
File size: 6,659 Bytes
d796e74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_huggingface.llms import HuggingFacePipeline
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import os
from dotenv import load_dotenv
import tiktoken
load_dotenv()
#HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
#embeddings_model_name = "cointegrated/rubert-tiny2"
embeddings_model_name = "text-embedding-3-large"
llm_model_name = "gpt-4o-mini"
store_save_path = "stores/openai"
# Step 1: Document Loading and Splitting
def load_and_split_documents(pdf_path="docs/test_file.pdf"):
"""
Loads a PDF document and splits it into smaller chunks.
"""
loader = PyPDFLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=200
)
docs = text_splitter.split_documents(documents)
return docs
# Step 2: Embeddings and Vector Store
def get_vector_store(docs, store_save_path=store_save_path):
"""
Loads an existing vector store or creates a new one if it doesn't exist.
"""
if os.path.exists(store_save_path):
print("Loading vector store from disk...")
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
db = FAISS.load_local(store_save_path, embeddings, allow_dangerous_deserialization=True)
else:
print("Creating a new vector store...")
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
db = FAISS.from_documents(docs, embeddings)
db.save_local(store_save_path)
return db
# Step 3: Initialize the LLM
def initialize_llm():
"""
Initializes a Russian-specific LLM locally using transformers
"""
#repo_id = "ai-forever/rugpt3large_based_on_gpt2"
#repo_id = "ai-forever/ruBert-base"
#repo_id = "ai-forever/ruGPT-3.5-13B"
'''
llm = HuggingFaceEndpoint(
repo_id=repo_id,
temperature=0.5,
#max_new_tokens=300,
task='text-generation'
)
'''
llm = ChatOpenAI(
model=llm_model_name,
temperature=0.7
)
return llm
# Step 4: Create the LCEL RAG Chain
def setup_rag_chain(pdf_path):
"""
Sets up the complete Retrieval-Augmented Generation chain using LCEL.
"""
docs = load_and_split_documents(pdf_path)
db = get_vector_store(docs)
retriever = db.as_retriever()
llm = initialize_llm()
# Checking the vector store
#print(f"Number of vectors in FAISS index: {db.index.ntotal}")
# Define the prompt template
template = """Используйте следующие фрагменты контекста, чтобы ответить на вопрос в конце. Если вы не знаете ответа, просто скажите, что не знаете, не пытайтесь что-то придумать. Всегда будьте вежливым.
{context}
Вопрос: {question}
Полезный ответ:"""
prompt = PromptTemplate.from_template(template)
# Corrected RAG chain construction
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
# Initialize the chain
document_name = "docs/test_file.pdf"
qa_chain = setup_rag_chain(pdf_path=document_name)
# Gradio Interface
def chat_with_doc(query):
"""
Function to handle the user query and return a response.
"""
try:
# Pass the query directly, not as a dictionary
result = qa_chain.invoke(query)
return result
except Exception as e:
return f"Произошла ошибка: {type(e).__name__} - {e!r}"
def count_tokens(text, model_name):
encoding = tiktoken.encoding_for_model(model_name)
num_tokens = len(encoding.encode(text))
return num_tokens
iface = gr.Interface(
fn=chat_with_doc,
inputs=gr.Textbox(lines=5, placeholder="Спросите что-нибудь о документе..."),
outputs="text",
title="RAG LLM модель для AIGINEER",
description="Задайте вопрос о содержании документации",
)
css_code = """
#submit-button {
background-color: #4CAF50 !important;
color: white !important;
}
#centered-text {
text-align: center;
//justify-content: center;
}
#fixed-height-textarea textarea {
overflow-y: auto !important;
}
"""
heading_text = "# AIGINEER-ИИ Модель"
subheading_text = 'Узнайте любую информацию о нормативно-технической документации (НТД) со 100% точностью при помощи ИИ модели AIGINEER'
with gr.Blocks(css=css_code) as demo:
gr.Markdown(heading_text, elem_id='centered-text')
gr.Markdown(subheading_text, elem_id='centered-text')
with gr.Row(scale=1):
with gr.Column():
query_input = gr.Textbox(interactive=True, label='Вопрос', lines=5, placeholder="Спросите что-нибудь о документе...")
with gr.Row():
clear_button = gr.ClearButton(components=[query_input], variant='secondary', value='Очистить')
submit_button = gr.Button(variant='primary', value='Отправить')
#with gr.Column():
# count_tokens_output = gr.TextArea(interactive=False, label='Стоимость запроса в токенах')
# count_tokens_button = gr.Button(variant='secondary', value='Посчитать стоимость в токенах')
response_output = gr.TextArea(interactive=True, label='Ответ', lines=8, placeholder='Тут будет отображаться ответ.')
submit_button.click(fn=chat_with_doc, inputs=query_input, outputs=response_output)
#count_tokens_button.click(fn=lambda text_input: count_tokens(text_input, llm_model_name), inputs=[query_input], outputs=[count_tokens_output])
# Launch the Gradio app
if __name__ == "__main__":
# Uncomment to run as CLI
#query = input(f"Спросите что нибудь о документе {document_name}: ")
#result = chat_with_doc(query)
#print(result)
# Run Gradio app
demo.launch() |