Starberry15 commited on
Commit
79dbb99
Β·
verified Β·
1 Parent(s): fa9ab75

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +517 -189
src/streamlit_app.py CHANGED
@@ -1,252 +1,580 @@
1
- # ======================================================
2
- # πŸ“˜ Handbook Assistant (FAST OPTIMIZED VERSION)
3
- # ======================================================
4
- # Requirements:
5
- # pip install streamlit python-dotenv PyPDF2 numpy faiss-cpu scikit-learn huggingface-hub streamlit-chat sentence-transformers
6
-
7
  import os
8
- import time
9
  import glob
10
  import json
 
11
  import math
12
- from typing import List, Tuple, Dict, Any
 
13
 
14
  import numpy as np
15
  import streamlit as st
16
- from dotenv import load_dotenv
17
  import PyPDF2
18
- from streamlit_chat import message as st_message
19
-
20
- # Optional fast embedding model
21
  from sentence_transformers import SentenceTransformer
 
 
22
 
23
- # Try FAISS
24
  try:
25
  import faiss
26
  except Exception:
27
  faiss = None
28
 
29
- # ======================================================
30
- # βš™οΈ CONFIGURATION
31
- # ======================================================
32
- st.set_page_config(page_title="πŸ“š Handbook Assistant", page_icon="πŸ“˜", layout="wide")
33
- st.title("πŸ“š Handbook Assistant β€” Fast Local Version")
34
- st.caption("Place your handbook PDF (e.g., handbook.pdf) beside this script or upload below.")
35
 
36
  load_dotenv()
37
-
38
- # File names for saving
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  HAND_INDEX_FN = "handbook_faiss.index"
40
  HAND_META_FN = "handbook_metadata.json"
41
  HAND_EMB_DIM_FN = "handbook_emb_dim.json"
42
 
43
- # ======================================================
44
- # βš™οΈ SIDEBAR SETTINGS
45
- # ======================================================
46
- with st.sidebar:
47
- st.header("βš™οΈ Settings")
48
-
49
- similarity_threshold = st.slider("Similarity threshold", 0.3, 0.95, 0.62, 0.01)
50
- top_k = st.slider("Top chunks retrieved", 1, 10, 4)
51
- chunk_size_chars = st.number_input("Chunk size (chars)", min_value=400, max_value=3000, value=2000, step=100)
52
- chunk_overlap = st.number_input("Chunk overlap (chars)", min_value=20, max_value=600, value=100, step=10)
53
- regenerate_index = st.button("πŸ” Rebuild handbook index (force re-embed)")
54
- st.markdown("**Storage:** Cached FAISS index + metadata for fast restarts.")
55
-
56
- uploaded_pdf = st.file_uploader("πŸ“„ Upload handbook PDF", type=["pdf"])
57
- if uploaded_pdf:
58
- temp_path = os.path.join(os.path.dirname(__file__), uploaded_pdf.name)
59
- with open(temp_path, "wb") as f:
60
- f.write(uploaded_pdf.getbuffer())
61
- st.session_state.uploaded_pdf_path = temp_path
62
- st.success(f"βœ… Uploaded and saved: {uploaded_pdf.name}")
63
-
64
- # ======================================================
65
- # 🧩 UTILITIES
66
- # ======================================================
67
- @st.cache_resource(show_spinner=False)
68
- def get_local_embedder():
69
- """Load MiniLM model (only once)."""
70
- return SentenceTransformer("all-MiniLM-L6-v2")
71
-
72
- def find_pdfs(patterns=["handbook*.pdf", "*.pdf"]) -> List[str]:
73
- """Find handbook PDFs in script folder or uploaded ones."""
74
- base_dir = os.path.dirname(os.path.abspath(__file__))
75
- files = []
76
- for patt in patterns:
77
- files += glob.glob(os.path.join(base_dir, patt))
78
- if not files and "uploaded_pdf_path" in st.session_state:
79
- files = [st.session_state.uploaded_pdf_path]
80
- return sorted(list(set(files)))
81
 
82
  def load_pdf_texts_with_page_info(pdf_paths: List[str]) -> List[Dict[str, Any]]:
83
- """Extract text from each page with filename and page number."""
84
- all_pages = []
85
  for p in pdf_paths:
86
  try:
87
  with open(p, "rb") as f:
88
  reader = PyPDF2.PdfReader(f)
89
  for i, page in enumerate(reader.pages):
90
  try:
91
- text = page.extract_text() or ""
92
  except Exception:
93
- text = ""
94
- if text.strip():
95
- all_pages.append({"filename": os.path.basename(p), "page": i + 1, "text": text})
96
  except Exception as e:
97
- st.warning(f"⚠️ Failed to read {p}: {e}")
98
- return all_pages
99
 
100
  def chunk_pages_into_segments(pages: List[Dict[str, Any]], chunk_size: int, overlap: int) -> List[Dict[str, Any]]:
101
- """Split long page text into overlapping chunks."""
 
 
102
  chunks = []
