File size: 24,107 Bytes
79dbb99
13ec6bf
 
 
79dbb99
56780d5
79dbb99
 
56780d5
72e3f3c
 
56780d5
79dbb99
 
fa9ab75
79dbb99
 
fa9ab75
79dbb99
fa9ab75
 
 
 
13ec6bf
79dbb99
 
 
 
 
 
13ec6bf
 
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa9ab75
 
 
56780d5
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56780d5
 
79dbb99
 
56780d5
 
 
 
f12886d
fa9ab75
79dbb99
fa9ab75
79dbb99
 
56780d5
79dbb99
 
56780d5
fa9ab75
79dbb99
 
 
13ec6bf
56780d5
79dbb99
 
 
 
 
 
 
 
 
 
56780d5
79dbb99
56780d5
 
 
 
79dbb99
56780d5
79dbb99
56780d5
 
 
13ec6bf
 
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56780d5
fa9ab75
79dbb99
56780d5
fa9ab75
79dbb99
56780d5
fa9ab75
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
fa9ab75
79dbb99
fa9ab75
79dbb99
 
fa9ab75
 
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa9ab75
13ec6bf
79dbb99
 
 
56780d5
79dbb99
56780d5
79dbb99
 
fa9ab75
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
13ec6bf
79dbb99
 
 
 
 
 
 
 
 
 
 
56780d5
79dbb99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
# streamlit_app.py
import os
import glob
import json
import time
import math
import re
from typing import List, Dict, Any, Tuple

import numpy as np
import streamlit as st
import PyPDF2
from dotenv import load_dotenv
from huggingface_hub import InferenceClient, login
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from heapq import nlargest

# FAISS (optional)
try:
    import faiss
except Exception:
    faiss = None

# =========================
# Page + env
# =========================
st.set_page_config(page_title="πŸ“˜ Handbook Assistant", page_icon="πŸ“˜", layout="wide")
st.title("πŸ“˜ USTP Student Handbook Assistant (2023 Edition)")
st.caption("This assistant answers only from the handbook. Place 'USTP Student Handbook 2023 Edition.pdf' in the same folder.")

load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")

hf_client = None
if HF_TOKEN:
    try:
        login(HF_TOKEN)
    except Exception:
        # login might be unnecessary depending on environment
        pass
    try:
        hf_client = InferenceClient(token=HF_TOKEN)
    except Exception as e:
        st.warning(f"Could not init InferenceClient: {e}")

# =========================
# Sidebar configuration
# =========================
with st.sidebar:
    st.header("βš™οΈ Settings")
    model_options = {
        "Qwen 2.5 14B Instruct (default)": "Qwen/Qwen2.5-14B-Instruct",
        "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
        "Llama 3 8B Instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
        "Falcon 7B Instruct": "tiiuae/falcon-7b-instruct",
        "Mixtral 8x7B Instruct": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    }
    model_label = st.selectbox("Choose model", list(model_options.keys()), index=0)
    DEFAULT_MODEL = model_options[model_label]

    st.markdown("---")
    similarity_threshold = st.slider("Similarity threshold", 0.30, 0.95, 0.62, 0.01)
    top_k = st.slider("Top K retrieved chunks", 1, 10, 4)
    chunk_size_chars = st.number_input("Chunk size (chars)", min_value=400, max_value=3000, value=1200, step=100)
    chunk_overlap = st.number_input("Chunk overlap (chars)", min_value=20, max_value=800, value=150, step=10)
    regenerate_index = st.button("πŸ” Rebuild handbook index (re-extract & re-embed)")

# =========================
# Filenames for index/meta
# =========================
HAND_INDEX_FN = "handbook_faiss.index"
HAND_META_FN = "handbook_metadata.json"
HAND_EMB_DIM_FN = "handbook_emb_dim.json"

