Starberry15 commited on
Commit
32b1a73
Β·
verified Β·
1 Parent(s): 42984b6

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +43 -10
src/streamlit_app.py CHANGED
@@ -7,6 +7,7 @@ from typing import List, Dict, Any
7
  import numpy as np
8
  import streamlit as st
9
  import PyPDF2
 
10
  from dotenv import load_dotenv
11
  from huggingface_hub import InferenceClient, login
12
  from streamlit_chat import message as st_message
@@ -77,6 +78,7 @@ def find_handbook() -> List[str]:
77
  st.error("❌ No PDF found in the same folder as this app.")
78
  return []
79
 
 
80
  def load_pdf_texts(pdf_paths: List[str]) -> List[Dict[str, Any]]:
81
  """Extract text from all pages of provided PDFs."""
82
  pages = []
@@ -89,6 +91,7 @@ def load_pdf_texts(pdf_paths: List[str]) -> List[Dict[str, Any]]:
89
  pages.append({"filename": os.path.basename(path), "page": i + 1, "text": text})
90
  return pages
91
 
 
92
  def chunk_text(pages: List[Dict[str, Any]], size: int, overlap: int) -> List[Dict[str, Any]]:
93
  """Split text into overlapping chunks."""
94
  chunks = []
@@ -106,20 +109,45 @@ def chunk_text(pages: List[Dict[str, Any]], size: int, overlap: int) -> List[Dic
106
  start += size - overlap
107
  return chunks
108
 
 
109
  def embed_texts(texts: List[str]) -> np.ndarray:
110
- """Get embeddings via Hugging Face Inference API."""
111
- if not hf_client:
112
- st.error("❌ No Hugging Face client initialized.")
113
  return np.zeros((len(texts), 768))
 
 
114
  try:
115
- emb = hf_client.post(
116
- "/embeddings",
117
- json={"inputs": texts, "model": EMBED_MODEL},
118
  )
119
- return np.array(emb["embeddings"])
120
- except Exception as e:
121
- st.error(f"Embedding error: {e}")
122
- return np.zeros((len(texts), 768))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  def build_faiss_index(chunks: List[Dict[str, Any]]) -> None:
125
  """Build and save FAISS index for handbook chunks."""
@@ -139,6 +167,7 @@ def build_faiss_index(chunks: List[Dict[str, Any]]) -> None:
139
  with open(EMB_DIM_FILE, "w") as f:
140
  json.dump({"dim": dim}, f)
141
 
 
142
  def load_faiss_index():
143
  """Load FAISS index and metadata if available."""
144
  if not (os.path.exists(INDEX_FILE) and os.path.exists(META_FILE)):
@@ -148,6 +177,7 @@ def load_faiss_index():
148
  meta = json.load(f)
149
  return index, meta
150
 
 
151
  def search_index(query: str, index, meta, top_k: int, threshold: float) -> List[Dict[str, Any]]:
152
  """Search FAISS for top-K similar chunks."""
153
  query_emb = embed_texts([query])
@@ -160,6 +190,7 @@ def search_index(query: str, index, meta, top_k: int, threshold: float) -> List[
160
  results.append(result)
161
  return results
162
 
 
163
  def generate_answer(context: str, query: str) -> str:
164
  """Generate robust answer with explicit citations β€” auto-switches between endpoints."""
165
  prompt = f"""
@@ -210,6 +241,7 @@ If the answer is not explicitly found, respond with:
210
  except Exception as e2:
211
  return f"⚠️ Error generating answer: {e2}"
212
 
 
213
  # =============================================================
214
  # πŸ” Index Handling
215
  # =============================================================
@@ -233,6 +265,7 @@ def ensure_index():
233
  st.stop()
234
  return index, meta
235
 
 
236
  # =============================================================
237
  # πŸ’¬ Chat Interface
238
  # =============================================================
 
7
  import numpy as np
8
  import streamlit as st
9
  import PyPDF2
10
+ import requests
11
  from dotenv import load_dotenv
12
  from huggingface_hub import InferenceClient, login
13
  from streamlit_chat import message as st_message
 
78
  st.error("❌ No PDF found in the same folder as this app.")
79
  return []
80
 
81
+
82
  def load_pdf_texts(pdf_paths: List[str]) -> List[Dict[str, Any]]:
83
  """Extract text from all pages of provided PDFs."""
84
  pages = []
 
91
  pages.append({"filename": os.path.basename(path), "page": i + 1, "text": text})
92
  return pages
93
 
94
+
95
  def chunk_text(pages: List[Dict[str, Any]], size: int, overlap: int) -> List[Dict[str, Any]]:
96
  """Split text into overlapping chunks."""
97
  chunks = []
 
109
  start += size - overlap
110
  return chunks
111
 
112
+
113
  def embed_texts(texts: List[str]) -> np.ndarray:
114
+ """Get embeddings via Hugging Face Inference API with fallback."""
115
+ if not HF_TOKEN:
116
+ st.error("❌ Missing HF_TOKEN.")
117
  return np.zeros((len(texts), 768))
118
+
119
+ # --- Primary method: InferenceClient.feature_extraction ---
120
  try:
121
+ embeddings = hf_client.feature_extraction(
122
+ model=EMBED_MODEL,
123
+ inputs=texts
124
  )
125
+
126
+ # Handle nested list outputs (token-level vectors)
127
+ if isinstance(embeddings[0][0], list):
128
+ embeddings = [np.mean(np.array(e), axis=0) for e in embeddings]
129
+
130
+ return np.array(embeddings)
131
+
132
+ # --- Fallback method: REST API ---
133
+ except Exception as e1:
134
+ st.warning(f"⚠️ feature_extraction() failed, using REST API fallback: {e1}")
135
+ try:
136
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
137
+ response = requests.post(
138
+ f"https://api-inference.huggingface.co/pipeline/feature-extraction/{EMBED_MODEL}",
139
+ headers=headers,
140
+ json={"inputs": texts}
141
+ )
142
+ response.raise_for_status()
143
+ data = response.json()
144
+
145
+ embeddings = [np.mean(np.array(e), axis=0) for e in data]
146
+ return np.array(embeddings)
147
+ except Exception as e2:
148
+ st.error(f"Embedding error: {e2}")
149
+ return np.zeros((len(texts), 768))
150
+
151
 
152
  def build_faiss_index(chunks: List[Dict[str, Any]]) -> None:
153
  """Build and save FAISS index for handbook chunks."""
 
167
  with open(EMB_DIM_FILE, "w") as f:
168
  json.dump({"dim": dim}, f)
169
 
170
+
171
  def load_faiss_index():
172
  """Load FAISS index and metadata if available."""
173
  if not (os.path.exists(INDEX_FILE) and os.path.exists(META_FILE)):
 
177
  meta = json.load(f)
178
  return index, meta
179
 
180
+
181
  def search_index(query: str, index, meta, top_k: int, threshold: float) -> List[Dict[str, Any]]:
182
  """Search FAISS for top-K similar chunks."""
183
  query_emb = embed_texts([query])
 
190
  results.append(result)
191
  return results
192
 
193
+
194
  def generate_answer(context: str, query: str) -> str:
195
  """Generate robust answer with explicit citations β€” auto-switches between endpoints."""
196
  prompt = f"""
 
241
  except Exception as e2:
242
  return f"⚠️ Error generating answer: {e2}"
243
 
244
+
245
  # =============================================================
246
  # πŸ” Index Handling
247
  # =============================================================
 
265
  st.stop()
266
  return index, meta
267
 
268
+
269
  # =============================================================
270
  # πŸ’¬ Chat Interface
271
  # =============================================================