Spaces:
Runtime error
Runtime error
Saiteja Solleti
commited on
Commit
·
a46269a
1
Parent(s):
748ac82
fine tuning and reranking is pushed
Browse files- app.py +7 -2
- finetuneresults.py +61 -0
- generationhelper.py +8 -0
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from createmilvusschema import CreateMilvusDbSchema
|
|
| 6 |
from insertmilvushelper import EmbedAllDocumentsAndInsert
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from searchmilvushelper import SearchTopKDocuments
|
|
|
|
| 9 |
|
| 10 |
from model import generate_response
|
| 11 |
from huggingface_hub import login
|
|
@@ -15,6 +16,7 @@ from huggingface_hub import dataset_info
|
|
| 15 |
|
| 16 |
# Load embedding model
|
| 17 |
QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
|
| 18 |
WINDOW_SIZE = 5
|
| 19 |
OVERLAP = 2
|
| 20 |
RETRIVE_TOP_K_SIZE=10
|
|
@@ -38,8 +40,11 @@ EmbedAllDocumentsAndInsert(QUERY_EMBEDDING_MODEL, rag_extracted_data, db_collect
|
|
| 38 |
"""
|
| 39 |
query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def chatbot(prompt):
|
|
|
|
| 6 |
from insertmilvushelper import EmbedAllDocumentsAndInsert
|
| 7 |
from sentence_transformers import SentenceTransformer
|
| 8 |
from searchmilvushelper import SearchTopKDocuments
|
| 9 |
+
from finetuneresults import FineTuneAndRerankSearchResults
|
| 10 |
|
| 11 |
from model import generate_response
|
| 12 |
from huggingface_hub import login
|
|
|
|
| 16 |
|
| 17 |
# Load embedding model
|
| 18 |
QUERY_EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
| 19 |
+
RERANKING_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 20 |
WINDOW_SIZE = 5
|
| 21 |
OVERLAP = 2
|
| 22 |
RETRIVE_TOP_K_SIZE=10
|
|
|
|
| 40 |
"""
|
| 41 |
query = "what would the net revenue have been in 2015 if there wasn't a stipulated settlement from the business combination in october 2015?"
|
| 42 |
|
| 43 |
+
results_for_top10_chunks = SearchTopKDocuments(db_collection, query, QUERY_EMBEDDING_MODEL, top_k=RETRIVE_TOP_K_SIZE)
|
| 44 |
+
|
| 45 |
+
reranked_results = FineTuneAndRerankSearchResults(results_for_top10_chunks, rag_extracted_data, query, RERANKING_MODEL)
|
| 46 |
+
|
| 47 |
+
print(reranked_results)
|
| 48 |
|
| 49 |
|
| 50 |
def chatbot(prompt):
|
finetuneresults.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import CrossEncoder
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Retrieves unique full documents based on the top-ranked document IDs.
|
| 5 |
+
|
| 6 |
+
Args:
|
| 7 |
+
top_documents (list): List of dictionaries containing 'doc_id'.
|
| 8 |
+
df (pd.DataFrame): The dataset containing document IDs and text.
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
pd.DataFrame: A DataFrame with 'doc_id' and 'document'.
|
| 12 |
+
"""
|
| 13 |
+
def retrieve_full_documents(top_documents, df):
|
| 14 |
+
|
| 15 |
+
# Extract unique doc_ids
|
| 16 |
+
unique_doc_ids = list(set(doc["doc_id"] for doc in top_documents))
|
| 17 |
+
|
| 18 |
+
# Print for debugging
|
| 19 |
+
print(f"Extracted Doc IDs: {unique_doc_ids}")
|
| 20 |
+
|
| 21 |
+
# Filter DataFrame where 'id' matches any of the unique_doc_ids
|
| 22 |
+
filtered_df = df[df["id"].isin(unique_doc_ids)][["id", "documents"]].drop_duplicates(subset="id")
|
| 23 |
+
|
| 24 |
+
# Rename columns for clarity
|
| 25 |
+
filtered_df = filtered_df.rename(columns={"id": "doc_id", "documents": "document"})
|
| 26 |
+
|
| 27 |
+
return filtered_df
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
Reranks the retrieved documents based on their relevance to the query using a Cross-Encoder model.
|
| 31 |
+
Args:
|
| 32 |
+
query (str): The search query.
|
| 33 |
+
retrieved_docs (pd.DataFrame): DataFrame with 'doc_id' and 'document'.
|
| 34 |
+
model_name (str): Name of the Cross-Encoder model.
|
| 35 |
+
Returns:
|
| 36 |
+
pd.DataFrame: A sorted DataFrame with doc_id, document, and reranking score.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def rerank_documents(query, retrieved_docs_df, model_name="cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
| 40 |
+
|
| 41 |
+
# Load Cross-Encoder model
|
| 42 |
+
model = CrossEncoder(model_name)
|
| 43 |
+
|
| 44 |
+
# Prepare query-document pairs
|
| 45 |
+
query_doc_pairs = [(query, " ".join(doc)) for doc in retrieved_docs_df["document"]]
|
| 46 |
+
|
| 47 |
+
# Compute relevance scores
|
| 48 |
+
scores = model.predict(query_doc_pairs)
|
| 49 |
+
|
| 50 |
+
# Add scores to the DataFrame
|
| 51 |
+
retrieved_docs_df["relevance_score"] = scores
|
| 52 |
+
|
| 53 |
+
# Sort by score in descending order (higher score = more relevant)
|
| 54 |
+
reranked_docs_df = retrieved_docs_df.sort_values(by="relevance_score", ascending=False).reset_index(drop=True)
|
| 55 |
+
|
| 56 |
+
return reranked_docs_df
|
| 57 |
+
|
| 58 |
+
def FineTuneAndRerankSearchResults(top_10_chunk_results, rag_extarcted_data, question, reranking_model):
|
| 59 |
+
unique_docs= retrieve_full_documents(top_10_chunk_results, rag_extarcted_data)
|
| 60 |
+
reranked_results = rerank_documents(question, unique_docs, reranking_model)
|
| 61 |
+
return rerank_documents
|
generationhelper.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from groq import Groq
|
| 3 |
+
|
| 4 |
+
groq_token = os.getenv("GROQ_TOKEN")
|
| 5 |
+
|
| 6 |
+
groq_client = Groq(
|
| 7 |
+
api_key = groq_token
|
| 8 |
+
)
|
requirements.txt
CHANGED
|
@@ -4,4 +4,5 @@ torch
|
|
| 4 |
huggingface_hub
|
| 5 |
pymilvus
|
| 6 |
nltk
|
| 7 |
-
sentence-transformers
|
|
|
|
|
|
| 4 |
huggingface_hub
|
| 5 |
pymilvus
|
| 6 |
nltk
|
| 7 |
+
sentence-transformers
|
| 8 |
+
Groq
|