# =========================
# Utilities: find/load PDF
# =========================
def find_handbook(preferred_name: str = "USTP Student Handbook 2023 Edition.pdf") -> List[str]:
    """Return list containing handbook path (preferred) or first pdf found."""
    current_dir = os.path.dirname(os.path.abspath(__file__))
    preferred_path = os.path.join(current_dir, preferred_name)
    if os.path.exists(preferred_path):
        st.info(f"πŸ“˜ Found handbook: {preferred_name}")
        return [preferred_path]
    # fallback: any pdf
    pdfs = glob.glob(os.path.join(current_dir, "*.pdf"))
    if pdfs:
        st.warning(f"⚠️ Preferred handbook not found. Using {os.path.basename(pdfs[0])}")
        return [pdfs[0]]
    st.error("❌ No PDF found in the app folder. Please add the handbook PDF.")
    return []

def load_pdf_texts_with_page_info(pdf_paths: List[str]) -> List[Dict[str, Any]]:
    """Extract text (per page) and return list of dicts with filename, page, text."""
    pages = []
    for p in pdf_paths:
        try:
            with open(p, "rb") as f:
                reader = PyPDF2.PdfReader(f)
                for i, page in enumerate(reader.pages):
                    try:
                        raw = page.extract_text() or ""
                    except Exception:
                        raw = ""
                    pages.append({"filename": os.path.basename(p), "page": i + 1, "text": raw})
        except Exception as e:
            st.warning(f"Failed to read {p}: {e}")
    return pages

def chunk_pages_into_segments(pages: List[Dict[str, Any]], chunk_size: int, overlap: int) -> List[Dict[str, Any]]:
    """
    Split pages into overlapping character chunks while preserving filename/page metadata.
    """
    chunks = []
    for pg in pages:
        text = (pg.get("text") or "").strip()
        if not text:
            continue
        filename = pg.get("filename", "handbook")
        page_no = pg.get("page", 0)
        start = 0
        chunk_id = 0
        L = len(text)
        while start < L:
            end = min(start + chunk_size, L)
            seg = text[start:end].strip()
            if len(seg) >= 30:
                chunks.append({
                    "filename": filename,
                    "page": page_no,
                    "chunk_id": f"{filename}_p{page_no}_c{chunk_id}",
                    "content": seg
                })
                chunk_id += 1
            start = end - overlap
            if start < 0:
                start = 0
    return chunks

# =========================
# Embeddings: robust pipeline
# =========================
TFIDF_MAX_FEATURES = 50000

@st.cache_resource
def get_tfidf_vectorizer():
    return TfidfVectorizer(stop_words="english", max_features=TFIDF_MAX_FEATURES)

@st.cache_resource
def load_local_embedder():
    """
    Try to load a SentenceTransformer model. Will raise if cannot load.
    """
    # compact, fast model recommended
    MODEL_NAME = "all-MiniLM-L6-v2"
    return SentenceTransformer(MODEL_NAME)

def hf_embeddings_call_if_possible(texts: List[str], model_name: str = "sentence-transformers/all-mpnet-base-v2") -> Tuple[bool, Any]:
    """
    Try calling HF InferenceClient embeddings call in a few ways depending on client version.
    Returns (success_bool, embeddings_or_error)
    """
    if not hf_client:
        return False, "No HF client"
    try:
        # Preferred modern method
        if hasattr(hf_client, "embeddings"):
            out = hf_client.embeddings(model=model_name, inputs=texts)
            # handle common shapes
            if isinstance(out, dict) and "embedding" in out:
                # single input case
                return True, np.array(out["embedding"], dtype=np.float32)
            # sometimes returns list of dicts
            if isinstance(out, list) and out and isinstance(out[0], dict) and "embedding" in out[0]:
                arr = [d["embedding"] for d in out]
                return True, np.array(arr, dtype=np.float32)
            # sometimes returns list-of-lists
            if isinstance(out, list) and len(out) and isinstance(out[0], (list, tuple)):
                return True, np.array(out, dtype=np.float32)
            return False, f"Unexpected hf_client.embeddings output shape: {type(out)}"
        # older client versions may have 'feature_extraction'
        if hasattr(hf_client, "feature_extraction"):
            out = hf_client.feature_extraction(texts, model=model_name)
            return True, np.array(out, dtype=np.float32)
        # As a last resort, try .post() to the inference endpoint (some versions)
        if hasattr(hf_client, "post"):
            url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
            out = hf_client.post(url, json={"inputs": texts})
            return True, np.array(out, dtype=np.float32)
    except Exception as e:
        return False, e
    return False, "No known embeddings method on hf_client"