103
  for pg in pages:
104
- text = pg["text"]
105
- filename, page_no = pg["filename"], pg["page"]
106
- start, chunk_id = 0, 0
107
- while start < len(text):
108
- end = min(start + chunk_size, len(text))
 
 
 
 
 
109
  seg = text[start:end].strip()
110
- if len(seg) > 50:
111
  chunks.append({
112
  "filename": filename,
113
  "page": page_no,
114
  "chunk_id": f"{filename}_p{page_no}_c{chunk_id}",
115
- "text": seg
116
  })
117
- chunk_id += 1
118
  start = end - overlap
119
  if start < 0:
120
  start = 0
121
  return chunks
122
 
123
- def embed_texts(texts: List[str], batch_size: int = 16) -> np.ndarray:
124
- """Fast local embedding using MiniLM in batches."""
125
- model = get_local_embedder()
126
- all_embeddings = []
127
- for i in range(0, len(texts), batch_size):
128
- batch = texts[i:i + batch_size]
129
- emb = model.encode(batch, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
130
- all_embeddings.append(emb)
131
- return np.vstack(all_embeddings)
132
-
133
- def build_faiss_index(embeddings: np.ndarray):
134
- """Build FAISS cosine index."""
135
- if faiss is None:
136
- raise RuntimeError("❌ FAISS not installed (pip install faiss-cpu)")
137
- d = embeddings.shape[1]
138
- index = faiss.IndexFlatIP(d)
139
- index.add(embeddings)
140
- return index, d
141
-
142
- def save_index_and_metadata(index, metadata, emb_dim: int):
143
- faiss.write_index(index, HAND_INDEX_FN)
144
- with open(HAND_META_FN, "w", encoding="utf-8") as f:
145
- json.dump(metadata, f, indent=2)
146
- with open(HAND_EMB_DIM_FN, "w") as f:
147
- json.dump({"dim": emb_dim}, f)
148
-
149
- def load_index_and_metadata():
150
- if not (os.path.exists(HAND_INDEX_FN) and os.path.exists(HAND_META_FN)):
151
- return None, None
152
- index = faiss.read_index(HAND_INDEX_FN)
153
- with open(HAND_META_FN, "r", encoding="utf-8") as f:
154
- meta = json.load(f)
155
- with open(HAND_EMB_DIM_FN, "r") as f:
156
- emb_dim = json.load(f)["dim"]
157
- return index, meta
158
-
159
- # ======================================================
160
- # 🧠 INDEX BUILDER
161
- # ======================================================
162
- def ensure_handbook_index(rebuild=False):
163
- """Build or load handbook FAISS index efficiently."""
164
- if "handbook_ready" in st.session_state and st.session_state.handbook_ready and not rebuild:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  return
166
 
167
- pdfs = find_pdfs()
168
  if not pdfs:
169
- st.error("❌ No handbook PDF found.")
170
  st.session_state.handbook_ready = False
 
171
  return
172
 
173
- # Try loading cached index
174
- if os.path.exists(HAND_INDEX_FN) and not rebuild:
175
- index, metadata = load_index_and_metadata()
176
- if index is not None:
177
- st.session_state.faiss_index = index
178
- st.session_state.metadata = metadata
179
- st.session_state.handbook_ready = True
180
- st.success(f"βœ… Loaded FAISS index with {len(metadata)} chunks.")
181
- return
182
-
183
- st.info("βš™οΈ Building FAISS index locally with MiniLM… this may take 30–60 seconds.")
184
- start_time = time.time()
185
-
 
186
  pages = load_pdf_texts_with_page_info(pdfs)
187
- chunks = chunk_pages_into_segments(pages, int(chunk_size_chars), int(chunk_overlap))
188
  if not chunks:
189
- st.error("❌ No readable text found in the handbook.")
 
190
  return
191
 
192
- texts = [c["text"] for c in chunks]
193
- embeddings = embed_texts(texts, batch_size=16)
194
-
195
- index, emb_dim = build_faiss_index(embeddings)
196
- save_index_and_metadata(index, chunks, emb_dim)
197
-
198
- st.session_state.faiss_index = index
199
- st.session_state.metadata = chunks
200
- st.session_state.handbook_ready = True
201
-
202
- elapsed = time.time() - start_time
203
- st.success(f"βœ… Handbook indexed in {elapsed:.1f} seconds ({len(chunks)} chunks).")
204
-
205
- # ======================================================
206
- # πŸ” RETRIEVAL
207
- # ======================================================
208
- def embed_query(query: str) -> np.ndarray:
209
- model = get_local_embedder()
210
- emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
211
- return emb.astype("float32")
212
-
213
- def retrieve_top_chunks(query: str, k: int):
214
- index = st.session_state.get("faiss_index")
215
- metadata = st.session_state.get("metadata", [])
216
- if not index or not metadata:
217
- return [], []
218
- q_emb = embed_query(query).reshape(1, -1)
219
- D, I = index.search(q_emb, k)
220
- results = [metadata[i] for i in I[0] if i < len(metadata)]
221
- return results, D[0].tolist()
222
-
223
- # ======================================================
224
- # πŸ—£οΈ CHAT INTERFACE
225
- # ======================================================
226
  ensure_handbook_index(rebuild=regenerate_index)
227
 
 
 
 
228
  st.divider()
229
- st.subheader("πŸ’¬ Ask the handbook")
230
 
231
- user_input = st.chat_input("Ask a question about the handbook...")
232
- if user_input:
233
- st_message(user_input, is_user=True)
234
 
235
- retrieved, scores = retrieve_top_chunks(user_input, top_k)
236
- if not retrieved or max(scores) < similarity_threshold:
237
- reply = "Sorry, I can only answer based on the handbook, and I couldn’t find relevant information."
238
- st_message(reply, is_user=False)
 
 
 
 
 
 
 
 
 
 
239
  else:
240
- answer = "Based on the handbook:\n\n"
241
- for r, s in zip(retrieved, scores):
242
- short = (r["text"][:300] + "…") if len(r["text"]) > 300 else r["text"]
243
- answer += f"πŸ“„ **{r['filename']}**, page {r['page']} β€” (score {s:.3f})\n> {short}\n\n"
244
- st_message(answer.strip(), is_user=False)
245
-
246
- # ======================================================
247
- # 🧾 HISTORY & EXPORT
248
- # ======================================================
 
 
249
  st.divider()
250
- st.subheader("Conversation History")
251
- if "chat_history" not in st.session_state:
252
- st.session_state.chat_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app.py
 
 
 
 
 
2
  import os
 
3
  import glob
4
  import json
5
+ import time
6
  import math
7
+ import re
8
+ from typing import List, Dict, Any, Tuple
9
 
10
  import numpy as np
11
  import streamlit as st
 
12
  import PyPDF2
13
+ from dotenv import load_dotenv
14
+ from huggingface_hub import InferenceClient, login
 
15
  from sentence_transformers import SentenceTransformer
16
+ from sklearn.feature_extraction.text import TfidfVectorizer
17
+ from heapq import nlargest
18
 
19
+ # FAISS (optional)
20
  try:
21
  import faiss
22
  except Exception:
23
  faiss = None
24
 
25
+ # =========================
26
+ # Page + env
27
+ # =========================
28
+ st.set_page_config(page_title="πŸ“˜ Handbook Assistant", page_icon="πŸ“˜", layout="wide")
29
+ st.title("πŸ“˜ USTP Student Handbook Assistant (2023 Edition)")
30
+ st.caption("This assistant answers only from the handbook. Place 'USTP Student Handbook 2023 Edition.pdf' in the same folder.")
31
 
32
  load_dotenv()
33
+ HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
34
+
35
+ hf_client = None
36
+ if HF_TOKEN:
37
+ try:
38
+ login(HF_TOKEN)
39
+ except Exception:
40
+ # login might be unnecessary depending on environment
41
+ pass
42
+ try:
43
+ hf_client = InferenceClient(token=HF_TOKEN)
44
+ except Exception as e:
45
+ st.warning(f"Could not init InferenceClient: {e}")
46
+
47
+ # =========================
48
+ # Sidebar configuration
49
+ # =========================
50
+ with st.sidebar:
51
+ st.header("βš™οΈ Settings")
52
+ model_options = {
53
+ "Qwen 2.5 14B Instruct (default)": "Qwen/Qwen2.5-14B-Instruct",
54
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
55
+ "Llama 3 8B Instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
56
+ "Falcon 7B Instruct": "tiiuae/falcon-7b-instruct",
57
+ "Mixtral 8x7B Instruct": "mistralai/Mixtral-8x7B-Instruct-v0.1",
58
+ }
59
+ model_label = st.selectbox("Choose model", list(model_options.keys()), index=0)
60
+ DEFAULT_MODEL = model_options[model_label]
61
+
62
+ st.markdown("---")
63
+ similarity_threshold = st.slider("Similarity threshold", 0.30, 0.95, 0.62, 0.01)
64
+ top_k = st.slider("Top K retrieved chunks", 1, 10, 4)
65
+ chunk_size_chars = st.number_input("Chunk size (chars)", min_value=400, max_value=3000, value=1200, step=100)
66
+ chunk_overlap = st.number_input("Chunk overlap (chars)", min_value=20, max_value=800, value=150, step=10)
67
+ regenerate_index = st.button("πŸ” Rebuild handbook index (re-extract & re-embed)")
68
+
69
+ # =========================
70
+ # Filenames for index/meta
71
+ # =========================
72
  HAND_INDEX_FN = "handbook_faiss.index"
73
  HAND_META_FN = "handbook_metadata.json"
74
  HAND_EMB_DIM_FN = "handbook_emb_dim.json"
75
 
76
+ # =========================
77
+ # Utilities: find/load PDF
78
+ # =========================
79
+ def find_handbook(preferred_name: str = "USTP Student Handbook 2023 Edition.pdf") -> List[str]:
80
+ """Return list containing handbook path (preferred) or first pdf found."""
81
+ current_dir = os.path.dirname(os.path.abspath(__file__))
82
+ preferred_path = os.path.join(current_dir, preferred_name)
83
+ if os.path.exists(preferred_path):
84
+ st.info(f"πŸ“˜ Found handbook: {preferred_name}")
85
+ return [preferred_path]
86
+ # fallback: any pdf
87
+ pdfs = glob.glob(os.path.join(current_dir, "*.pdf"))
88
+ if pdfs:
89
+ st.warning(f"⚠️ Preferred handbook not found. Using {os.path.basename(pdfs[0])}")
90
+ return [pdfs[0]]
91
+ st.error("❌ No PDF found in the app folder. Please add the handbook PDF.")
92
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def load_pdf_texts_with_page_info(pdf_paths: List[str]) -> List[Dict[str, Any]]:
95
+ """Extract text (per page) and return list of dicts with filename, page, text."""
96
+ pages = []
97
  for p in pdf_paths:
98
  try:
99
  with open(p, "rb") as f:
100
  reader = PyPDF2.PdfReader(f)
101
  for i, page in enumerate(reader.pages):
102
  try:
103
+ raw = page.extract_text() or ""
104
  except Exception:
105
+ raw = ""
106
+ pages.append({"filename": os.path.basename(p), "page": i + 1, "text": raw})
 
107
  except Exception as e:
108
+ st.warning(f"Failed to read {p}: {e}")
109
+ return pages
110
 
111
  def chunk_pages_into_segments(pages: List[Dict[str, Any]], chunk_size: int, overlap: int) -> List[Dict[str, Any]]:
112
+ """
113
+ Split pages into overlapping character chunks while preserving filename/page metadata.
114
+ """
115
  chunks = []
116
  for pg in pages:
117
+ text = (pg.get("text") or "").strip()
118
+ if not text:
119
+ continue
120
+ filename = pg.get("filename", "handbook")
121
+ page_no = pg.get("page", 0)
122
+ start = 0
123
+ chunk_id = 0
124
+ L = len(text)
125
+ while start < L:
126
+ end = min(start + chunk_size, L)
127
  seg = text[start:end].strip()
128
+ if len(seg) >= 30:
129
  chunks.append({
130
  "filename": filename,
131
  "page": page_no,
132
  "chunk_id": f"{filename}_p{page_no}_c{chunk_id}",
133
+ "content": seg
134
  })
135
+ chunk_id += 1
136
  start = end - overlap
137
  if start < 0:
138
  start = 0
139
  return chunks
140
 
141
+ # =========================
142
+ # Embeddings: robust pipeline
143
+ # =========================
144
+ TFIDF_MAX_FEATURES = 50000
145
+
146
+ @st.cache_resource
147
+ def get_tfidf_vectorizer():
148
+ return TfidfVectorizer(stop_words="english", max_features=TFIDF_MAX_FEATURES)
149
+
150
+ @st.cache_resource
151
+ def load_local_embedder():
152
+ """
153
+ Try to load a SentenceTransformer model. Will raise if cannot load.
154
+ """
155
+ # compact, fast model recommended
156
+ MODEL_NAME = "all-MiniLM-L6-v2"
157
+ return SentenceTransformer(MODEL_NAME)
158
+
159
+ def hf_embeddings_call_if_possible(texts: List[str], model_name: str = "sentence-transformers/all-mpnet-base-v2") -> Tuple[bool, Any]:
160
+ """
161
+ Try calling HF InferenceClient embeddings call in a few ways depending on client version.
162
+ Returns (success_bool, embeddings_or_error)
163
+ """
164
+ if not hf_client:
165
+ return False, "No HF client"
166
+ try:
167
+ # Preferred modern method
168
+ if hasattr(hf_client, "embeddings"):
169
+ out = hf_client.embeddings(model=model_name, inputs=texts)
170
+ # handle common shapes
171
+ if isinstance(out, dict) and "embedding" in out:
172
+ # single input case
173
+ return True, np.array(out["embedding"], dtype=np.float32)
174
+ # sometimes returns list of dicts
175
+ if isinstance(out, list) and out and isinstance(out[0], dict) and "embedding" in out[0]:
176
+ arr = [d["embedding"] for d in out]
177
+ return True, np.array(arr, dtype=np.float32)
178
+ # sometimes returns list-of-lists
179
+ if isinstance(out, list) and len(out) and isinstance(out[0], (list, tuple)):
180
+ return True, np.array(out, dtype=np.float32)
181
+ return False, f"Unexpected hf_client.embeddings output shape: {type(out)}"
182
+ # older client versions may have 'feature_extraction'
183
+ if hasattr(hf_client, "feature_extraction"):
184
+ out = hf_client.feature_extraction(texts, model=model_name)
185
+ return True, np.array(out, dtype=np.float32)
186
+ # As a last resort, try .post() to the inference endpoint (some versions)
187
+ if hasattr(hf_client, "post"):
188
+ url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
189
+ out = hf_client.post(url, json={"inputs": texts})
190
+ return True, np.array(out, dtype=np.float32)
191
+ except Exception as e:
192
+ return False, e
193
+ return False, "No known embeddings method on hf_client"
194
+
195
+ def fallback_vectorize(texts: List[str]) -> np.ndarray:
196
+ """TF-IDF fallback embeddings (normalized)."""
197
+ if not texts:
198
+ return np.zeros((0, 0), dtype=np.float32)
199
+ vect = get_tfidf_vectorizer()
200
+ X = vect.fit_transform(texts) # sparse matrix
201
+ arr = X.toarray().astype(np.float32)
202
+ norms = np.linalg.norm(arr, axis=1, keepdims=True)
203
+ norms[norms == 0] = 1.0
204
+ arr = arr / norms
205
+ return arr
206
+
207
+ def embed_texts(texts: List[str]) -> np.ndarray:
208
+ """
209
+ Unified embedding function:
210
+ 1) Try HF embedding call (if client present)
211
+ 2) Try local SentenceTransformer embedder
212
+ 3) Fallback to TF-IDF
213
+ Returns normalized float32 numpy array.
214
+ """
215
+ if not texts:
216
+ return np.zeros((0, 0), dtype=np.float32)
217
+
218
+ # 1) HF first (cheap if credits available)
219
+ success, out = hf_embeddings_call_if_possible(texts)
220
+ if success:
221
+ try:
222
+ arr = np.array(out, dtype=np.float32)
223
+ # if single vector returned for single input, reshape
224
+ if arr.ndim == 1:
225
+ arr = arr.reshape(1, -1)
226
+ norms = np.linalg.norm(arr, axis=1, keepdims=True)
227
+ norms[norms == 0] = 1.0
228
+ return arr / norms
229
+ except Exception:
230
+ pass
231
+
232
+ # 2) Local model
233
+ try:
234
+ model = load_local_embedder()
235
+ arr = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
236
+ arr = np.array(arr, dtype=np.float32)
237
+ if arr.ndim == 1:
238
+ arr = arr.reshape(1, -1)
239
+ norms = np.linalg.norm(arr, axis=1, keepdims=True)
240
+ norms[norms == 0] = 1.0
241
+ return arr / norms
242
+ except Exception as e:
243
+ st.warning(f"⚠️ Local SentenceTransformer failed or unavailable: {e}")
244
+
245
+ # 3) TF-IDF fallback
246
+ try:
247
+ st.info("Using TF-IDF fallback embeddings (offline).")
248
+ return fallback_vectorize(texts)
249
+ except Exception as e:
250
+ st.error(f"Embedding fallback failed completely: {e}")
251
+ return np.zeros((len(texts), 128), dtype=np.float32)
252
+
253
+ # =========================
254
+ # Build / load index
255
+ # =========================
256
+ def build_faiss_index(chunks: List[Dict[str, Any]]) -> Tuple[Any, List[Dict[str, Any]]]:
257
+ """
258
+ Build FAISS index (if faiss available) and return index + metadata (chunks)
259
+ """
260
+ texts = [c["content"] for c in chunks]
261
+ emb = embed_texts(texts)
262
+ if emb.size == 0:
263
+ raise RuntimeError("No embeddings produced.")
264
+ if faiss is not None:
265
+ d = emb.shape[1]
266
+ # Use Inner Product on normalized vectors for cosine
267
+ index = faiss.IndexFlatIP(d)
268
+ # ensure normalized
269
+ norms = np.linalg.norm(emb, axis=1, keepdims=True)
270
+ norms[norms == 0] = 1.0
271
+ emb_norm = emb / norms
272
+ index.add(emb_norm.astype("float32"))
273
+ # Save index & metadata
274
+ faiss.write_index(index, HAND_INDEX_FN)
275
+ with open(HAND_META_FN, "w", encoding="utf-8") as f:
276
+ json.dump(chunks, f, indent=2)
277
+ with open(HAND_EMB_DIM_FN, "w", encoding="utf-8") as f:
278
+ json.dump({"dim": d}, f)
279
+ return index, chunks
280
+ else:
281
+ # No FAISS: we return embeddings baked into an in-memory structure (meta includes embeddings)
282
+ for i, c in enumerate(chunks):
283
+ c["_embedding"] = emb[i].tolist()
284
+ with open(HAND_META_FN, "w", encoding="utf-8") as f:
285
+ json.dump(chunks, f, indent=2)
286
+ return None, chunks
287
+
288
+ def load_index_and_metadata() -> Tuple[Any, List[Dict[str, Any]]]:
289
+ if os.path.exists(HAND_META_FN) and os.path.exists(HAND_EMB_DIM_FN) and os.path.exists(HAND_INDEX_FN) and faiss is not None:
290
+ try:
291
+ index = faiss.read_index(HAND_INDEX_FN)
292
+ with open(HAND_META_FN, "r", encoding="utf-8") as f:
293
+ meta = json.load(f)
294
+ return index, meta
295
+ except Exception as e:
296
+ st.warning(f"Failed to load saved FAISS index: {e}")
297
+ return None, None
298
+ # fallback to metadata only
299
+ if os.path.exists(HAND_META_FN):
300
+ with open(HAND_META_FN, "r", encoding="utf-8") as f:
301
+ meta = json.load(f)
302
+ return None, meta
303
+ return None, None
304
+
305
+ # =========================
306
+ # Retrieval
307
+ # =========================
308
+ def retrieve_top_chunks(query: str, k: int = 4, metadata: List[Dict[str, Any]] = None, index = None) -> Tuple[List[Dict[str, Any]], List[float]]:
309
+ """
310
+ Return top-k chunks and similarity scores (cosine-like).
311
+ Works with FAISS if available, otherwise does brute-force using stored embeddings or TF-IDF.
312
+ """
313
+ if not metadata:
314
+ metadata = []
315
+ # If FAISS index available
316
+ if index is not None:
317
+ q_emb = embed_texts([query])
318
+ if q_emb.ndim == 1:
319
+ q_emb = q_emb.reshape(1, -1)
320
+ # normalize and search
321
+ norms = np.linalg.norm(q_emb, axis=1, keepdims=True)
322
+ norms[norms == 0] = 1.0
323
+ q_emb_norm = q_emb / norms
324
+ D, I = index.search(q_emb_norm.astype("float32"), k)
325
+ scores = D[0].tolist()
326
+ idxs = I[0].tolist()
327
+ results = []
328
+ for idx, score in zip(idxs, scores):
329
+ if 0 <= idx < len(metadata):
330
+ results.append(metadata[idx])
331
+ return results, scores
332
+ # else brute-force: metadata may include stored embeddings or we compute embeddings of metadata texts
333
+ # If metadata items have "_embedding", use them
334
+ if metadata and "_embedding" in metadata[0]:
335
+ emb_mat = np.array([np.array(m["_embedding"], dtype=np.float32) for m in metadata])
336
+ q_emb = embed_texts([query]).astype(np.float32)
337
+ if q_emb.ndim == 1:
338
+ q_emb = q_emb.reshape(1, -1)
339
+ # cosine
340
+ emb_norms = np.linalg.norm(emb_mat, axis=1, keepdims=True)
341
+ emb_norms[emb_norms == 0] = 1.0
342
+ emb_mat_n = emb_mat / emb_norms
343
+ qn = q_emb / np.linalg.norm(q_emb, axis=1, keepdims=True)
344
+ sims = (emb_mat_n @ qn.T).squeeze() # cosine values
345
+ idxs = np.argsort(-sims)[:k]
346
+ results = [metadata[int(i)] for i in idxs]
347
+ scores = [float(sims[int(i)]) for i in idxs]
348
+ return results, scores
349
+ # final fallback: TF-IDF direct scoring between query and chunk contents (cheap)
350
+ texts = [m["content"] for m in metadata]
351
+ vect = TfidfVectorizer(stop_words="english", max_features=TFIDF_MAX_FEATURES)
352
+ if texts:
353
+ X = vect.fit_transform(texts)
354
+ qv = vect.transform([query])
355
+ sims = (X @ qv.T).toarray().squeeze()
356
+ idxs = np.argsort(-sims)[:k]
357
+ results = [metadata[int(i)] for i in idxs]
358
+ scores = [float(sims[int(i)]) for i in idxs]
359
+ return results, scores
360
+ return [], []
361
+
362
+ # =========================
363
+ # Extractive answer fallback
364
+ # =========================
365
+ def extractive_answer_from_chunks(retrieved_chunks: List[Dict[str, Any]], query: str) -> str:
366
+ if not retrieved_chunks:
367
+ return "The handbook does not specify that."
368
+ q_tokens = set([t.lower() for t in re.findall(r"\w+", query) if len(t) > 2])
369
+ scored = []
370
+ for rc in retrieved_chunks:
371
+ text = rc.get("content") or rc.get("text") or ""
372
+ sents = re.split(r'(?<=[.!?])\s+', text)
373
+ for s in sents:
374
+ tokens = set([t.lower() for t in re.findall(r"\w+", s) if len(t) > 2])
375
+ if not tokens:
376
+ continue
377
+ overlap = len(q_tokens & tokens) / (1 + len(tokens))
378
+ scored.append((overlap, s.strip(), rc))
379
+ if not scored:
380
+ return "The handbook does not specify that."
381
+ topk = nlargest(2, scored, key=lambda x: x[0])
382
+ parts = []
383
+ for score, sent, rc in topk:
384
+ cite = f"(Source: {rc.get('filename','handbook')}, page {rc.get('page',0)})"
385
+ short_sent = sent if len(sent) <= 400 else sent[:397] + "..."
386
+ parts.append(f"\"{short_sent}\" {cite}")
387
+ final = "\n\n".join(parts)
388
+ final += "\n\nTakeaway: Refer to the cited section(s) above for the official handbook wording."
389
+ return final
390
+
391
+ # =========================
392
+ # Generation with HF fallback
393
+ # =========================
394
+ def try_hf_generate(prompt: str) -> Tuple[bool, str]:
395
+ """
396
+ Try various HF generation endpoints. Returns (success, text_or_error).
397
+ Handles different InferenceClient versions gracefully.
398
+ """
399
+ if not hf_client:
400
+ return False, "No HF client"
401
+ # 1) text_generation method
402
+ try:
403
+ if hasattr(hf_client, "text_generation"):
404
+ out = hf_client.text_generation(model=DEFAULT_MODEL, inputs=prompt, max_new_tokens=400, temperature=0.25)
405
+ # out may be dict or list depending on client
406
+ if isinstance(out, dict) and "generated_text" in out:
407
+ return True, out["generated_text"]
408
+ if isinstance(out, list) and out and "generated_text" in out[0]:
409
+ return True, out[0]["generated_text"]
410
+ return True, str(out)
411
+ except Exception as e:
412
+ # ignore and fallback
413
+ pass
414
+ # 2) chat style: try common patterns
415
+ try:
416
+ # Some clients expose hf_client.chat()
417
+ if hasattr(hf_client, "chat"):
418
+ resp = hf_client.chat(model=DEFAULT_MODEL, messages=[{"role":"user","content":prompt}], max_tokens=400, temperature=0.25)
419
+ # try to extract common shapes
420
+ if isinstance(resp, dict) and "choices" in resp:
421
+ try:
422
+ return True, resp["choices"][0]["message"]["content"]
423
+ except Exception:
424
+ return True, str(resp)
425
+ if isinstance(resp, list) and resp and isinstance(resp[0], dict) and "generated_text" in resp[0]:
426
+ return True, resp[0]["generated_text"]
427
+ return True, str(resp)
428
+ # Some clients have chat.completions.create()
429
+ if hasattr(hf_client, "chat") and hasattr(hf_client.chat, "completions") and hasattr(hf_client.chat.completions, "create"):
430
+ resp = hf_client.chat.completions.create(model=DEFAULT_MODEL, messages=[{"role":"user","content":prompt}], max_tokens=400, temperature=0.25)
431
+ try:
432
+ return True, resp.choices[0].message["content"]
433
+ except Exception:
434
+ return True, str(resp)
435
+ # Last resort: some clients have 'create' on top-level
436
+ if hasattr(hf_client, "create"):
437
+ resp = hf_client.create(model=DEFAULT_MODEL, inputs=prompt, max_new_tokens=400, temperature=0.25)
438
+ if isinstance(resp, dict) and "generated_text" in resp:
439
+ return True, resp["generated_text"]
440
+ return True, str(resp)
441
+ except Exception as e:
442
+ return False, e
443
+ return False, "No known generation method"
444
+
445
+ def generate_answer(context: str, query: str, retrieved_chunks: List[Dict[str, Any]] = None) -> str:
446
+ """
447
+ Attempt to call HF generation; if that fails, fallback to extractive, citation-backed answer.
448
+ Pass retrieved_chunks (list) so extractive fallback can cite pages.
449
+ """
450
+ prompt = f"""
451
+ You are a precise academic assistant specialized in university policies.
452
+ Use only the provided USTP Student Handbook content below. If the answer is not in the provided text, respond exactly:
453
+ "The handbook does not specify that."
454
+
455
+ Context:
456
+ {context}
457
+
458
+ Question: {query}
459
+
460
+ Provide a concise answer including source citations (filename + page).
461
+ """
462
+ success, out = try_hf_generate(prompt)
463
+ if success:
464
+ # if out is not str, ensure str
465
+ return out if isinstance(out, str) else str(out)
466
+ # HF failed (e.g., 402 or no credits) -> extractive fallback
467
+ st.warning("HF generation unavailable β€” using extractive handbook-backed answer (no hallucination).")
468
+ return extractive_answer_from_chunks(retrieved_chunks or [], query)
469
+
470
+ # =========================
471
+ # Index management (persist/load)
472
+ # =========================
473
+ def ensure_handbook_index(rebuild: bool = False):
474
+ """
475
+ Create or load index and metadata.
476
+ Stores results in st.session_state as well for quick reuse.
477
+ """
478
+ # If already built and not rebuilding, return
479
+ if st.session_state.get("handbook_ready") and not rebuild:
480
  return
481
 
482
+ pdfs = find_handbook()
483
  if not pdfs:
 
484
  st.session_state.handbook_ready = False
485
+ st.session_state.handbook_chunks = []
486
  return
487
 
488
+ # if saved index exists & not rebuilding
489
+ if not rebuild and os.path.exists(HAND_META_FN) and (faiss is not None and os.path.exists(HAND_INDEX_FN) and os.path.exists(HAND_EMB_DIM_FN)):
490
+ try:
491
+ idx, meta = load_index_and_metadata()
492
+ if meta:
493
+ st.session_state.faiss_index = idx
494
+ st.session_state.metadata = meta
495
+ st.session_state.handbook_ready = True
496
+ st.success(f"Loaded saved index ({len(meta)} chunks).")
497
+ return
498
+ except Exception:
499
+ pass
500
+
501
+ # extract pages -> chunks
502
  pages = load_pdf_texts_with_page_info(pdfs)
503
+ chunks = chunk_pages_into_segments(pages, chunk_size=int(chunk_size_chars), overlap=int(chunk_overlap))
504
  if not chunks:
505
+ st.error("No text found in PDFs.")
506
+ st.session_state.handbook_ready = False
507
  return
508
 
509
+ # build index (this will attempt HF embeddings -> local -> TFIDF)
510
+ try:
511
+ idx, meta = build_faiss_index(chunks)
512
+ st.session_state.faiss_index = idx
513
+ st.session_state.metadata = meta
514
+ st.session_state.handbook_ready = True
515
+ st.success(f"Indexed {len(meta)} chunks.")
516
+ except Exception as e:
517
+ st.error(f"Failed to build index: {e}")
518
+ # as fallback, store chunks in session
519
+ st.session_state.metadata = chunks
520
+ st.session_state.faiss_index = None
521
+ st.session_state.handbook_ready = True
522
+
523
+ # build / load index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  ensure_handbook_index(rebuild=regenerate_index)
525
 
526
+ # =========================
527
+ # Chat UI
528
+ # =========================
529
  st.divider()
530
+ st.subheader("πŸ’¬ Ask the handbook (only handbook-based answers)")
531
 
532
+ if "chat_history" not in st.session_state:
533
+ st.session_state.chat_history = []
 
534
 
535
+ # Input and handling
536
+ user_query = st.chat_input("Ask a question about the handbook...")
537
+ if user_query:
538
+ ts = int(time.time() * 1000)
539
+ st.session_state.chat_history.append({"role": "user", "content": user_query, "ts": ts})
540
+ # Retrieve top chunks
541
+ index = st.session_state.get("faiss_index")
542
+ metadata = st.session_state.get("metadata", [])
543
+ with st.spinner("πŸ”Ž Retrieving relevant handbook excerpts..."):
544
+ retrieved, scores = retrieve_top_chunks(user_query, k=int(top_k), metadata=metadata, index=index)
545
+ # Reject if no good match
546
+ if not retrieved or (scores and max(scores) < float(similarity_threshold)):
547
+ reply = "Sorry, I can only answer questions based on the school's handbook. I couldn't find relevant information in the handbook for your question."
548
+ st.session_state.chat_history.append({"role": "assistant", "content": reply, "ts": int(time.time() * 1000)})
549
  else:
550
+ # Build context snippet for model (concise)
551
+ context_text = "\n\n".join([f"--- {r['chunk_id']} | {r['filename']} | page {r['page']} ---\n{r['content']}" if 'chunk_id' in r else f"(Page {r.get('page')})\n{r.get('content')}" for r in retrieved])
552
+ # Query model or fallback extractive
553
+ with st.spinner("πŸ€– Generating answer..."):
554
+ ans = generate_answer(context_text, user_query, retrieved_chunks=retrieved)
555
+ # Append citation block
556
+ citations = "\n".join([f"{r.get('chunk_id', 'n/a')} β€” {r.get('filename')} p{r.get('page')} (score {float(s):.3f})" for r, s in zip(retrieved, scores or [])])
557
+ final = f"{ans}\n\n**Retrieved sources (top results):**\n{citations}"
558
+ st.session_state.chat_history.append({"role": "assistant", "content": final, "ts": int(time.time() * 1000)})
559
+
560
+ # Display chat history with unique keys
561
  st.divider()
562
+ st.subheader("Conversation")
563
+ for i, entry in enumerate(st.session_state.chat_history):
564
+ is_user = entry.get("role") == "user"
565
+ # use ts and i to ensure uniqueness across identical messages
566
+ key = f"msg_{i}_{entry.get('ts',0)}"
567
+ st_message(entry["content"], is_user=is_user, key=key)
568
+
569
+ # Toolbar
570
+ st.divider()
571
+ col1, col2 = st.columns([1, 1])
572
+ with col1:
573
+ if st.button("πŸ”„ Reset chat"):
574
+ st.session_state.chat_history = []
575
+ st.success("Chat reset.")
576
+ with col2:
577
+ transcript = "\n\n".join([f"{m['role'].upper()}: {m['content']}" for m in st.session_state.chat_history])
578
+ st.download_button("πŸ“₯ Download transcript", data=transcript, file_name="handbook_transcript.txt")
579
+
580
+ st.caption("⚑ FAISS + Local embeddings + Hugging Face (when available). Default model: Qwen 2.5 14B")