def fallback_vectorize(texts: List[str]) -> np.ndarray:
    """TF-IDF fallback embeddings (normalized)."""
    if not texts:
        return np.zeros((0, 0), dtype=np.float32)
    vect = get_tfidf_vectorizer()
    X = vect.fit_transform(texts)  # sparse matrix
    arr = X.toarray().astype(np.float32)
    norms = np.linalg.norm(arr, axis=1, keepdims=True)
    norms[norms == 0] = 1.0
    arr = arr / norms
    return arr

def embed_texts(texts: List[str]) -> np.ndarray:
    """
    Unified embedding function:
    1) Try HF embedding call (if client present)
    2) Try local SentenceTransformer embedder
    3) Fallback to TF-IDF
    Returns normalized float32 numpy array.
    """
    if not texts:
        return np.zeros((0, 0), dtype=np.float32)

    # 1) HF first (cheap if credits available)
    success, out = hf_embeddings_call_if_possible(texts)
    if success:
        try:
            arr = np.array(out, dtype=np.float32)
            # if single vector returned for single input, reshape
            if arr.ndim == 1:
                arr = arr.reshape(1, -1)
            norms = np.linalg.norm(arr, axis=1, keepdims=True)
            norms[norms == 0] = 1.0
            return arr / norms
        except Exception:
            pass

    # 2) Local model
    try:
        model = load_local_embedder()
        arr = model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
        arr = np.array(arr, dtype=np.float32)
        if arr.ndim == 1:
            arr = arr.reshape(1, -1)
        norms = np.linalg.norm(arr, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        return arr / norms
    except Exception as e:
        st.warning(f"⚠️ Local SentenceTransformer failed or unavailable: {e}")

    # 3) TF-IDF fallback
    try:
        st.info("Using TF-IDF fallback embeddings (offline).")
        return fallback_vectorize(texts)
    except Exception as e:
        st.error(f"Embedding fallback failed completely: {e}")
        return np.zeros((len(texts), 128), dtype=np.float32)

# =========================
# Build / load index
# =========================
def build_faiss_index(chunks: List[Dict[str, Any]]) -> Tuple[Any, List[Dict[str, Any]]]:
    """
    Build FAISS index (if faiss available) and return index + metadata (chunks)
    """
    texts = [c["content"] for c in chunks]
    emb = embed_texts(texts)
    if emb.size == 0:
        raise RuntimeError("No embeddings produced.")
    if faiss is not None:
        d = emb.shape[1]
        # Use Inner Product on normalized vectors for cosine
        index = faiss.IndexFlatIP(d)
        # ensure normalized
        norms = np.linalg.norm(emb, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        emb_norm = emb / norms
        index.add(emb_norm.astype("float32"))
        # Save index & metadata
        faiss.write_index(index, HAND_INDEX_FN)
        with open(HAND_META_FN, "w", encoding="utf-8") as f:
            json.dump(chunks, f, indent=2)
        with open(HAND_EMB_DIM_FN, "w", encoding="utf-8") as f:
            json.dump({"dim": d}, f)
        return index, chunks
    else:
        # No FAISS: we return embeddings baked into an in-memory structure (meta includes embeddings)
        for i, c in enumerate(chunks):
            c["_embedding"] = emb[i].tolist()
        with open(HAND_META_FN, "w", encoding="utf-8") as f:
            json.dump(chunks, f, indent=2)
        return None, chunks

def load_index_and_metadata() -> Tuple[Any, List[Dict[str, Any]]]:
    if os.path.exists(HAND_META_FN) and os.path.exists(HAND_EMB_DIM_FN) and os.path.exists(HAND_INDEX_FN) and faiss is not None:
        try:
            index = faiss.read_index(HAND_INDEX_FN)
            with open(HAND_META_FN, "r", encoding="utf-8") as f:
                meta = json.load(f)
            return index, meta
        except Exception as e:
            st.warning(f"Failed to load saved FAISS index: {e}")
            return None, None
    # fallback to metadata only
    if os.path.exists(HAND_META_FN):
        with open(HAND_META_FN, "r", encoding="utf-8") as f:
            meta = json.load(f)
        return None, meta
    return None, None

# =========================
# Retrieval
# =========================
def retrieve_top_chunks(query: str, k: int = 4, metadata: List[Dict[str, Any]] = None, index = None) -> Tuple[List[Dict[str, Any]], List[float]]:
    """
    Return top-k chunks and similarity scores (cosine-like).
    Works with FAISS if available, otherwise does brute-force using stored embeddings or TF-IDF.
    """
    if not metadata:
        metadata = []
    # If FAISS index available
    if index is not None:
        q_emb = embed_texts([query])
        if q_emb.ndim == 1:
            q_emb = q_emb.reshape(1, -1)
        # normalize and search
        norms = np.linalg.norm(q_emb, axis=1, keepdims=True)
        norms[norms == 0] = 1.0
        q_emb_norm = q_emb / norms
        D, I = index.search(q_emb_norm.astype("float32"), k)
        scores = D[0].tolist()
        idxs = I[0].tolist()
        results = []
        for idx, score in zip(idxs, scores):
            if 0 <= idx < len(metadata):
                results.append(metadata[idx])
        return results, scores
    # else brute-force: metadata may include stored embeddings or we compute embeddings of metadata texts
    # If metadata items have "_embedding", use them
    if metadata and "_embedding" in metadata[0]:
        emb_mat = np.array([np.array(m["_embedding"], dtype=np.float32) for m in metadata])
        q_emb = embed_texts([query]).astype(np.float32)
        if q_emb.ndim == 1:
            q_emb = q_emb.reshape(1, -1)
        # cosine
        emb_norms = np.linalg.norm(emb_mat, axis=1, keepdims=True)
        emb_norms[emb_norms == 0] = 1.0
        emb_mat_n = emb_mat / emb_norms
        qn = q_emb / np.linalg.norm(q_emb, axis=1, keepdims=True)
        sims = (emb_mat_n @ qn.T).squeeze()  # cosine values
        idxs = np.argsort(-sims)[:k]
        results = [metadata[int(i)] for i in idxs]
        scores = [float(sims[int(i)]) for i in idxs]
        return results, scores
    # final fallback: TF-IDF direct scoring between query and chunk contents (cheap)
    texts = [m["content"] for m in metadata]
    vect = TfidfVectorizer(stop_words="english", max_features=TFIDF_MAX_FEATURES)
    if texts:
        X = vect.fit_transform(texts)
        qv = vect.transform([query])
        sims = (X @ qv.T).toarray().squeeze()
        idxs = np.argsort(-sims)[:k]
        results = [metadata[int(i)] for i in idxs]
        scores = [float(sims[int(i)]) for i in idxs]
        return results, scores
    return [], []

# =========================
# Extractive answer fallback
# =========================
def extractive_answer_from_chunks(retrieved_chunks: List[Dict[str, Any]], query: str) -> str:
    if not retrieved_chunks:
        return "The handbook does not specify that."
    q_tokens = set([t.lower() for t in re.findall(r"\w+", query) if len(t) > 2])
    scored = []
    for rc in retrieved_chunks:
        text = rc.get("content") or rc.get("text") or ""
        sents = re.split(r'(?<=[.!?])\s+', text)
        for s in sents:
            tokens = set([t.lower() for t in re.findall(r"\w+", s) if len(t) > 2])
            if not tokens:
                continue
            overlap = len(q_tokens & tokens) / (1 + len(tokens))
            scored.append((overlap, s.strip(), rc))
    if not scored:
        return "The handbook does not specify that."
    topk = nlargest(2, scored, key=lambda x: x[0])
    parts = []
    for score, sent, rc in topk:
        cite = f"(Source: {rc.get('filename','handbook')}, page {rc.get('page',0)})"
        short_sent = sent if len(sent) <= 400 else sent[:397] + "..."
        parts.append(f"\"{short_sent}\" {cite}")
    final = "\n\n".join(parts)
    final += "\n\nTakeaway: Refer to the cited section(s) above for the official handbook wording."
    return final

# =========================
# Generation with HF fallback
# =========================
def try_hf_generate(prompt: str) -> Tuple[bool, str]:
    """
    Try various HF generation endpoints. Returns (success, text_or_error).
    Handles different InferenceClient versions gracefully.
    """
    if not hf_client:
        return False, "No HF client"
    # 1) text_generation method
    try:
        if hasattr(hf_client, "text_generation"):
            out = hf_client.text_generation(model=DEFAULT_MODEL, inputs=prompt, max_new_tokens=400, temperature=0.25)
            # out may be dict or list depending on client
            if isinstance(out, dict) and "generated_text" in out:
                return True, out["generated_text"]
            if isinstance(out, list) and out and "generated_text" in out[0]:
                return True, out[0]["generated_text"]
            return True, str(out)
    except Exception as e:
        # ignore and fallback
        pass
    # 2) chat style: try common patterns
    try:
        # Some clients expose hf_client.chat()
        if hasattr(hf_client, "chat"):
            resp = hf_client.chat(model=DEFAULT_MODEL, messages=[{"role":"user","content":prompt}], max_tokens=400, temperature=0.25)
            # try to extract common shapes
            if isinstance(resp, dict) and "choices" in resp:
                try:
                    return True, resp["choices"][0]["message"]["content"]
                except Exception:
                    return True, str(resp)
            if isinstance(resp, list) and resp and isinstance(resp[0], dict) and "generated_text" in resp[0]:
                return True, resp[0]["generated_text"]
            return True, str(resp)
        # Some clients have chat.completions.create()
        if hasattr(hf_client, "chat") and hasattr(hf_client.chat, "completions") and hasattr(hf_client.chat.completions, "create"):
            resp = hf_client.chat.completions.create(model=DEFAULT_MODEL, messages=[{"role":"user","content":prompt}], max_tokens=400, temperature=0.25)
            try:
                return True, resp.choices[0].message["content"]
            except Exception:
                return True, str(resp)
        # Last resort: some clients have 'create' on top-level
        if hasattr(hf_client, "create"):
            resp = hf_client.create(model=DEFAULT_MODEL, inputs=prompt, max_new_tokens=400, temperature=0.25)
            if isinstance(resp, dict) and "generated_text" in resp:
                return True, resp["generated_text"]
            return True, str(resp)
    except Exception as e:
        return False, e
    return False, "No known generation method"

def generate_answer(context: str, query: str, retrieved_chunks: List[Dict[str, Any]] = None) -> str:
    """
    Attempt to call HF generation; if that fails, fallback to extractive, citation-backed answer.
    Pass retrieved_chunks (list) so extractive fallback can cite pages.
    """
    prompt = f"""
You are a precise academic assistant specialized in university policies.
Use only the provided USTP Student Handbook content below. If the answer is not in the provided text, respond exactly:
"The handbook does not specify that."

Context:
{context}

Question: {query}

Provide a concise answer including source citations (filename + page).
"""
    success, out = try_hf_generate(prompt)
    if success:
        # if out is not str, ensure str
        return out if isinstance(out, str) else str(out)
    # HF failed (e.g., 402 or no credits) -> extractive fallback
    st.warning("HF generation unavailable β€” using extractive handbook-backed answer (no hallucination).")
    return extractive_answer_from_chunks(retrieved_chunks or [], query)

# =========================
# Index management (persist/load)
# =========================
def ensure_handbook_index(rebuild: bool = False):
    """
    Create or load index and metadata.
    Stores results in st.session_state as well for quick reuse.
    """
    # If already built and not rebuilding, return
    if st.session_state.get("handbook_ready") and not rebuild:
        return

    pdfs = find_handbook()
    if not pdfs:
        st.session_state.handbook_ready = False
        st.session_state.handbook_chunks = []
        return

    # if saved index exists & not rebuilding
    if not rebuild and os.path.exists(HAND_META_FN) and (faiss is not None and os.path.exists(HAND_INDEX_FN) and os.path.exists(HAND_EMB_DIM_FN)):
        try:
            idx, meta = load_index_and_metadata()
            if meta:
                st.session_state.faiss_index = idx
                st.session_state.metadata = meta
                st.session_state.handbook_ready = True
                st.success(f"Loaded saved index ({len(meta)} chunks).")
                return
        except Exception:
            pass

    # extract pages -> chunks
    pages = load_pdf_texts_with_page_info(pdfs)
    chunks = chunk_pages_into_segments(pages, chunk_size=int(chunk_size_chars), overlap=int(chunk_overlap))
    if not chunks:
        st.error("No text found in PDFs.")
        st.session_state.handbook_ready = False
        return

    # build index (this will attempt HF embeddings -> local -> TFIDF)
    try:
        idx, meta = build_faiss_index(chunks)
        st.session_state.faiss_index = idx
        st.session_state.metadata = meta
        st.session_state.handbook_ready = True
        st.success(f"Indexed {len(meta)} chunks.")
    except Exception as e:
        st.error(f"Failed to build index: {e}")
        # as fallback, store chunks in session
        st.session_state.metadata = chunks
        st.session_state.faiss_index = None
        st.session_state.handbook_ready = True

# build / load index
ensure_handbook_index(rebuild=regenerate_index)

# =========================
# Chat UI
# =========================
st.divider()
st.subheader("πŸ’¬ Ask the handbook (only handbook-based answers)")

if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Input and handling
user_query = st.chat_input("Ask a question about the handbook...")
if user_query:
    ts = int(time.time() * 1000)
    st.session_state.chat_history.append({"role": "user", "content": user_query, "ts": ts})
    # Retrieve top chunks
    index = st.session_state.get("faiss_index")
    metadata = st.session_state.get("metadata", [])
    with st.spinner("πŸ”Ž Retrieving relevant handbook excerpts..."):
        retrieved, scores = retrieve_top_chunks(user_query, k=int(top_k), metadata=metadata, index=index)
    # Reject if no good match
    if not retrieved or (scores and max(scores) < float(similarity_threshold)):
        reply = "Sorry, I can only answer questions based on the school's handbook. I couldn't find relevant information in the handbook for your question."
        st.session_state.chat_history.append({"role": "assistant", "content": reply, "ts": int(time.time() * 1000)})
    else:
        # Build context snippet for model (concise)
        context_text = "\n\n".join([f"--- {r['chunk_id']} | {r['filename']} | page {r['page']} ---\n{r['content']}" if 'chunk_id' in r else f"(Page {r.get('page')})\n{r.get('content')}" for r in retrieved])
        # Query model or fallback extractive
        with st.spinner("πŸ€– Generating answer..."):
            ans = generate_answer(context_text, user_query, retrieved_chunks=retrieved)
        # Append citation block
        citations = "\n".join([f"{r.get('chunk_id', 'n/a')} β€” {r.get('filename')} p{r.get('page')} (score {float(s):.3f})" for r, s in zip(retrieved, scores or [])])
        final = f"{ans}\n\n**Retrieved sources (top results):**\n{citations}"
        st.session_state.chat_history.append({"role": "assistant", "content": final, "ts": int(time.time() * 1000)})

# Display chat history with unique keys
st.divider()
st.subheader("Conversation")
for i, entry in enumerate(st.session_state.chat_history):
    is_user = entry.get("role") == "user"
    # use ts and i to ensure uniqueness across identical messages
    key = f"msg_{i}_{entry.get('ts',0)}"
    st_message(entry["content"], is_user=is_user, key=key)

# Toolbar
st.divider()
col1, col2 = st.columns([1, 1])
with col1:
    if st.button("πŸ”„ Reset chat"):
        st.session_state.chat_history = []
        st.success("Chat reset.")
with col2:
    transcript = "\n\n".join([f"{m['role'].upper()}: {m['content']}" for m in st.session_state.chat_history])
    st.download_button("πŸ“₯ Download transcript", data=transcript, file_name="handbook_transcript.txt")

st.caption("⚑ FAISS + Local embeddings + Hugging Face (when available). Default model: Qwen 2.5 14B")