arahrooh commited on
Commit
3c837c8
·
1 Parent(s): c5842f1

Add RAG chatbot functionality with OAuth authentication

Browse files

- Modified app.py to use OAuth token pattern from template
- Added bot.py with RAG functionality
- Added requirements.txt with all dependencies
- Added chroma_db vector database
- Updated README.md with full description and usage instructions

.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.sqlite3 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Cgt 3
3
- emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
@@ -13,4 +13,58 @@ hf_oauth_scopes:
13
  license: mit
14
  ---
15
 
16
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CGT-LLM-Beta RAG Chatbot
3
+ emoji: 🧬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
13
  license: mit
14
  ---
15
 
16
+ # 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot
17
+
18
+ A Retrieval-Augmented Generation (RAG) chatbot for genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics.
19
+
20
+ ## Features
21
+
22
+ - **RAG System**: Provides evidence-based answers from medical literature using vector database retrieval
23
+ - **Multiple Models**: Choose from various LLM models (Llama, Mistral, MediPhi, etc.)
24
+ - **Education Level Adaptation**: Answers are tailored to different education levels (Middle School, High School, College, Doctoral)
25
+ - **Source Citations**: View retrieved document chunks with similarity scores
26
+ - **Readability Scoring**: Flesch-Kincaid grade level scores for each answer
27
+ - **OAuth Authentication**: Secure access using Hugging Face OAuth tokens
28
+
29
+ ## How to Use
30
+
31
+ 1. **Log in**: Click the "Login" button in the sidebar to authenticate with your Hugging Face account
32
+ 2. **Ask a question**: Enter your question about genetic counseling, hereditary cancer, or related topics
33
+ 3. **Select options**:
34
+ - Choose your preferred LLM model
35
+ - Select your education level for personalized answers
36
+ - Adjust advanced settings (retrieval count, temperature, max tokens)
37
+ 4. **View results**: See the answer, readability score, source documents, and similarity scores
38
+
39
+ ## Example Questions
40
+
41
+ The chatbot includes 100+ example questions covering topics like:
42
+ - BRCA1/BRCA2 mutations and cancer risk
43
+ - Lynch Syndrome (MLH1, MSH2, MSH6, PMS2, EPCAM)
44
+ - Genetic testing recommendations
45
+ - Family communication about genetic results
46
+ - Insurance and legal considerations (GINA)
47
+ - Screening and prevention strategies
48
+
49
+ ## Technical Details
50
+
51
+ - **Vector Database**: ChromaDB for fast semantic search
52
+ - **Embeddings**: Sentence-transformers (all-MiniLM-L6-v2)
53
+ - **Inference**: Hugging Face Inference API (via OAuth)
54
+ - **Interface**: Gradio 5.42.0
55
+
56
+ ## Important Notes
57
+
58
+ ⚠️ **Medical Disclaimer**: This chatbot provides informational answers based on medical literature. It is not a substitute for professional medical advice, diagnosis, or treatment. Always consult with qualified healthcare providers for medical decisions.
59
+
60
+ ## Resources
61
+
62
+ The chatbot's knowledge base includes:
63
+ - NCCN Guidelines
64
+ - Medical literature on hereditary cancer syndromes
65
+ - Genetic counseling resources
66
+ - Patient education materials
67
+
68
+ ---
69
+
70
+ Built with [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py CHANGED
@@ -1,70 +1,922 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
 
 
 
 
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- response = ""
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  temperature=temperature,
32
  top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  minimum=0.1,
55
  maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  if __name__ == "__main__":
 
 
70
  demo.launch()
 
1
+ """
2
+ Gradio Chatbot Interface for CGT-LLM-Beta RAG System
3
+
4
+ This application provides a web interface for the RAG chatbot with OAuth authentication.
5
+ It uses Hugging Face Inference API with OAuth tokens for authentication.
6
+ """
7
+
8
  import gradio as gr
9
+ import argparse
10
+ import sys
11
+ import os
12
+ from typing import Tuple, Optional, List
13
+ import logging
14
+ import textstat
15
+ import torch
16
+
17
+ # Set up logging first (before any logger usage)
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
+ # Import from bot.py - wrap in try/except to handle import errors gracefully
22
+ try:
23
+ from bot import RAGBot, parse_args, Chunk
24
+ BOT_AVAILABLE = True
25
+ except ImportError as e:
26
+ logger.error(f"Failed to import bot module: {e}")
27
+ BOT_AVAILABLE = False
28
+ # Create dummy classes so the module can still load
29
+ class RAGBot:
30
+ pass
31
+ class Chunk:
32
+ pass
33
+ def parse_args():
34
+ return None
35
 
36
+ # For Hugging Face Inference API
37
+ try:
38
+ from huggingface_hub import InferenceClient
39
+ HF_INFERENCE_AVAILABLE = True
40
+ except ImportError:
41
+ HF_INFERENCE_AVAILABLE = False
42
+ logger.warning("huggingface_hub not available, InferenceClient will not work")
 
 
 
 
 
 
43
 
44
+ # Model mapping: short name -> full HuggingFace path
45
+ MODEL_MAP = {
46
+ "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
47
+ "Mistral-7B-Instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
48
+ "Llama-4-Scout-17B-16E-Instruct": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
49
+ "MediPhi-Instruct": "microsoft/MediPhi-Instruct",
50
+ "MediPhi": "microsoft/MediPhi",
51
+ "Phi-4-reasoning": "microsoft/Phi-4-reasoning",
52
+ }
53
 
54
+ # Education level mapping
55
+ EDUCATION_LEVELS = {
56
+ "Middle School": "middle_school",
57
+ "High School": "high_school",
58
+ "College": "college",
59
+ "Doctoral": "doctoral"
60
+ }
61
 
62
+ # Example questions from the results CSV (hardcoded for easy access)
63
+ EXAMPLE_QUESTIONS = [
64
+ "Can a BRCA2 variant skip a generation?",
65
+ "Can a PMS2 variant skip a generation?",
66
+ "Can an EPCAM/MSH2 variant skip a generation?",
67
+ "Can an MLH1 variant skip a generation?",
68
+ "Can an MSH2 variant skip a generation?",
69
+ "Can an MSH6 variant skip a generation?",
70
+ "Can I pass this MSH2 variant to my kids?",
71
+ "Can only women carry a BRCA inherited mutation?",
72
+ "Does GINA cover life or disability insurance?",
73
+ "Does having a BRCA1 mutation mean I will definitely have cancer?",
74
+ "Does having a BRCA2 mutation mean I will definitely have cancer?",
75
+ "Does having a PMS2 mutation mean I will definitely have cancer?",
76
+ "Does having an EPCAM/MSH2 mutation mean I will definitely have cancer?",
77
+ "Does having an MLH1 mutation mean I will definitely have cancer?",
78
+ "Does having an MSH2 mutation mean I will definitely have cancer?",
79
+ "Does having an MSH6 mutation mean I will definitely have cancer?",
80
+ "Does this BRCA1 genetic variant affect my cancer treatment?",
81
+ "Does this BRCA2 genetic variant affect my cancer treatment?",
82
+ "Does this EPCAM/MSH2 genetic variant affect my cancer treatment?",
83
+ "Does this MLH1 genetic variant affect my cancer treatment?",
84
+ "Does this MSH2 genetic variant affect my cancer treatment?",
85
+ "Does this MSH6 genetic variant affect my cancer treatment?",
86
+ "Does this PMS2 genetic variant affect my cancer treatment?",
87
+ "How can I cope with this diagnosis?",
88
+ "How can I get my kids tested?",
89
+ "How can I help others with my condition?",
90
+ "How might my genetic test results change over time?",
91
+ "I don't talk to my family/parents/sister/brother. How can I share this with them?",
92
+ "I have a BRCA pathogenic variant and I want to have children, what are my options?",
93
+ "Is genetic testing for my family members covered by insurance?",
94
+ "Is new research being done on my condition?",
95
+ "Is this BRCA1 variant something I inherited?",
96
+ "Is this BRCA2 variant something I inherited?",
97
+ "Is this EPCAM/MSH2 variant something I inherited?",
98
+ "Is this MLH1 variant something I inherited?",
99
+ "Is this MSH2 variant something I inherited?",
100
+ "Is this MSH6 variant something I inherited?",
101
+ "Is this PMS2 variant something I inherited?",
102
+ "My relative doesn't have insurance. What should they do?",
103
+ "People who test positive for a genetic mutation are they at risk of losing their health insurance?",
104
+ "Should I contact my male and female relatives?",
105
+ "Should my family members get tested?",
106
+ "What are the Risks and Benefits of Risk-Reducing Surgeries for Lynch Syndrome?",
107
+ "What are the recommendations for my family members if I have a BRCA1 mutation?",
108
+ "What are the recommendations for my family members if I have a BRCA2 mutation?",
109
+ "What are the recommendations for my family members if I have a PMS2 mutation?",
110
+ "What are the recommendations for my family members if I have an EPCAM/MSH2 mutation?",
111
+ "What are the recommendations for my family members if I have an MLH1 mutation?",
112
+ "What are the recommendations for my family members if I have an MSH2 mutation?",
113
+ "What are the recommendations for my family members if I have an MSH6 mutation?",
114
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have a BRCA mutation?",
115
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an EPCAM/MSH2 mutation?",
116
+ "What are the surveillance and preventions I can take to reduce my risk of cancer or detecting cancer early if I have an MSH2 mutation?",
117
+ "What does a BRCA1 genetic variant mean for me?",
118
+ "What does a BRCA2 genetic variant mean for me?",
119
+ "What does a PMS2 genetic variant mean for me?",
120
+ "What does an EPCAM/MSH2 genetic variant mean for me?",
121
+ "What does an MLH1 genetic variant mean for me?",
122
+ "What does an MSH2 genetic variant mean for me?",
123
+ "What does an MSH6 genetic variant mean for me?",
124
+ "What if I feel overwhelmed?",
125
+ "What if I want to have children and have a hereditary cancer gene? What are my reproductive options?",
126
+ "What if a family member doesn't want to get tested?",
127
+ "What is Lynch Syndrome?",
128
+ "What is my cancer risk if I have BRCA1 Hereditary Breast and Ovarian Cancer syndrome?",
129
+ "What is my cancer risk if I have BRCA2 Hereditary Breast and Ovarian Cancer syndrome?",
130
+ "What is my cancer risk if I have MLH1 Lynch syndrome?",
131
+ "What is my cancer risk if I have MSH2 or EPCAM-associated Lynch syndrome?",
132
+ "What is my cancer risk if I have MSH6 Lynch syndrome?",
133
+ "What is my cancer risk if I have PMS2 Lynch syndrome?",
134
+ "What other resources are available to help me?",
135
+ "What screening tests do you recommend for BRCA1 carriers?",
136
+ "What screening tests do you recommend for BRCA2 carriers?",
137
+ "What screening tests do you recommend for EPCAM/MSH2 carriers?",
138
+ "What screening tests do you recommend for MLH1 carriers?",
139
+ "What screening tests do you recommend for MSH2 carriers?",
140
+ "What screening tests do you recommend for MSH6 carriers?",
141
+ "What screening tests do you recommend for PMS2 carriers?",
142
+ "What steps can I take to manage my cancer risk if I have Lynch syndrome?",
143
+ "What types of cancers am I at risk for with a BRCA1 mutation?",
144
+ "What types of cancers am I at risk for with a BRCA2 mutation?",
145
+ "What types of cancers am I at risk for with a PMS2 mutation?",
146
+ "What types of cancers am I at risk for with an EPCAM/MSH2 mutation?",
147
+ "What types of cancers am I at risk for with an MLH1 mutation?",
148
+ "What types of cancers am I at risk for with an MSH2 mutation?",
149
+ "What types of cancers am I at risk for with an MSH6 mutation?",
150
+ "Where can I find a genetic counselor?",
151
+ "Which of my relatives are at risk?",
152
+ "Who are my first-degree relatives?",
153
+ "Who do my family members call to have genetic testing?",
154
+ "Why do some families with Lynch syndrome have more cases of cancer than others?",
155
+ "Why should I share my BRCA1 genetic results with family?",
156
+ "Why should I share my BRCA2 genetic results with family?",
157
+ "Why should I share my EPCAM/MSH2 genetic results with family?",
158
+ "Why should I share my MLH1 genetic results with family?",
159
+ "Why should I share my MSH2 genetic results with family?",
160
+ "Why should I share my MSH6 genetic results with family?",
161
+ "Why should I share my PMS2 genetic results with family?",
162
+ "Why would my relatives want to know if they have this? What can they do about it?",
163
+ "Will my insurance cover testing for my parents/brother/sister?",
164
+ "Will this affect my health insurance?",
165
+ ]
166
 
 
167
 
168
+ class InferenceAPIBot:
169
+ """Wrapper that uses Hugging Face Inference API with OAuth token"""
170
+
171
+ def __init__(self, bot: RAGBot):
172
+ """Initialize with a RAGBot (for vector DB)"""
173
+ self.bot = bot # Use bot for vector DB and formatting
174
+ self.current_model = bot.args.model
175
+ logger.info(f"InferenceAPIBot initialized with model: {self.current_model}")
176
+
177
+ def _get_client(self, hf_token: Optional[str] = None) -> InferenceClient:
178
+ """Create InferenceClient with token (can be None for public models)"""
179
+ if hf_token:
180
+ return InferenceClient(token=hf_token)
181
+ else:
182
+ # Try without token (works for public models)
183
+ return InferenceClient()
184
+
185
+ @property
186
+ def args(self):
187
+ """Access args from the wrapped bot"""
188
+ return self.bot.args
189
+
190
+ def generate_answer(self, prompt: str, hf_token: Optional[str] = None, **kwargs) -> str:
191
+ """Generate answer using Inference API"""
192
+ try:
193
+ max_tokens = kwargs.get('max_new_tokens', 512)
194
+ temperature = kwargs.get('temperature', 0.2)
195
+ top_p = kwargs.get('top_p', 0.9)
196
+
197
+ # Create client with token
198
+ client = self._get_client(hf_token)
199
+
200
+ # Use text_generation API directly
201
+ logger.info(f"Calling Inference API for model: {self.current_model}")
202
+ response = client.text_generation(
203
+ prompt,
204
+ model=self.current_model,
205
+ max_new_tokens=max_tokens,
206
  temperature=temperature,
207
  top_p=top_p,
208
+ return_full_text=False,
209
+ )
210
+ logger.info(f"Inference API response received (length: {len(response) if response else 0})")
211
+ return response
212
+ except Exception as e:
213
+ logger.error(f"Error calling Inference API: {e}", exc_info=True)
214
+ import traceback
215
+ logger.error(f"Traceback: {traceback.format_exc()}")
216
+ return f"Error generating answer: {str(e)}. Please check the logs for details."
217
+
218
+ def enhance_readability(self, answer: str, target_level: str = "middle_school", hf_token: Optional[str] = None) -> Tuple[str, float]:
219
+ """Enhance readability using Inference API"""
220
+ try:
221
+ # Define prompts for different reading levels
222
+ if target_level == "middle_school":
223
+ level_description = "middle school reading level (ages 12-14, 6th-8th grade)"
224
+ instructions = """
225
+ - Use simpler medical terms or explain them
226
+ - Medium-length sentences
227
+ - Clear, structured explanations
228
+ - Keep important medical information accessible"""
229
+ elif target_level == "high_school":
230
+ level_description = "high school reading level (ages 15-18, 9th-12th grade)"
231
+ instructions = """
232
+ - Use appropriate medical terminology with context
233
+ - Varied sentence length
234
+ - Comprehensive yet accessible explanations
235
+ - Maintain technical accuracy while ensuring clarity"""
236
+ elif target_level == "college":
237
+ level_description = "college reading level (undergraduate level, ages 18-22)"
238
+ instructions = """
239
+ - Use standard medical terminology with brief explanations
240
+ - Professional and clear writing style
241
+ - Include relevant clinical context
242
+ - Maintain scientific accuracy and precision
243
+ - Appropriate for undergraduate students in health sciences"""
244
+ elif target_level == "doctoral":
245
+ level_description = "doctoral/professional reading level (graduate level, medical professionals)"
246
+ instructions = """
247
+ - Use advanced medical and scientific terminology
248
+ - Include detailed clinical and research context
249
+ - Reference specific mechanisms, pathways, and evidence
250
+ - Provide comprehensive technical explanations
251
+ - Appropriate for medical professionals, researchers, and graduate students
252
+ - Include nuanced discussions of clinical implications and research findings"""
253
+ else:
254
+ raise ValueError(f"Unknown target_level: {target_level}")
255
+
256
+ system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}:
257
+ {instructions}
258
+ - Keep the same important information but adapt the complexity
259
+ - Provide context for technical terms
260
+ - Ensure the answer is informative yet understandable"""
261
+
262
+ user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}"
263
+
264
+ # Combine system and user messages for text generation
265
+ combined_prompt = f"{system_message}\n\n{user_message}"
266
+ logger.info(f"Enhancing readability for {target_level} level")
267
+
268
+ # Create client with token
269
+ client = self._get_client(hf_token)
270
+
271
+ max_tokens = 512 if target_level in ["college", "doctoral"] else 384
272
+ temperature = 0.4 if target_level in ["college", "doctoral"] else 0.3
273
+
274
+ enhanced_answer = client.text_generation(
275
+ combined_prompt,
276
+ model=self.current_model,
277
+ max_new_tokens=max_tokens,
278
+ temperature=temperature,
279
+ return_full_text=False,
280
+ )
281
+ # Clean the answer (same as bot.py)
282
+ cleaned = self.bot._clean_readability_answer(enhanced_answer, target_level)
283
+
284
+ # Calculate Flesch score
285
+ try:
286
+ flesch_score = textstat.flesch_kincaid_grade(cleaned)
287
+ except:
288
+ flesch_score = 0.0
289
+
290
+ return cleaned, flesch_score
291
+ except Exception as e:
292
+ logger.error(f"Error enhancing readability: {e}", exc_info=True)
293
+ return answer, 0.0
294
+
295
+ # Delegate other methods to bot
296
+ def format_prompt(self, context_chunks: List[Chunk], question: str) -> str:
297
+ return self.bot.format_prompt(context_chunks, question)
298
+
299
+ def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]:
300
+ return self.bot.retrieve_with_scores(query, k)
301
+
302
+ def _categorize_question(self, question: str) -> str:
303
+ return self.bot._categorize_question(question)
304
+
305
+ @property
306
+ def vector_retriever(self):
307
+ return self.bot.vector_retriever
308
 
 
 
309
 
310
+ class GradioRAGInterface:
311
+ """Wrapper class to integrate RAGBot with Gradio using OAuth"""
312
+
313
+ def __init__(self, initial_bot: RAGBot):
314
+ # Always use Inference API on Spaces
315
+ if HF_INFERENCE_AVAILABLE:
316
+ self.bot = InferenceAPIBot(initial_bot)
317
+ self.use_inference_api = True
318
+ logger.info("Using Hugging Face Inference API with OAuth")
319
+ else:
320
+ self.bot = initial_bot
321
+ self.use_inference_api = False
322
+ logger.warning("Inference API not available, falling back to local model")
323
+
324
+ # Get current model from bot args
325
+ self.current_model = self.bot.args.model if hasattr(self.bot, 'args') else getattr(self.bot, 'current_model', None)
326
+ if self.current_model is None and hasattr(self.bot, 'bot'):
327
+ self.current_model = self.bot.bot.args.model
328
+ self.data_dir = initial_bot.args.data_dir
329
+ logger.info("GradioRAGInterface initialized")
330
+
331
+ def _find_file_path(self, filename: str) -> str:
332
+ """Find the full file path for a given filename"""
333
+ from pathlib import Path
334
+ data_path = Path(self.data_dir)
335
+
336
+ if not data_path.exists():
337
+ return ""
338
+
339
+ # Search for the file recursively
340
+ for file_path in data_path.rglob(filename):
341
+ return str(file_path)
342
+
343
+ return ""
344
+
345
+ def reload_model(self, model_short_name: str) -> str:
346
+ """Reload the model when user selects a different one"""
347
+ if model_short_name not in MODEL_MAP:
348
+ return f"Error: Unknown model '{model_short_name}'"
349
+
350
+ new_model_path = MODEL_MAP[model_short_name]
351
+
352
+ # If same model, no need to reload
353
+ if new_model_path == self.current_model:
354
+ return f"Model already loaded: {model_short_name}"
355
+
356
+ try:
357
+ logger.info(f"Switching model from {self.current_model} to {new_model_path}")
358
+
359
+ if self.use_inference_api:
360
+ # For Inference API, just update the model name
361
+ self.bot.current_model = new_model_path
362
+ self.current_model = new_model_path
363
+ return f"✓ Model switched to: {model_short_name} (using Inference API)"
364
+ else:
365
+ # For local model, reload it
366
+ self.bot.args.model = new_model_path
367
+
368
+ # Clear old model from memory
369
+ if hasattr(self.bot, 'model') and self.bot.model is not None:
370
+ del self.bot.model
371
+ del self.bot.tokenizer
372
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
373
+
374
+ # Load new model
375
+ self.bot._load_model()
376
+ self.current_model = new_model_path
377
+
378
+ return f"✓ Model loaded: {model_short_name}"
379
+ except Exception as e:
380
+ logger.error(f"Error reloading model: {e}", exc_info=True)
381
+ return f"✗ Error loading model: {str(e)}"
382
+
383
+ def process_question(
384
+ self,
385
+ question: str,
386
+ model_name: str,
387
+ education_level: str,
388
+ k: int,
389
+ temperature: float,
390
+ max_tokens: int,
391
+ hf_token: Optional[str] = None
392
+ ) -> Tuple[str, str, str, str, str]:
393
+ """
394
+ Process a single question and return formatted results
395
+
396
+ Returns:
397
+ Tuple of (answer, flesch_score, sources, similarity_scores, question_category)
398
+ """
399
+ import time
400
+
401
+ if not question or not question.strip():
402
+ return "Please enter a question.", "N/A", "", "", ""
403
+
404
+ # Check if token is provided when using Inference API
405
+ if self.use_inference_api and not hf_token:
406
+ return (
407
+ "⚠️ **Authentication Required**\n\nPlease log in using the Hugging Face login button in the sidebar to use the Inference API.",
408
+ "N/A",
409
+ "",
410
+ "",
411
+ "Error"
412
+ )
413
+
414
+ try:
415
+ start_time = time.time()
416
+ logger.info(f"Processing question: {question[:50]}...")
417
+
418
+ # Reload model if changed
419
+ if model_name in MODEL_MAP:
420
+ model_path = MODEL_MAP[model_name]
421
+ if model_path != self.current_model:
422
+ logger.info(f"Model changed, reloading from {self.current_model} to {model_path}")
423
+ reload_status = self.reload_model(model_name)
424
+ if reload_status.startswith("✗"):
425
+ return f"Error: {reload_status}", "N/A", "", "", ""
426
+ logger.info(f"Model reloaded in {time.time() - start_time:.1f}s")
427
+
428
+ # Update bot args for this query
429
+ self.bot.args.k = k
430
+ self.bot.args.temperature = temperature
431
+ self.bot.args.max_new_tokens = min(max_tokens, 512) # Cap at 512 for faster responses
432
+
433
+ # Categorize question
434
+ logger.info("Categorizing question...")
435
+ question_group = self.bot._categorize_question(question)
436
+
437
+ # Retrieve relevant chunks with similarity scores
438
+ logger.info("Retrieving relevant documents...")
439
+ retrieve_start = time.time()
440
+ context_chunks, similarity_scores = self.bot.retrieve_with_scores(question, k)
441
+ logger.info(f"Retrieved {len(context_chunks)} chunks in {time.time() - retrieve_start:.2f}s")
442
+
443
+ if not context_chunks:
444
+ return (
445
+ "I don't have enough information to answer this question. Please try rephrasing or asking about a different topic.",
446
+ "N/A",
447
+ "No sources found",
448
+ "No matches found",
449
+ question_group
450
+ )
451
+
452
+ # Format similarity scores
453
+ similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores])
454
+
455
+ # Format sources with chunk text and file paths
456
+ sources_list = []
457
+ for i, (chunk, score) in enumerate(zip(context_chunks, similarity_scores)):
458
+ file_path = self._find_file_path(chunk.filename)
459
+
460
+ source_info = f"""
461
+ {'='*80}
462
+ SOURCE {i+1} | Similarity: {score:.3f}
463
+ {'='*80}
464
+ 📄 File: {chunk.filename}
465
+ 📍 Path: {file_path if file_path else 'File path not found (search in Data Resources directory)'}
466
+ 📊 Chunk: {chunk.chunk_id + 1}/{chunk.total_chunks} (Position: {chunk.start_pos}-{chunk.end_pos})
467
+
468
+ 📝 Full Chunk Text:
469
+ {chunk.text}
470
 
471
  """
472
+ sources_list.append(source_info)
473
+
474
+ sources = "\n".join(sources_list)
475
+
476
+ # Generation kwargs
477
+ gen_kwargs = {
478
+ 'max_new_tokens': min(max_tokens, 512),
479
+ 'temperature': temperature,
480
+ 'top_p': self.bot.args.top_p,
481
+ 'repetition_penalty': self.bot.args.repetition_penalty
482
+ }
483
+
484
+ # Generate answer based on education level
485
+ answer = ""
486
+ flesch_score = 0.0
487
+
488
+ # Generate original answer first
489
+ logger.info("Generating original answer...")
490
+ gen_start = time.time()
491
+ prompt = self.bot.format_prompt(context_chunks, question)
492
+ original_answer = self.bot.generate_answer(prompt, hf_token=hf_token, **gen_kwargs)
493
+ logger.info(f"Original answer generated in {time.time() - gen_start:.1f}s")
494
+
495
+ # Enhance based on education level
496
+ logger.info(f"Enhancing answer for {education_level} level...")
497
+ enhance_start = time.time()
498
+ if education_level == "middle_school":
499
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="middle_school", hf_token=hf_token)
500
+ elif education_level == "high_school":
501
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="high_school", hf_token=hf_token)
502
+ elif education_level == "college":
503
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="college", hf_token=hf_token)
504
+ elif education_level == "doctoral":
505
+ answer, flesch_score = self.bot.enhance_readability(original_answer, target_level="doctoral", hf_token=hf_token)
506
+ else:
507
+ answer = "Invalid education level selected."
508
+ flesch_score = 0.0
509
+
510
+ logger.info(f"Answer enhanced in {time.time() - enhance_start:.1f}s")
511
+ total_time = time.time() - start_time
512
+ logger.info(f"Total processing time: {total_time:.1f}s")
513
+
514
+ # Clean the answer - remove special tokens and formatting
515
+ import re
516
+ cleaned_answer = answer
517
+
518
+ # Remove special tokens (case-insensitive)
519
+ special_tokens = [
520
+ "<|end|>",
521
+ "<|endoftext|>",
522
+ "<|end_of_text|>",
523
+ "<|eot_id|>",
524
+ "<|start_header_id|>",
525
+ "<|end_header_id|>",
526
+ "<|assistant|>",
527
+ "<|endoftext|>",
528
+ "<|end_of_text|>",
529
+ ]
530
+ for token in special_tokens:
531
+ cleaned_answer = re.sub(re.escape(token), '', cleaned_answer, flags=re.IGNORECASE)
532
+
533
+ # Remove any remaining special token patterns
534
+ cleaned_answer = re.sub(r'<\|[^|]+\|>', '', cleaned_answer)
535
+ cleaned_answer = re.sub(r'^\*\*.*?\*\*.*?\n', '', cleaned_answer, flags=re.MULTILINE)
536
+ cleaned_answer = re.sub(r'\n\s*\n\s*\n+', '\n\n', cleaned_answer)
537
+ cleaned_answer = re.sub(r'^\s+|\s+$', '', cleaned_answer, flags=re.MULTILINE)
538
+ cleaned_answer = cleaned_answer.strip()
539
+
540
+ return (
541
+ cleaned_answer,
542
+ f"{flesch_score:.1f}",
543
+ sources,
544
+ similarity_scores_str,
545
+ question_group
546
+ )
547
+
548
+ except Exception as e:
549
+ logger.error(f"Error processing question: {e}", exc_info=True)
550
+ return (
551
+ f"An error occurred while processing your question: {str(e)}",
552
+ "N/A",
553
+ "",
554
+ "",
555
+ "Error"
556
+ )
557
+
558
+
559
+ def create_interface(initial_bot: RAGBot) -> gr.Blocks:
560
+ """Create and configure the Gradio interface with OAuth"""
561
+
562
+ try:
563
+ interface = GradioRAGInterface(initial_bot)
564
+ except Exception as e:
565
+ logger.error(f"Failed to create GradioRAGInterface: {e}")
566
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
567
+ gr.Markdown(f"""
568
+ # ⚠️ Initialization Error
569
+
570
+ Failed to initialize the chatbot interface.
571
+
572
+ **Error:** {str(e)}
573
+
574
+ Please check the logs for more details.
575
+ """)
576
+ return demo
577
+
578
+ # Get initial model name from bot
579
+ initial_model_short = None
580
+ for short_name, full_path in MODEL_MAP.items():
581
+ if full_path == initial_bot.args.model:
582
+ initial_model_short = short_name
583
+ break
584
+ if initial_model_short is None:
585
+ initial_model_short = list(MODEL_MAP.keys())[0]
586
+
587
+ # Create the Gradio interface
588
+ try:
589
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
590
+ with gr.Sidebar():
591
+ gr.LoginButton()
592
+ gr.Markdown("### 🔐 Authentication")
593
+ gr.Markdown("Please log in with your Hugging Face account to use the Inference API.")
594
+
595
+ gr.Markdown("""
596
+ # 🧬 CGT-LLM-Beta: Genetic Counseling RAG Chatbot
597
+
598
+ Ask questions about genetic counseling, cascade genetic testing, hereditary cancer syndromes, and related topics.
599
+
600
+ The chatbot uses a Retrieval-Augmented Generation (RAG) system to provide evidence-based answers from medical literature.
601
+ """)
602
+
603
+ with gr.Row():
604
+ with gr.Column(scale=2):
605
+ question_input = gr.Textbox(
606
+ label="Your Question",
607
+ placeholder="e.g., What is Lynch Syndrome? What screening is recommended for BRCA1 carriers?",
608
+ lines=3
609
+ )
610
+
611
+ with gr.Row():
612
+ model_dropdown = gr.Dropdown(
613
+ choices=list(MODEL_MAP.keys()),
614
+ value=initial_model_short,
615
+ label="Select Model",
616
+ info="Choose which LLM model to use for generating answers"
617
+ )
618
+
619
+ education_dropdown = gr.Dropdown(
620
+ choices=list(EDUCATION_LEVELS.keys()),
621
+ value=list(EDUCATION_LEVELS.keys())[0],
622
+ label="Education Level",
623
+ info="Select your education level for personalized answers"
624
+ )
625
+
626
+ with gr.Accordion("Advanced Settings", open=False):
627
+ k_slider = gr.Slider(
628
+ minimum=1,
629
+ maximum=10,
630
+ value=5,
631
+ step=1,
632
+ label="Number of document chunks to retrieve (k)"
633
+ )
634
+ temperature_slider = gr.Slider(
635
  minimum=0.1,
636
  maximum=1.0,
637
+ value=0.2,
638
+ step=0.1,
639
+ label="Temperature (lower = more focused)"
640
+ )
641
+ max_tokens_slider = gr.Slider(
642
+ minimum=128,
643
+ maximum=1024,
644
+ value=512,
645
+ step=128,
646
+ label="Max Tokens (lower = faster responses)"
647
+ )
648
+
649
+ submit_btn = gr.Button("Ask Question", variant="primary", size="lg")
650
+
651
+ with gr.Column(scale=3):
652
+ answer_output = gr.Textbox(
653
+ label="Answer",
654
+ lines=20,
655
+ interactive=False,
656
+ elem_classes=["answer-box"]
657
+ )
658
+
659
+ with gr.Row():
660
+ flesch_output = gr.Textbox(
661
+ label="Flesch-Kincaid Grade Level",
662
+ value="N/A",
663
+ interactive=False,
664
+ scale=1
665
+ )
666
+
667
+ similarity_output = gr.Textbox(
668
+ label="Similarity Scores",
669
+ value="",
670
+ interactive=False,
671
+ scale=1
672
+ )
673
+
674
+ category_output = gr.Textbox(
675
+ label="Question Category",
676
+ value="",
677
+ interactive=False,
678
+ scale=1
679
+ )
680
+
681
+ sources_output = gr.Textbox(
682
+ label="Source Documents (with Chunk Text)",
683
+ lines=15,
684
+ interactive=False,
685
+ info="Shows the retrieved document chunks with full text. File paths are shown for easy access."
686
+ )
687
+
688
+ # Example questions
689
+ gr.Markdown("### 💡 Example Questions")
690
+ gr.Markdown(f"Select a question below to use it in the chatbot ({len(EXAMPLE_QUESTIONS)} questions - scrollable dropdown):")
691
+
692
+ example_questions_dropdown = gr.Dropdown(
693
+ choices=EXAMPLE_QUESTIONS,
694
+ label="Example Questions",
695
+ value=None,
696
+ info="Open the dropdown and scroll through all questions. Select one to use it.",
697
+ interactive=True,
698
+ container=True,
699
+ scale=1
700
+ )
701
+
702
+ def update_question_from_dropdown(selected_question):
703
+ return selected_question if selected_question else ""
704
+
705
+ example_questions_dropdown.change(
706
+ fn=update_question_from_dropdown,
707
+ inputs=example_questions_dropdown,
708
+ outputs=question_input
709
+ )
710
+
711
+ # Footer
712
+ gr.Markdown("""
713
+ ---
714
+ **Note:** This chatbot provides informational answers based on medical literature.
715
+ It is not a substitute for professional medical advice, diagnosis, or treatment.
716
+ Always consult with qualified healthcare providers for medical decisions.
717
+ """)
718
+
719
+ # Connect the submit button with OAuth token
720
+ def process_with_education_level(question, model, education, k, temp, max_tok, hf_token: gr.OAuthToken):
721
+ education_key = EDUCATION_LEVELS[education]
722
+ token = hf_token.token if hf_token else None
723
+ return interface.process_question(question, model, education_key, k, temp, max_tok, hf_token=token)
724
+
725
+ submit_btn.click(
726
+ fn=process_with_education_level,
727
+ inputs=[
728
+ question_input,
729
+ model_dropdown,
730
+ education_dropdown,
731
+ k_slider,
732
+ temperature_slider,
733
+ max_tokens_slider,
734
+ gr.OAuthToken()
735
+ ],
736
+ outputs=[
737
+ answer_output,
738
+ flesch_output,
739
+ sources_output,
740
+ similarity_output,
741
+ category_output
742
+ ]
743
+ )
744
+
745
+ # Also allow Enter key to submit
746
+ question_input.submit(
747
+ fn=process_with_education_level,
748
+ inputs=[
749
+ question_input,
750
+ model_dropdown,
751
+ education_dropdown,
752
+ k_slider,
753
+ temperature_slider,
754
+ max_tokens_slider,
755
+ gr.OAuthToken()
756
+ ],
757
+ outputs=[
758
+ answer_output,
759
+ flesch_output,
760
+ sources_output,
761
+ similarity_output,
762
+ category_output
763
+ ]
764
+ )
765
+ except Exception as interface_error:
766
+ logger.error(f"Error setting up Gradio interface components: {interface_error}", exc_info=True)
767
+ import traceback
768
+ error_trace = traceback.format_exc()
769
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
770
+ gr.Markdown(f"""
771
+ # ⚠️ Interface Setup Error
772
+
773
+ An error occurred while setting up the interface components.
774
+
775
+ **Error:** {str(interface_error)}
776
+
777
+ **Traceback:**
778
+ ```
779
+ {error_trace[:1000]}...
780
+ ```
781
+
782
+ Please check the logs for more details.
783
+ """)
784
+ return demo
785
+
786
+ logger.info("Gradio interface created successfully")
787
+ return demo
788
+
789
+
790
+ # Check if we're on Spaces
791
+ IS_SPACES = (
792
+ os.getenv("SPACE_ID") is not None or
793
+ os.getenv("SYSTEM") == "spaces" or
794
+ os.getenv("HF_SPACE_ID") is not None
795
  )
796
 
797
+ # Initialize demo variable
798
+ demo = None
799
+
800
+ def _create_demo():
801
+ """Create the demo - separated into function for better error handling"""
802
+ try:
803
+ logger.info("=" * 80)
804
+ logger.info("Starting demo creation...")
805
+ logger.info(f"IS_SPACES: {IS_SPACES}")
806
+ logger.info(f"BOT_AVAILABLE: {BOT_AVAILABLE}")
807
+
808
+ if not BOT_AVAILABLE:
809
+ raise ImportError("bot module is not available - cannot create demo")
810
+
811
+ # Initialize with default args
812
+ parser = argparse.ArgumentParser()
813
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct')
814
+ parser.add_argument('--vector-db-dir', default='./chroma_db')
815
+ parser.add_argument('--data-dir', default='./Data Resources')
816
+ parser.add_argument('--max-new-tokens', type=int, default=1024)
817
+ parser.add_argument('--temperature', type=float, default=0.2)
818
+ parser.add_argument('--top-p', type=float, default=0.9)
819
+ parser.add_argument('--repetition-penalty', type=float, default=1.1)
820
+ parser.add_argument('--k', type=int, default=5)
821
+ parser.add_argument('--skip-indexing', action='store_true', default=True)
822
+ parser.add_argument('--verbose', action='store_true', default=False)
823
+ parser.add_argument('--seed', type=int, default=42)
824
+
825
+ args = parser.parse_args([]) # Empty args
826
+ args.skip_model_loading = IS_SPACES # Skip model loading on Spaces, use Inference API
827
+
828
+ logger.info("Creating RAGBot...")
829
+ bot = RAGBot(args)
830
+
831
+ if bot.vector_retriever is None:
832
+ raise Exception("Vector database not available")
833
+
834
+ # Check if vector database has documents
835
+ collection_stats = bot.vector_retriever.get_collection_stats()
836
+ if collection_stats.get('total_chunks', 0) == 0:
837
+ logger.warning("Vector database is empty. The chatbot may not find relevant documents.")
838
+
839
+ logger.info("Creating interface...")
840
+ demo = create_interface(bot)
841
+ logger.info(f"Demo created successfully: {type(demo)}")
842
+ return demo
843
+
844
+ except Exception as bot_error:
845
+ logger.error(f"Error initializing: {bot_error}", exc_info=True)
846
+ import traceback
847
+ error_trace = traceback.format_exc()
848
+ logger.error(f"Full traceback: {error_trace}")
849
+
850
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as error_demo:
851
+ gr.Markdown(f"""
852
+ # ⚠️ Initialization Error
853
+
854
+ The chatbot encountered an error during initialization:
855
+
856
+ **Error:** {str(bot_error)}
857
+
858
+ **Possible causes:**
859
+ - Missing vector database (chroma_db directory)
860
+ - Missing dependencies
861
+ - Configuration issues
862
+
863
+ **Error Details:**
864
+ ```
865
+ {error_trace[:1000]}...
866
+ ```
867
+ """)
868
+ logger.info(f"Error demo created: {type(error_demo)}")
869
+ return error_demo
870
+
871
+ # Create demo at module level
872
+ try:
873
+ if IS_SPACES:
874
+ logger.info("Creating demo directly at module level for Spaces...")
875
+ else:
876
+ logger.info("Creating demo for local execution...")
877
+
878
+ demo = _create_demo()
879
+
880
+ if demo is None or not isinstance(demo, (gr.Blocks, gr.Interface)):
881
+ raise ValueError(f"Demo creation returned invalid result: {type(demo)}")
882
+
883
+ logger.info("Demo creation completed successfully")
884
+ except Exception as e:
885
+ logger.error(f"CRITICAL: Error creating demo: {e}", exc_info=True)
886
+ import traceback
887
+ error_trace = traceback.format_exc()
888
+ logger.error(f"Full traceback: {error_trace}")
889
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
890
+ gr.Markdown(f"""
891
+ # Error Initializing Chatbot
892
+
893
+ A critical error occurred while initializing the chatbot.
894
+
895
+ **Error:** {str(e)}
896
+
897
+ **Traceback:**
898
+ ```
899
+ {error_trace[:1500]}...
900
+ ```
901
+
902
+ Please check the logs for more details.
903
+ """)
904
+ logger.info(f"Fallback error demo created: {type(demo)}")
905
 
906
+ # Final verification
907
+ if demo is None:
908
+ logger.error("CRITICAL: Demo variable is None! Creating fallback demo.")
909
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
910
+ gr.Markdown("# Error: Demo was not created properly\n\nPlease check the logs for details.")
911
+ elif not isinstance(demo, (gr.Blocks, gr.Interface)):
912
+ logger.error(f"CRITICAL: Demo is not a valid Gradio object: {type(demo)}")
913
+ with gr.Blocks(title="CGT-LLM-Beta RAG Chatbot") as demo:
914
+ gr.Markdown(f"# Error: Invalid demo type\n\nDemo type: {type(demo)}\n\nPlease check the logs for details.")
915
+ else:
916
+ logger.info(f"✅ Final demo check passed: demo type={type(demo)}")
917
 
918
+ # For local execution only (not on Spaces)
919
  if __name__ == "__main__":
920
+ if not IS_SPACES:
921
+ # For local use, we can launch it
922
  demo.launch()
bot.py ADDED
@@ -0,0 +1,1777 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ RAG Chatbot Implementation for CGT-LLM-Beta with Vector Database
4
+ Production-ready local RAG system with ChromaDB and MPS acceleration for Apple Silicon
5
+ """
6
+
7
+ import argparse
8
+ import csv
9
+ import json
10
+ import logging
11
+ import os
12
+ import re
13
+ import sys
14
+ import time
15
+ import hashlib
16
+ from pathlib import Path
17
+ from typing import List, Tuple, Dict, Any, Optional, Union
18
+ from dataclasses import dataclass
19
+ from collections import defaultdict
20
+
21
+ import textstat
22
+
23
+ import torch
24
+ import numpy as np
25
+ import pandas as pd
26
+ from tqdm import tqdm
27
+
28
+ # Optional imports with graceful fallbacks
29
+ try:
30
+ import chromadb
31
+ from chromadb.config import Settings
32
+ CHROMADB_AVAILABLE = True
33
+ except ImportError:
34
+ CHROMADB_AVAILABLE = False
35
+ print("Warning: chromadb not available. Install with: pip install chromadb")
36
+
37
+ try:
38
+ from sentence_transformers import SentenceTransformer
39
+ SENTENCE_TRANSFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SENTENCE_TRANSFORMERS_AVAILABLE = False
42
+ print("Warning: sentence-transformers not available. Install with: pip install sentence-transformers")
43
+
44
+ try:
45
+ import pypdf
46
+ PDF_AVAILABLE = True
47
+ except ImportError:
48
+ PDF_AVAILABLE = False
49
+ print("Warning: pypdf not available. PDF files will be skipped.")
50
+
51
+ try:
52
+ from docx import Document
53
+ DOCX_AVAILABLE = True
54
+ except ImportError:
55
+ DOCX_AVAILABLE = False
56
+ print("Warning: python-docx not available. DOCX files will be skipped.")
57
+
58
+ try:
59
+ from rank_bm25 import BM25Okapi
60
+ BM25_AVAILABLE = True
61
+ except ImportError:
62
+ BM25_AVAILABLE = False
63
+ print("Warning: rank-bm25 not available. BM25 retrieval disabled.")
64
+
65
+ # Configure logging
66
+ logging.basicConfig(
67
+ level=logging.INFO,
68
+ format='%(asctime)s - %(levelname)s - %(message)s',
69
+ handlers=[
70
+ logging.StreamHandler(),
71
+ logging.FileHandler('rag_bot.log')
72
+ ]
73
+ )
74
+ logger = logging.getLogger(__name__)
75
+
76
+
77
+ @dataclass
78
+ class Document:
79
+ """Represents a document with metadata"""
80
+ filename: str
81
+ content: str
82
+ filepath: str
83
+ file_type: str
84
+ chunk_count: int = 0
85
+ file_hash: str = ""
86
+
87
+
88
+ @dataclass
89
+ class Chunk:
90
+ """Represents a text chunk with metadata"""
91
+ text: str
92
+ filename: str
93
+ chunk_id: int
94
+ total_chunks: int
95
+ start_pos: int
96
+ end_pos: int
97
+ metadata: Dict[str, Any]
98
+ chunk_hash: str = ""
99
+
100
+
101
+ class VectorRetriever:
102
+ """ChromaDB-based vector retrieval"""
103
+
104
+ def __init__(self, collection_name: str = "cgt_documents", persist_directory: str = "./chroma_db"):
105
+ if not CHROMADB_AVAILABLE:
106
+ raise ImportError("ChromaDB is required for vector retrieval")
107
+
108
+ self.collection_name = collection_name
109
+ self.persist_directory = persist_directory
110
+
111
+ # Initialize ChromaDB client
112
+ self.client = chromadb.PersistentClient(path=persist_directory)
113
+
114
+ # Get or create collection
115
+ try:
116
+ self.collection = self.client.get_collection(name=collection_name)
117
+ logger.info(f"Loaded existing collection '{collection_name}' with {self.collection.count()} documents")
118
+ except:
119
+ self.collection = self.client.create_collection(
120
+ name=collection_name,
121
+ metadata={"description": "CGT-LLM-Beta document collection"}
122
+ )
123
+ logger.info(f"Created new collection '{collection_name}'")
124
+
125
+ # Initialize embedding model
126
+ if SENTENCE_TRANSFORMERS_AVAILABLE:
127
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
128
+ logger.info("Loaded sentence-transformers embedding model")
129
+ else:
130
+ self.embedding_model = None
131
+ logger.warning("Sentence-transformers not available, using ChromaDB default embeddings")
132
+
133
+ def add_documents(self, chunks: List[Chunk]) -> None:
134
+ """Add document chunks to the vector database"""
135
+ if not chunks:
136
+ return
137
+
138
+ logger.info(f"Adding {len(chunks)} chunks to vector database...")
139
+
140
+ # Prepare data for ChromaDB
141
+ documents = []
142
+ metadatas = []
143
+ ids = []
144
+
145
+ for chunk in chunks:
146
+ chunk_id = f"{chunk.filename}_{chunk.chunk_id}"
147
+ documents.append(chunk.text)
148
+
149
+ metadata = {
150
+ "filename": chunk.filename,
151
+ "chunk_id": chunk.chunk_id,
152
+ "total_chunks": chunk.total_chunks,
153
+ "start_pos": chunk.start_pos,
154
+ "end_pos": chunk.end_pos,
155
+ "chunk_hash": chunk.chunk_hash,
156
+ **chunk.metadata
157
+ }
158
+ metadatas.append(metadata)
159
+ ids.append(chunk_id)
160
+
161
+ # Add to collection
162
+ try:
163
+ self.collection.add(
164
+ documents=documents,
165
+ metadatas=metadatas,
166
+ ids=ids
167
+ )
168
+ logger.info(f"Successfully added {len(chunks)} chunks to vector database")
169
+ except Exception as e:
170
+ logger.error(f"Error adding documents to vector database: {e}")
171
+
172
+ def search(self, query: str, k: int = 5) -> List[Tuple[Chunk, float]]:
173
+ """Search for similar chunks using vector similarity"""
174
+ try:
175
+ # Perform vector search
176
+ results = self.collection.query(
177
+ query_texts=[query],
178
+ n_results=k
179
+ )
180
+
181
+ chunks_with_scores = []
182
+ if results['documents'] and results['documents'][0]:
183
+ for i, (doc, metadata, distance) in enumerate(zip(
184
+ results['documents'][0],
185
+ results['metadatas'][0],
186
+ results['distances'][0]
187
+ )):
188
+ # Convert distance to similarity score (ChromaDB uses cosine distance)
189
+ similarity_score = 1 - distance
190
+
191
+ chunk = Chunk(
192
+ text=doc,
193
+ filename=metadata['filename'],
194
+ chunk_id=metadata['chunk_id'],
195
+ total_chunks=metadata['total_chunks'],
196
+ start_pos=metadata['start_pos'],
197
+ end_pos=metadata['end_pos'],
198
+ metadata={k: v for k, v in metadata.items()
199
+ if k not in ['filename', 'chunk_id', 'total_chunks', 'start_pos', 'end_pos', 'chunk_hash']},
200
+ chunk_hash=metadata.get('chunk_hash', '')
201
+ )
202
+ chunks_with_scores.append((chunk, similarity_score))
203
+
204
+ return chunks_with_scores
205
+
206
+ except Exception as e:
207
+ logger.error(f"Error searching vector database: {e}")
208
+ return []
209
+
210
+ def get_collection_stats(self) -> Dict[str, Any]:
211
+ """Get statistics about the collection"""
212
+ try:
213
+ count = self.collection.count()
214
+ return {
215
+ "total_chunks": count,
216
+ "collection_name": self.collection_name,
217
+ "persist_directory": self.persist_directory
218
+ }
219
+ except Exception as e:
220
+ logger.error(f"Error getting collection stats: {e}")
221
+ return {}
222
+
223
+
224
+ class RAGBot:
225
+ """Main RAG chatbot class with vector database"""
226
+
227
+ def __init__(self, args):
228
+ self.args = args
229
+ self.device = self._setup_device()
230
+ self.model = None
231
+ self.tokenizer = None
232
+ self.vector_retriever = None
233
+
234
+ # Load model (unless skipping for Inference API)
235
+ if not hasattr(args, 'skip_model_loading') or not args.skip_model_loading:
236
+ self._load_model()
237
+
238
+ # Initialize vector retriever
239
+ self._setup_vector_retriever()
240
+
241
+ def _setup_device(self) -> str:
242
+ """Setup device with MPS support for Apple Silicon"""
243
+ if torch.backends.mps.is_available():
244
+ device = "mps"
245
+ logger.info("Using device: mps (Apple Silicon)")
246
+ elif torch.cuda.is_available():
247
+ device = "cuda"
248
+ logger.info("Using device: cuda")
249
+ else:
250
+ device = "cpu"
251
+ logger.info("Using device: cpu")
252
+
253
+ return device
254
+
255
+ def _load_model(self):
256
+ """Load the specified LLM model and tokenizer"""
257
+ try:
258
+ model_name = self.args.model
259
+ logger.info(f"Loading model: {model_name}...")
260
+ from transformers import AutoTokenizer, AutoModelForCausalLM
261
+
262
+ # Get Hugging Face token from environment (for gated models)
263
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
264
+
265
+ # Load tokenizer
266
+ tokenizer_kwargs = {
267
+ "trust_remote_code": True
268
+ }
269
+ if hf_token:
270
+ tokenizer_kwargs["token"] = hf_token
271
+ logger.info("Using HF_TOKEN for authentication")
272
+
273
+ self.tokenizer = AutoTokenizer.from_pretrained(
274
+ model_name,
275
+ **tokenizer_kwargs
276
+ )
277
+
278
+ # Determine appropriate torch dtype based on device and model
279
+ # Use float16 for MPS/CUDA, float32 for CPU
280
+ # Some models work better with bfloat16
281
+ if self.device == "mps":
282
+ torch_dtype = torch.float16
283
+ elif self.device == "cuda":
284
+ torch_dtype = torch.float16
285
+ else:
286
+ torch_dtype = torch.float32
287
+
288
+ # Load model with appropriate settings
289
+ model_kwargs = {
290
+ "torch_dtype": torch_dtype,
291
+ "trust_remote_code": True,
292
+ }
293
+
294
+ # Add token if available (for gated models)
295
+ if hf_token:
296
+ model_kwargs["token"] = hf_token
297
+
298
+ # Use 8-bit quantization on CPU to reduce memory usage
299
+ # This reduces memory by ~50% with minimal quality loss
300
+ if self.device == "cpu":
301
+ try:
302
+ from transformers import BitsAndBytesConfig
303
+ # Use 8-bit quantization for CPU (reduces memory significantly)
304
+ model_kwargs["load_in_8bit"] = False # 8-bit not available on CPU
305
+ # Instead, use float16 even on CPU to save memory
306
+ model_kwargs["torch_dtype"] = torch.float16
307
+ logger.info("Using float16 on CPU to reduce memory usage")
308
+ except ImportError:
309
+ # Fallback: use float16 anyway
310
+ model_kwargs["torch_dtype"] = torch.float16
311
+ logger.info("Using float16 on CPU to reduce memory usage (fallback)")
312
+
313
+ # For MPS, use device_map; for CUDA, let it auto-detect
314
+ if self.device == "mps":
315
+ model_kwargs["device_map"] = self.device
316
+ elif self.device == "cuda":
317
+ model_kwargs["device_map"] = "auto"
318
+ # For CPU, don't specify device_map
319
+
320
+ self.model = AutoModelForCausalLM.from_pretrained(
321
+ model_name,
322
+ **model_kwargs
323
+ )
324
+
325
+ # Move to device if not using device_map
326
+ if self.device == "cpu":
327
+ self.model = self.model.to(self.device)
328
+
329
+ # Set pad token if not already set
330
+ if self.tokenizer.pad_token is None:
331
+ if self.tokenizer.eos_token is not None:
332
+ self.tokenizer.pad_token = self.tokenizer.eos_token
333
+ else:
334
+ # Some models might need a different approach
335
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
336
+
337
+ logger.info(f"Model {model_name} loaded successfully on {self.device}")
338
+
339
+ except Exception as e:
340
+ logger.error(f"Failed to load model {self.args.model}: {e}")
341
+ logger.error("Make sure the model name is correct and you have access to it on HuggingFace")
342
+ logger.error("For gated models (like Llama), you need to:")
343
+ logger.error(" 1. Request access at: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct")
344
+ logger.error(" 2. Add HF_TOKEN as a secret in your Hugging Face Space settings")
345
+ logger.error(" 3. Get your token from: https://huggingface.co/settings/tokens")
346
+ logger.error("For local use, ensure you're logged in: huggingface-cli login")
347
+ sys.exit(2)
348
+
349
+ def _setup_vector_retriever(self):
350
+ """Setup the vector retriever"""
351
+ try:
352
+ self.vector_retriever = VectorRetriever(
353
+ collection_name="cgt_documents",
354
+ persist_directory=self.args.vector_db_dir
355
+ )
356
+ logger.info("Vector retriever initialized successfully")
357
+ except Exception as e:
358
+ logger.error(f"Failed to setup vector retriever: {e}")
359
+ sys.exit(2)
360
+
361
+ def _calculate_file_hash(self, filepath: str) -> str:
362
+ """Calculate hash of file for change detection"""
363
+ try:
364
+ with open(filepath, 'rb') as f:
365
+ return hashlib.md5(f.read()).hexdigest()
366
+ except:
367
+ return ""
368
+
369
+ def _calculate_chunk_hash(self, text: str) -> str:
370
+ """Calculate hash of chunk text"""
371
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
372
+
373
+ def load_corpus(self, data_dir: str) -> List[Document]:
374
+ """Load all documents from the data directory"""
375
+ logger.info(f"Loading corpus from {data_dir}")
376
+ documents = []
377
+ data_path = Path(data_dir)
378
+
379
+ if not data_path.exists():
380
+ logger.error(f"Data directory {data_dir} does not exist")
381
+ sys.exit(1)
382
+
383
+ # Supported file extensions
384
+ supported_extensions = {'.txt', '.md', '.json', '.csv'}
385
+ if PDF_AVAILABLE:
386
+ supported_extensions.add('.pdf')
387
+ if DOCX_AVAILABLE:
388
+ supported_extensions.add('.docx')
389
+ supported_extensions.add('.doc')
390
+
391
+ # Find all files recursively
392
+ files = []
393
+ for ext in supported_extensions:
394
+ files.extend(data_path.rglob(f"*{ext}"))
395
+
396
+ logger.info(f"Found {len(files)} files to process")
397
+
398
+ # Process files with progress bar
399
+ for file_path in tqdm(files, desc="Loading documents"):
400
+ try:
401
+ content = self._read_file(file_path)
402
+ if content.strip(): # Only add non-empty documents
403
+ file_hash = self._calculate_file_hash(file_path)
404
+ doc = Document(
405
+ filename=file_path.name,
406
+ content=content,
407
+ filepath=str(file_path),
408
+ file_type=file_path.suffix.lower(),
409
+ file_hash=file_hash
410
+ )
411
+ documents.append(doc)
412
+ logger.debug(f"Loaded {file_path.name} ({len(content)} chars)")
413
+ else:
414
+ logger.warning(f"Skipping empty file: {file_path.name}")
415
+
416
+ except Exception as e:
417
+ logger.error(f"Failed to load {file_path.name}: {e}")
418
+ continue
419
+
420
+ logger.info(f"Successfully loaded {len(documents)} documents")
421
+ return documents
422
+
423
+ def _read_file(self, file_path: Path) -> str:
424
+ """Read content from various file types"""
425
+ suffix = file_path.suffix.lower()
426
+
427
+ try:
428
+ if suffix == '.txt':
429
+ return file_path.read_text(encoding='utf-8')
430
+
431
+ elif suffix == '.md':
432
+ return file_path.read_text(encoding='utf-8')
433
+
434
+ elif suffix == '.json':
435
+ with open(file_path, 'r', encoding='utf-8') as f:
436
+ data = json.load(f)
437
+ if isinstance(data, dict):
438
+ return json.dumps(data, indent=2)
439
+ else:
440
+ return str(data)
441
+
442
+ elif suffix == '.csv':
443
+ df = pd.read_csv(file_path)
444
+ return df.to_string()
445
+
446
+ elif suffix == '.pdf' and PDF_AVAILABLE:
447
+ text = ""
448
+ with open(file_path, 'rb') as f:
449
+ pdf_reader = pypdf.PdfReader(f)
450
+ for page in pdf_reader.pages:
451
+ text += page.extract_text() + "\n"
452
+ return text
453
+
454
+ elif suffix in ['.docx', '.doc'] and DOCX_AVAILABLE:
455
+ doc = Document(file_path)
456
+ text = ""
457
+ for paragraph in doc.paragraphs:
458
+ text += paragraph.text + "\n"
459
+ return text
460
+
461
+ else:
462
+ logger.warning(f"Unsupported file type: {suffix}")
463
+ return ""
464
+
465
+ except Exception as e:
466
+ logger.error(f"Error reading {file_path}: {e}")
467
+ return ""
468
+
469
+ def chunk_documents(self, docs: List[Document], chunk_size: int, overlap: int) -> List[Chunk]:
470
+ """Chunk documents into smaller pieces"""
471
+ logger.info(f"Chunking {len(docs)} documents (size={chunk_size}, overlap={overlap})")
472
+ chunks = []
473
+
474
+ for doc in docs:
475
+ doc_chunks = self._chunk_text(
476
+ doc.content,
477
+ doc.filename,
478
+ chunk_size,
479
+ overlap
480
+ )
481
+ chunks.extend(doc_chunks)
482
+
483
+ # Update document metadata
484
+ doc.chunk_count = len(doc_chunks)
485
+
486
+ logger.info(f"Created {len(chunks)} chunks from {len(docs)} documents")
487
+ return chunks
488
+
489
+ def _chunk_text(self, text: str, filename: str, chunk_size: int, overlap: int) -> List[Chunk]:
490
+ """Split text into overlapping chunks"""
491
+ # Clean text
492
+ text = re.sub(r'\s+', ' ', text.strip())
493
+
494
+ # Simple token-based chunking (approximate)
495
+ words = text.split()
496
+ chunks = []
497
+
498
+ for i in range(0, len(words), chunk_size - overlap):
499
+ chunk_words = words[i:i + chunk_size]
500
+ chunk_text = ' '.join(chunk_words)
501
+
502
+ if chunk_text.strip():
503
+ chunk_hash = self._calculate_chunk_hash(chunk_text)
504
+ chunk = Chunk(
505
+ text=chunk_text,
506
+ filename=filename,
507
+ chunk_id=len(chunks),
508
+ total_chunks=0, # Will be updated later
509
+ start_pos=i,
510
+ end_pos=i + len(chunk_words),
511
+ metadata={
512
+ 'word_count': len(chunk_words),
513
+ 'char_count': len(chunk_text)
514
+ },
515
+ chunk_hash=chunk_hash
516
+ )
517
+ chunks.append(chunk)
518
+
519
+ # Update total_chunks for each chunk
520
+ for chunk in chunks:
521
+ chunk.total_chunks = len(chunks)
522
+
523
+ return chunks
524
+
525
+ def build_or_update_index(self, chunks: List[Chunk], force_rebuild: bool = False) -> None:
526
+ """Build or update the vector index"""
527
+ if not chunks:
528
+ logger.warning("No chunks provided for indexing")
529
+ return
530
+
531
+ # Check if we need to rebuild
532
+ collection_stats = self.vector_retriever.get_collection_stats()
533
+ existing_count = collection_stats.get('total_chunks', 0)
534
+
535
+ if existing_count > 0 and not force_rebuild:
536
+ logger.info(f"Vector database already contains {existing_count} chunks. Use --force-rebuild to rebuild.")
537
+ return
538
+
539
+ if force_rebuild and existing_count > 0:
540
+ logger.info("Force rebuild requested. Clearing existing collection...")
541
+ try:
542
+ self.client.delete_collection(self.vector_retriever.collection_name)
543
+ self.vector_retriever.collection = self.client.create_collection(
544
+ name=self.vector_retriever.collection_name,
545
+ metadata={"description": "CGT-LLM-Beta document collection"}
546
+ )
547
+ except Exception as e:
548
+ logger.error(f"Error clearing collection: {e}")
549
+
550
+ # Add chunks to vector database
551
+ self.vector_retriever.add_documents(chunks)
552
+
553
+ logger.info("Vector index built successfully")
554
+
555
+ def retrieve(self, query: str, k: int) -> List[Chunk]:
556
+ """Retrieve relevant chunks for a query using vector search"""
557
+ results = self.vector_retriever.search(query, k)
558
+ chunks = [chunk for chunk, score in results]
559
+
560
+ if self.args.verbose:
561
+ logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
562
+ for i, (chunk, score) in enumerate(results):
563
+ logger.info(f" {i+1}. {chunk.filename} (score: {score:.3f})")
564
+
565
+ return chunks
566
+
567
+ def retrieve_with_scores(self, query: str, k: int) -> Tuple[List[Chunk], List[float]]:
568
+ """Retrieve relevant chunks with similarity scores
569
+
570
+ Returns:
571
+ Tuple of (chunks, scores) where scores are similarity scores for each chunk
572
+ """
573
+ results = self.vector_retriever.search(query, k)
574
+ chunks = [chunk for chunk, score in results]
575
+ scores = [score for chunk, score in results]
576
+
577
+ if self.args.verbose:
578
+ logger.info(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
579
+ for i, (chunk, score) in enumerate(results):
580
+ logger.info(f" {i+1}. {chunk.filename} (score: {score:.3f})")
581
+
582
+ return chunks, scores
583
+
584
+ def format_prompt(self, context_chunks: List[Chunk], question: str) -> str:
585
+ """Format the prompt with context and question, ensuring it fits within token limits"""
586
+ context_parts = []
587
+ for chunk in context_chunks:
588
+ context_parts.append(f"{chunk.text}")
589
+
590
+ context = "\n".join(context_parts)
591
+
592
+ # Try to use the tokenizer's chat template if available
593
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
594
+ try:
595
+ messages = [
596
+ {"role": "system", "content": "You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative."},
597
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
598
+ ]
599
+ base_prompt = self.tokenizer.apply_chat_template(
600
+ messages,
601
+ tokenize=False,
602
+ add_generation_prompt=True
603
+ )
604
+ except Exception as e:
605
+ logger.warning(f"Failed to use chat template, falling back to manual format: {e}")
606
+ base_prompt = self._format_prompt_manual(context, question)
607
+ else:
608
+ # Fall back to manual formatting (for Llama models)
609
+ base_prompt = self._format_prompt_manual(context, question)
610
+
611
+ # Check if prompt is too long and truncate context if needed
612
+ max_context_tokens = 1200 # Leave room for generation
613
+ try:
614
+ tokenized = self.tokenizer(base_prompt, return_tensors="pt")
615
+ current_tokens = tokenized['input_ids'].shape[1]
616
+ except Exception as e:
617
+ logger.warning(f"Tokenization error, using base prompt as-is: {e}")
618
+ return base_prompt
619
+
620
+ if current_tokens > max_context_tokens:
621
+ # Truncate context to fit within limits
622
+ try:
623
+ context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
624
+ available_tokens = max_context_tokens - (current_tokens - context_tokens)
625
+
626
+ if available_tokens > 0:
627
+ # Truncate context to fit
628
+ truncated_context = self.tokenizer.decode(
629
+ self.tokenizer(context, return_tensors="pt", truncation=True, max_length=available_tokens)['input_ids'][0],
630
+ skip_special_tokens=True
631
+ )
632
+
633
+ # Reformat with truncated context
634
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
635
+ try:
636
+ messages = [
637
+ {"role": "system", "content": "You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative."},
638
+ {"role": "user", "content": f"Context: {truncated_context}\n\nQuestion: {question}"}
639
+ ]
640
+ prompt = self.tokenizer.apply_chat_template(
641
+ messages,
642
+ tokenize=False,
643
+ add_generation_prompt=True
644
+ )
645
+ except:
646
+ prompt = self._format_prompt_manual(truncated_context, question)
647
+ else:
648
+ prompt = self._format_prompt_manual(truncated_context, question)
649
+ else:
650
+ # If even basic prompt is too long, use minimal format
651
+ prompt = self._format_prompt_manual(context[:500] + "...", question)
652
+ except Exception as e:
653
+ logger.warning(f"Error truncating context: {e}, using base prompt")
654
+ prompt = base_prompt
655
+ else:
656
+ prompt = base_prompt
657
+
658
+ return prompt
659
+
660
+ def _format_prompt_manual(self, context: str, question: str) -> str:
661
+ """Manual prompt formatting for models without chat templates (e.g., Llama)"""
662
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
663
+
664
+ You are a helpful medical assistant. Answer questions based on the provided context. Be specific and informative.<|eot_id|><|start_header_id|>user<|end_header_id|>
665
+
666
+ Context: {context}
667
+
668
+ Question: {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
669
+
670
+ """
671
+
672
+ def format_improved_prompt(self, context_chunks: List[Chunk], question: str) -> Tuple[str, str]:
673
+ """Format an improved prompt with better tone, structure, and medical appropriateness
674
+
675
+ Returns:
676
+ Tuple of (prompt, prompt_text) where prompt_text is the system prompt instructions
677
+ """
678
+ context_parts = []
679
+ for chunk in context_chunks:
680
+ context_parts.append(f"{chunk.text}")
681
+
682
+ context = "\n".join(context_parts)
683
+
684
+ # Improved prompt with all the feedback incorporated
685
+ improved_prompt_text = """Provide a concise, neutral, and informative answer based on the provided medical context.
686
+
687
+ CRITICAL GUIDELINES:
688
+ - Format your response as clear, well-structured sentences and paragraphs
689
+ - Be concise and direct - focus on answering the specific question asked
690
+ - Use neutral, factual language - do NOT tell the questioner how to feel (avoid phrases like 'don't worry', 'the good news is', etc.)
691
+ - Do NOT use leading or coercive language - present information neutrally to preserve patient autonomy
692
+ - Do NOT make specific medical recommendations - instead state that management decisions should be made with a healthcare provider
693
+ - Use third-person voice only - never claim to be a medical professional or assistant
694
+ - Use consistent terminology: use 'children' (not 'offspring') consistently
695
+ - Do NOT include hypothetical examples with specific names (e.g., avoid 'Aunt Jenna' or similar)
696
+ - Include important distinctions when relevant (e.g., somatic vs. germline variants, reproductive risks)
697
+ - When citing sources, be consistent - always specify which guidelines or sources when mentioned
698
+ - Remove any formatting markers like asterisks (*) or bold markers
699
+ - Do NOT include phrases like 'Here's a rewritten version' - just provide the answer directly
700
+
701
+ If the question asks about medical management, screening, or interventions, conclude with: 'Management recommendations are individualized and should be discussed with a healthcare provider or genetic counselor.'"""
702
+
703
+ # Try to use the tokenizer's chat template if available
704
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
705
+ try:
706
+ messages = [
707
+ {"role": "system", "content": improved_prompt_text},
708
+ {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
709
+ ]
710
+ base_prompt = self.tokenizer.apply_chat_template(
711
+ messages,
712
+ tokenize=False,
713
+ add_generation_prompt=True
714
+ )
715
+ except Exception as e:
716
+ logger.warning(f"Failed to use chat template for improved prompt, falling back to manual format: {e}")
717
+ base_prompt = self._format_improved_prompt_manual(context, question, improved_prompt_text)
718
+ else:
719
+ # Fall back to manual formatting (for Llama models)
720
+ base_prompt = self._format_improved_prompt_manual(context, question, improved_prompt_text)
721
+
722
+ # Check if prompt is too long and truncate context if needed
723
+ max_context_tokens = 1200 # Leave room for generation
724
+ try:
725
+ tokenized = self.tokenizer(base_prompt, return_tensors="pt")
726
+ current_tokens = tokenized['input_ids'].shape[1]
727
+ except Exception as e:
728
+ logger.warning(f"Tokenization error for improved prompt, using base prompt as-is: {e}")
729
+ return base_prompt, improved_prompt_text
730
+
731
+ if current_tokens > max_context_tokens:
732
+ # Truncate context to fit within limits
733
+ try:
734
+ context_tokens = self.tokenizer(context, return_tensors="pt")['input_ids'].shape[1]
735
+ available_tokens = max_context_tokens - (current_tokens - context_tokens)
736
+
737
+ if available_tokens > 0:
738
+ # Truncate context to fit
739
+ truncated_context = self.tokenizer.decode(
740
+ self.tokenizer(context, return_tensors="pt", truncation=True, max_length=available_tokens)['input_ids'][0],
741
+ skip_special_tokens=True
742
+ )
743
+
744
+ # Reformat with truncated context
745
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
746
+ try:
747
+ messages = [
748
+ {"role": "system", "content": improved_prompt_text},
749
+ {"role": "user", "content": f"Context: {truncated_context}\n\nQuestion: {question}"}
750
+ ]
751
+ prompt = self.tokenizer.apply_chat_template(
752
+ messages,
753
+ tokenize=False,
754
+ add_generation_prompt=True
755
+ )
756
+ except:
757
+ prompt = self._format_improved_prompt_manual(truncated_context, question, improved_prompt_text)
758
+ else:
759
+ prompt = self._format_improved_prompt_manual(truncated_context, question, improved_prompt_text)
760
+ else:
761
+ # If even basic prompt is too long, use minimal format
762
+ prompt = self._format_improved_prompt_manual(context[:500] + "...", question, improved_prompt_text)
763
+ except Exception as e:
764
+ logger.warning(f"Error truncating context for improved prompt: {e}, using base prompt")
765
+ prompt = base_prompt
766
+ else:
767
+ prompt = base_prompt
768
+
769
+ return prompt, improved_prompt_text
770
+
771
+ def _format_improved_prompt_manual(self, context: str, question: str, improved_prompt_text: str) -> str:
772
+ """Manual prompt formatting for improved prompts (for models without chat templates)"""
773
+ return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
774
+
775
+ {improved_prompt_text}<|eot_id|><|start_header_id|>user<|end_header_id|>
776
+
777
+ Context: {context}
778
+
779
+ Question: {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
780
+
781
+ """
782
+
783
+ def generate_answer(self, prompt: str, **gen_kwargs) -> str:
784
+ """Generate answer using the language model"""
785
+ try:
786
+ if self.args.verbose:
787
+ logger.info(f"Full prompt (first 500 chars): {prompt[:500]}...")
788
+
789
+ # Tokenize input with more conservative limit to leave room for generation
790
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1500)
791
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
792
+
793
+ if self.args.verbose:
794
+ logger.info(f"Input tokens: {inputs['input_ids'].shape}")
795
+
796
+ # Generate
797
+ with torch.no_grad():
798
+ outputs = self.model.generate(
799
+ **inputs,
800
+ max_new_tokens=gen_kwargs.get('max_new_tokens', 512),
801
+ temperature=gen_kwargs.get('temperature', 0.7),
802
+ top_p=gen_kwargs.get('top_p', 0.95),
803
+ repetition_penalty=gen_kwargs.get('repetition_penalty', 1.05),
804
+ do_sample=True,
805
+ pad_token_id=self.tokenizer.eos_token_id,
806
+ eos_token_id=self.tokenizer.eos_token_id,
807
+ use_cache=True,
808
+ num_beams=1
809
+ )
810
+
811
+ # Decode response without skipping special tokens to preserve full length
812
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
813
+
814
+ if self.args.verbose:
815
+ logger.info(f"Full response (first 1000 chars): {response[:1000]}...")
816
+ logger.info(f"Looking for 'Answer:' in response: {'Answer:' in response}")
817
+ if "Answer:" in response:
818
+ answer_part = response.split("Answer:")[-1]
819
+ logger.info(f"Answer part (first 200 chars): {answer_part[:200]}...")
820
+
821
+ # Debug: Show the full response to understand the structure
822
+ logger.info(f"Full response length: {len(response)}")
823
+ logger.info(f"Prompt length: {len(prompt)}")
824
+ logger.info(f"Response after prompt (first 500 chars): {response[len(prompt):][:500]}...")
825
+
826
+ # Extract the answer more robustly by looking for the end of the prompt
827
+ # Find the actual end of the prompt in the response
828
+ prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>\n\n"
829
+ if prompt_end_marker in response:
830
+ answer = response.split(prompt_end_marker)[-1].strip()
831
+ else:
832
+ # Fallback to character-based extraction
833
+ answer = response[len(prompt):].strip()
834
+
835
+ if self.args.verbose:
836
+ logger.info(f"Full LLM output (first 200 chars): {answer[:200]}...")
837
+ logger.info(f"Full LLM output length: {len(answer)} characters")
838
+ logger.info(f"Full LLM output (last 200 chars): ...{answer[-200:]}")
839
+
840
+ # Only do minimal cleanup to preserve the full response
841
+ # Remove special tokens that might interfere with display, but preserve content
842
+ if "<|start_header_id|>" in answer:
843
+ # Only remove if it's at the very end
844
+ if answer.endswith("<|start_header_id|>"):
845
+ answer = answer[:-len("<|start_header_id|>")].strip()
846
+ if "<|eot_id|>" in answer:
847
+ # Only remove if it's at the very end
848
+ if answer.endswith("<|eot_id|>"):
849
+ answer = answer[:-len("<|eot_id|>")].strip()
850
+ if "<|end_of_text|>" in answer:
851
+ # Only remove if it's at the very end
852
+ if answer.endswith("<|end_of_text|>"):
853
+ answer = answer[:-len("<|end_of_text|>")].strip()
854
+
855
+ # Final validation - only reject if completely empty
856
+ if not answer or len(answer) < 3:
857
+ answer = "I don't know."
858
+
859
+ if self.args.verbose:
860
+ logger.info(f"Final answer: '{answer}'")
861
+
862
+ return answer
863
+
864
+ except Exception as e:
865
+ logger.error(f"Generation error: {e}")
866
+ return "I encountered an error while generating the answer."
867
+
868
+ def process_questions(self, questions_path: str, **kwargs) -> List[Tuple[str, str, str, str, float, str, float, str, float, str, str]]:
869
+ """Process all questions and generate answers with multiple readability levels
870
+
871
+ Returns:
872
+ List of tuples: (question, answer, sources, question_group, original_flesch,
873
+ middle_school_answer, middle_school_flesch,
874
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores)
875
+ """
876
+ logger.info(f"Processing questions from {questions_path}")
877
+
878
+ # Load questions
879
+ try:
880
+ with open(questions_path, 'r', encoding='utf-8') as f:
881
+ questions = [line.strip() for line in f if line.strip()]
882
+ except Exception as e:
883
+ logger.error(f"Failed to load questions: {e}")
884
+ sys.exit(1)
885
+
886
+ logger.info(f"Found {len(questions)} questions to process")
887
+
888
+ qa_pairs = []
889
+
890
+ # Get the improved prompt text for CSV header by calling format_improved_prompt with empty chunks
891
+ # This will give us the prompt text without actually generating
892
+ _, improved_prompt_text = self.format_improved_prompt([], "")
893
+
894
+ # Initialize CSV file with headers
895
+ self.write_csv([], kwargs.get('output_file', 'results.csv'), append=False, improved_prompt_text=improved_prompt_text)
896
+
897
+ # Process each question
898
+ for i, question in enumerate(tqdm(questions, desc="Processing questions")):
899
+ logger.info(f"Question {i+1}/{len(questions)}: {question[:50]}...")
900
+
901
+ try:
902
+ # Categorize question
903
+ question_group = self._categorize_question(question)
904
+
905
+ # Retrieve relevant chunks with similarity scores
906
+ context_chunks, similarity_scores = self.retrieve_with_scores(question, self.args.k)
907
+
908
+ # Format similarity scores as a string (comma-separated, 3 decimal places)
909
+ similarity_scores_str = ", ".join([f"{score:.3f}" for score in similarity_scores]) if similarity_scores else "0.000"
910
+
911
+ if not context_chunks:
912
+ answer = "I don't know."
913
+ sources = "No sources found"
914
+ middle_school_answer = "I don't know."
915
+ high_school_answer = "I don't know."
916
+ improved_answer = "I don't know."
917
+ original_flesch = 0.0
918
+ middle_school_flesch = 0.0
919
+ high_school_flesch = 0.0
920
+ similarity_scores_str = "0.000"
921
+ else:
922
+ # Format original prompt
923
+ prompt = self.format_prompt(context_chunks, question)
924
+
925
+ # Generate original answer
926
+ start_time = time.time()
927
+ answer = self.generate_answer(prompt, **kwargs)
928
+ gen_time = time.time() - start_time
929
+
930
+ # Generate improved answer
931
+ improved_prompt, _ = self.format_improved_prompt(context_chunks, question)
932
+ improved_start = time.time()
933
+ improved_answer = self.generate_answer(improved_prompt, **kwargs)
934
+ improved_time = time.time() - improved_start
935
+
936
+ # Clean up improved answer - remove unwanted phrases and formatting
937
+ improved_answer = self._clean_improved_answer(improved_answer)
938
+ logger.info(f"Improved answer generated in {improved_time:.2f}s")
939
+
940
+ # Extract source documents
941
+ sources = self._extract_sources(context_chunks)
942
+
943
+ # Calculate original answer Flesch score
944
+ try:
945
+ original_flesch = textstat.flesch_kincaid_grade(answer)
946
+ except:
947
+ original_flesch = 0.0
948
+
949
+ # Generate middle school version
950
+ readability_start = time.time()
951
+ middle_school_answer, middle_school_flesch = self.enhance_readability(answer, "middle_school")
952
+ readability_time = time.time() - readability_start
953
+ logger.info(f"Middle school readability in {readability_time:.2f}s")
954
+
955
+ # Generate high school version
956
+ readability_start = time.time()
957
+ high_school_answer, high_school_flesch = self.enhance_readability(answer, "high_school")
958
+ readability_time = time.time() - readability_start
959
+ logger.info(f"High school readability in {readability_time:.2f}s")
960
+
961
+ logger.info(f"Generated answer in {gen_time:.2f}s")
962
+ logger.info(f"Sources: {sources}")
963
+ logger.info(f"Similarity scores: {similarity_scores_str}")
964
+ logger.info(f"Original Flesch: {original_flesch:.1f}, Middle School: {middle_school_flesch:.1f}, High School: {high_school_flesch:.1f}")
965
+
966
+ qa_pairs.append((question, answer, sources, question_group, original_flesch,
967
+ middle_school_answer, middle_school_flesch,
968
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str))
969
+
970
+ # Write incrementally to CSV after each question
971
+ self.write_csv([(question, answer, sources, question_group, original_flesch,
972
+ middle_school_answer, middle_school_flesch,
973
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str)],
974
+ kwargs.get('output_file', 'results.csv'), append=True, improved_prompt_text=improved_prompt_text)
975
+ logger.info(f"Progress saved: {i+1}/{len(questions)} questions completed")
976
+
977
+ except Exception as e:
978
+ logger.error(f"Error processing question {i+1}: {e}")
979
+ error_answer = "I encountered an error processing this question."
980
+ sources = "Error retrieving sources"
981
+ question_group = self._categorize_question(question)
982
+ original_flesch = 0.0
983
+ middle_school_answer = "I encountered an error processing this question."
984
+ high_school_answer = "I encountered an error processing this question."
985
+ improved_answer = "I encountered an error processing this question."
986
+ middle_school_flesch = 0.0
987
+ high_school_flesch = 0.0
988
+ similarity_scores_str = "0.000"
989
+ qa_pairs.append((question, error_answer, sources, question_group, original_flesch,
990
+ middle_school_answer, middle_school_flesch,
991
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str))
992
+
993
+ # Still write the error to CSV
994
+ self.write_csv([(question, error_answer, sources, question_group, original_flesch,
995
+ middle_school_answer, middle_school_flesch,
996
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores_str)],
997
+ kwargs.get('output_file', 'results.csv'), append=True, improved_prompt_text=improved_prompt_text)
998
+ logger.info(f"Error saved: {i+1}/{len(questions)} questions completed")
999
+
1000
+ return qa_pairs
1001
+
1002
+ def _clean_readability_answer(self, answer: str, target_level: str) -> str:
1003
+ """Clean up readability-enhanced answers to remove unwanted phrases and formatting
1004
+
1005
+ Args:
1006
+ answer: The readability-enhanced answer
1007
+ target_level: Either "middle_school" or "high_school"
1008
+ """
1009
+ cleaned = answer
1010
+
1011
+ # Remove the "Here's a rewritten version" phrases
1012
+ if target_level == "middle_school":
1013
+ unwanted_phrases = [
1014
+ "Here's a rewritten version of the text at a middle school reading level:",
1015
+ "Here's a rewritten version of the text at a middle school reading level",
1016
+ "Here is a rewritten version of the text at a middle school reading level:",
1017
+ "Here is a rewritten version of the text at a middle school reading level",
1018
+ "Here's a rewritten version at a middle school reading level:",
1019
+ "Here's a rewritten version at a middle school reading level",
1020
+ ]
1021
+ elif target_level == "high_school":
1022
+ unwanted_phrases = [
1023
+ "Here's a rewritten version of the text at a high school reading level",
1024
+ "Here's a rewritten version of the text at a high school reading level:",
1025
+ "Here is a rewritten version of the text at a high school reading level",
1026
+ "Here is a rewritten version of the text at a high school reading level:",
1027
+ "Here's a rewritten version at a high school reading level",
1028
+ "Here's a rewritten version at a high school reading level:",
1029
+ ]
1030
+ else:
1031
+ unwanted_phrases = []
1032
+
1033
+ for phrase in unwanted_phrases:
1034
+ if phrase.lower() in cleaned.lower():
1035
+ # Find and remove the phrase (case-insensitive)
1036
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
1037
+ cleaned = pattern.sub("", cleaned).strip()
1038
+ # Remove leading colons, semicolons, or dashes
1039
+ cleaned = re.sub(r'^[:;\-]\s*', '', cleaned).strip()
1040
+
1041
+ # Remove asterisks (but preserve bullet points if they use •)
1042
+ cleaned = re.sub(r'\*\*', '', cleaned) # Remove bold markers
1043
+ cleaned = re.sub(r'\(\*\)', '', cleaned) # Remove (*)
1044
+ cleaned = re.sub(r'\*', '', cleaned) # Remove remaining asterisks
1045
+
1046
+ # Clean up extra whitespace
1047
+ cleaned = ' '.join(cleaned.split())
1048
+
1049
+ return cleaned
1050
+
1051
+ def _clean_improved_answer(self, answer: str) -> str:
1052
+ """Clean up improved answer to remove unwanted phrases and formatting"""
1053
+ # Remove phrases like "Here's a rewritten version" or similar
1054
+ unwanted_phrases = [
1055
+ "Here's a rewritten version",
1056
+ "Here's a version",
1057
+ "Here is a rewritten version",
1058
+ "Here is a version",
1059
+ "Here's the answer",
1060
+ "Here is the answer"
1061
+ ]
1062
+
1063
+ cleaned = answer
1064
+ for phrase in unwanted_phrases:
1065
+ if phrase.lower() in cleaned.lower():
1066
+ # Find and remove the phrase and any following colon/semicolon
1067
+ pattern = re.compile(re.escape(phrase), re.IGNORECASE)
1068
+ cleaned = pattern.sub("", cleaned).strip()
1069
+ # Remove leading colons, semicolons, or dashes
1070
+ cleaned = re.sub(r'^[:;\-]\s*', '', cleaned).strip()
1071
+
1072
+ # Remove formatting markers like (*) or ** but preserve bullet points
1073
+ cleaned = re.sub(r'\*\*', '', cleaned) # Remove bold markers
1074
+ cleaned = re.sub(r'\(\*\)', '', cleaned) # Remove (*)
1075
+ # Note: Single asterisks are left alone as they might be used for formatting
1076
+ # The prompt specifies using • for bullet points, so this should be fine
1077
+
1078
+ # Remove "Don't worry" and similar emotional management phrases
1079
+ emotional_phrases = [
1080
+ r"don't worry[^.]*\.\s*",
1081
+ r"Don't worry[^.]*\.\s*",
1082
+ r"the good news is[^.]*\.\s*",
1083
+ r"The good news is[^.]*\.\s*",
1084
+ ]
1085
+ for pattern in emotional_phrases:
1086
+ cleaned = re.sub(pattern, '', cleaned, flags=re.IGNORECASE)
1087
+
1088
+ # Clean up extra whitespace
1089
+ cleaned = ' '.join(cleaned.split())
1090
+
1091
+ return cleaned
1092
+
1093
+ def diagnose_system(self, sample_questions: List[str] = None) -> Dict[str, Any]:
1094
+ """Diagnose the document loading, chunking, and retrieval system
1095
+
1096
+ Args:
1097
+ sample_questions: Optional list of questions to test retrieval
1098
+
1099
+ Returns:
1100
+ Dictionary with diagnostic information
1101
+ """
1102
+ diagnostics = {
1103
+ 'vector_db_stats': {},
1104
+ 'document_stats': {},
1105
+ 'chunk_stats': {},
1106
+ 'retrieval_tests': []
1107
+ }
1108
+
1109
+ # Check vector database
1110
+ try:
1111
+ stats = self.vector_retriever.get_collection_stats()
1112
+ diagnostics['vector_db_stats'] = {
1113
+ 'total_chunks': stats.get('total_chunks', 0),
1114
+ 'collection_name': stats.get('collection_name', 'unknown'),
1115
+ 'status': 'OK' if stats.get('total_chunks', 0) > 0 else 'EMPTY'
1116
+ }
1117
+ except Exception as e:
1118
+ diagnostics['vector_db_stats'] = {
1119
+ 'status': 'ERROR',
1120
+ 'error': str(e)
1121
+ }
1122
+
1123
+ # Test document loading (without actually loading)
1124
+ try:
1125
+ data_path = Path(self.args.data_dir)
1126
+ if data_path.exists():
1127
+ supported_extensions = {'.txt', '.md', '.json', '.csv'}
1128
+ if PDF_AVAILABLE:
1129
+ supported_extensions.add('.pdf')
1130
+ if DOCX_AVAILABLE:
1131
+ supported_extensions.add('.docx')
1132
+ supported_extensions.add('.doc')
1133
+
1134
+ files = []
1135
+ for ext in supported_extensions:
1136
+ files.extend(data_path.rglob(f"*{ext}"))
1137
+
1138
+ # Sample a few files to check content
1139
+ sample_files = files[:5] if len(files) > 5 else files
1140
+ file_samples = []
1141
+ for file_path in sample_files:
1142
+ try:
1143
+ content = self._read_file(file_path)
1144
+ file_samples.append({
1145
+ 'filename': file_path.name,
1146
+ 'size_chars': len(content),
1147
+ 'size_words': len(content.split()),
1148
+ 'readable': True
1149
+ })
1150
+ except Exception as e:
1151
+ file_samples.append({
1152
+ 'filename': file_path.name,
1153
+ 'readable': False,
1154
+ 'error': str(e)
1155
+ })
1156
+
1157
+ diagnostics['document_stats'] = {
1158
+ 'total_files_found': len(files),
1159
+ 'sample_files': file_samples,
1160
+ 'status': 'OK'
1161
+ }
1162
+ else:
1163
+ diagnostics['document_stats'] = {
1164
+ 'status': 'ERROR',
1165
+ 'error': f'Data directory {self.args.data_dir} does not exist'
1166
+ }
1167
+ except Exception as e:
1168
+ diagnostics['document_stats'] = {
1169
+ 'status': 'ERROR',
1170
+ 'error': str(e)
1171
+ }
1172
+
1173
+ # Test chunking on a sample document
1174
+ try:
1175
+ if diagnostics['document_stats'].get('status') == 'OK':
1176
+ sample_file = None
1177
+ for file_info in diagnostics['document_stats'].get('sample_files', []):
1178
+ if file_info.get('readable', False):
1179
+ # Find the actual file
1180
+ data_path = Path(self.args.data_dir)
1181
+ for ext in ['.txt', '.md', '.pdf', '.docx']:
1182
+ files = list(data_path.rglob(f"*{file_info['filename']}"))
1183
+ if files:
1184
+ sample_file = files[0]
1185
+ break
1186
+ if sample_file:
1187
+ break
1188
+
1189
+ if sample_file:
1190
+ content = self._read_file(sample_file)
1191
+ # Create a dummy document (Document is already imported at top)
1192
+ sample_doc = Document(
1193
+ filename=sample_file.name,
1194
+ content=content,
1195
+ filepath=str(sample_file),
1196
+ file_type=sample_file.suffix.lower(),
1197
+ file_hash=""
1198
+ )
1199
+
1200
+ # Test chunking
1201
+ sample_chunks = self._chunk_text(
1202
+ content,
1203
+ sample_file.name,
1204
+ self.args.chunk_size,
1205
+ self.args.chunk_overlap
1206
+ )
1207
+
1208
+ chunk_lengths = [len(chunk.text.split()) for chunk in sample_chunks]
1209
+
1210
+ diagnostics['chunk_stats'] = {
1211
+ 'sample_document': sample_file.name,
1212
+ 'total_chunks': len(sample_chunks),
1213
+ 'avg_chunk_size_words': sum(chunk_lengths) / len(chunk_lengths) if chunk_lengths else 0,
1214
+ 'min_chunk_size_words': min(chunk_lengths) if chunk_lengths else 0,
1215
+ 'max_chunk_size_words': max(chunk_lengths) if chunk_lengths else 0,
1216
+ 'chunk_size_setting': self.args.chunk_size,
1217
+ 'chunk_overlap_setting': self.args.chunk_overlap,
1218
+ 'status': 'OK'
1219
+ }
1220
+ except Exception as e:
1221
+ diagnostics['chunk_stats'] = {
1222
+ 'status': 'ERROR',
1223
+ 'error': str(e)
1224
+ }
1225
+
1226
+ # Test retrieval with sample questions
1227
+ if sample_questions and diagnostics['vector_db_stats'].get('status') == 'OK':
1228
+ for question in sample_questions:
1229
+ try:
1230
+ context_chunks = self.retrieve(question, self.args.k)
1231
+ sources = self._extract_sources(context_chunks)
1232
+
1233
+ # Get similarity scores
1234
+ results = self.vector_retriever.search(question, self.args.k)
1235
+
1236
+ # Get sample chunk text (first 200 chars of first chunk)
1237
+ sample_chunk_text = context_chunks[0].text[:200] + "..." if context_chunks else "N/A"
1238
+
1239
+ diagnostics['retrieval_tests'].append({
1240
+ 'question': question,
1241
+ 'chunks_retrieved': len(context_chunks),
1242
+ 'sources': sources,
1243
+ 'similarity_scores': [f"{score:.3f}" for _, score in results],
1244
+ 'sample_chunk_preview': sample_chunk_text,
1245
+ 'status': 'OK' if context_chunks else 'NO_RESULTS'
1246
+ })
1247
+ except Exception as e:
1248
+ diagnostics['retrieval_tests'].append({
1249
+ 'question': question,
1250
+ 'status': 'ERROR',
1251
+ 'error': str(e)
1252
+ })
1253
+
1254
+ return diagnostics
1255
+
1256
+ def print_diagnostics(self, diagnostics: Dict[str, Any]) -> None:
1257
+ """Print diagnostic information in a readable format"""
1258
+ print("\n" + "="*80)
1259
+ print("SYSTEM DIAGNOSTICS")
1260
+ print("="*80)
1261
+
1262
+ # Vector DB Stats
1263
+ print("\n📊 VECTOR DATABASE:")
1264
+ vdb = diagnostics.get('vector_db_stats', {})
1265
+ print(f" Status: {vdb.get('status', 'UNKNOWN')}")
1266
+ print(f" Total chunks: {vdb.get('total_chunks', 0)}")
1267
+ print(f" Collection: {vdb.get('collection_name', 'unknown')}")
1268
+ if 'error' in vdb:
1269
+ print(f" Error: {vdb['error']}")
1270
+
1271
+ # Document Stats
1272
+ print("\n📄 DOCUMENT LOADING:")
1273
+ doc_stats = diagnostics.get('document_stats', {})
1274
+ print(f" Status: {doc_stats.get('status', 'UNKNOWN')}")
1275
+ print(f" Total files found: {doc_stats.get('total_files_found', 0)}")
1276
+ if 'sample_files' in doc_stats:
1277
+ print(f" Sample files:")
1278
+ for file_info in doc_stats['sample_files']:
1279
+ if file_info.get('readable', False):
1280
+ print(f" ✓ {file_info['filename']}: {file_info.get('size_chars', 0):,} chars, {file_info.get('size_words', 0):,} words")
1281
+ else:
1282
+ print(f" ✗ {file_info['filename']}: {file_info.get('error', 'unreadable')}")
1283
+ if 'error' in doc_stats:
1284
+ print(f" Error: {doc_stats['error']}")
1285
+
1286
+ # Chunk Stats
1287
+ print("\n✂️ CHUNKING:")
1288
+ chunk_stats = diagnostics.get('chunk_stats', {})
1289
+ print(f" Status: {chunk_stats.get('status', 'UNKNOWN')}")
1290
+ if chunk_stats.get('status') == 'OK':
1291
+ print(f" Sample document: {chunk_stats.get('sample_document', 'N/A')}")
1292
+ print(f" Total chunks from sample: {chunk_stats.get('total_chunks', 0)}")
1293
+ print(f" Average chunk size: {chunk_stats.get('avg_chunk_size_words', 0):.1f} words")
1294
+ print(f" Chunk size range: {chunk_stats.get('min_chunk_size_words', 0)} - {chunk_stats.get('max_chunk_size_words', 0)} words")
1295
+ print(f" Settings: size={chunk_stats.get('chunk_size_setting', 0)}, overlap={chunk_stats.get('chunk_overlap_setting', 0)}")
1296
+ if 'error' in chunk_stats:
1297
+ print(f" Error: {chunk_stats['error']}")
1298
+
1299
+ # Retrieval Tests
1300
+ if diagnostics.get('retrieval_tests'):
1301
+ print("\n🔍 RETRIEVAL TESTS:")
1302
+ for test in diagnostics['retrieval_tests']:
1303
+ print(f"\n Question: {test.get('question', 'N/A')}")
1304
+ print(f" Status: {test.get('status', 'UNKNOWN')}")
1305
+ if test.get('status') == 'OK':
1306
+ print(f" Chunks retrieved: {test.get('chunks_retrieved', 0)}")
1307
+ print(f" Sources: {test.get('sources', 'N/A')}")
1308
+ scores = test.get('similarity_scores', [])
1309
+ if scores:
1310
+ print(f" Similarity scores: {', '.join(scores)}")
1311
+ # Warn if scores are low
1312
+ try:
1313
+ score_values = [float(s) for s in scores]
1314
+ if max(score_values) < 0.3:
1315
+ print(f" ⚠️ WARNING: Low similarity scores - retrieved chunks may not be very relevant")
1316
+ elif max(score_values) < 0.5:
1317
+ print(f" ⚠️ NOTE: Moderate similarity - consider increasing --k or checking chunk quality")
1318
+ except:
1319
+ pass
1320
+ if 'sample_chunk_preview' in test:
1321
+ print(f" Sample chunk preview: {test['sample_chunk_preview']}")
1322
+ elif 'error' in test:
1323
+ print(f" Error: {test['error']}")
1324
+
1325
+ print("\n" + "="*80 + "\n")
1326
+
1327
+ def _extract_sources(self, context_chunks: List[Chunk]) -> str:
1328
+ """Extract source document names from context chunks"""
1329
+ sources = []
1330
+ for chunk in context_chunks:
1331
+ # Debug: Print chunk filename if verbose
1332
+ if self.args.verbose:
1333
+ logger.info(f"Chunk filename: {chunk.filename}")
1334
+
1335
+ # Extract filename from chunk attribute (not metadata)
1336
+ source = chunk.filename if hasattr(chunk, 'filename') and chunk.filename else 'Unknown source'
1337
+ # Clean up the source name
1338
+ if source.endswith('.pdf'):
1339
+ source = source[:-4] # Remove .pdf extension
1340
+ elif source.endswith('.txt'):
1341
+ source = source[:-4] # Remove .txt extension
1342
+ elif source.endswith('.md'):
1343
+ source = source[:-3] # Remove .md extension
1344
+
1345
+ sources.append(source)
1346
+
1347
+ # Remove duplicates while preserving order
1348
+ unique_sources = []
1349
+ for source in sources:
1350
+ if source not in unique_sources:
1351
+ unique_sources.append(source)
1352
+
1353
+ return "; ".join(unique_sources)
1354
+
1355
+ def _categorize_question(self, question: str) -> str:
1356
+ """Categorize a question into one of 5 categories"""
1357
+ question_lower = question.lower()
1358
+
1359
+ # Gene-Specific Recommendations
1360
+ if any(gene in question_lower for gene in ['msh2', 'mlh1', 'msh6', 'pms2', 'epcam', 'brca1', 'brca2']):
1361
+ if any(kw in question_lower for kw in ['screening', 'surveillance', 'prevention', 'recommendation', 'risk', 'cancer risk', 'steps', 'management']):
1362
+ return "Gene-Specific Recommendations"
1363
+
1364
+ # Inheritance Patterns
1365
+ if any(kw in question_lower for kw in ['inherit', 'inherited', 'pass', 'skip a generation', 'generation', 'can i pass']):
1366
+ return "Inheritance Patterns"
1367
+
1368
+ # Family Risk Assessment
1369
+ if any(kw in question_lower for kw in ['family member', 'relative', 'first-degree', 'family risk', 'which relative', 'should my family']):
1370
+ return "Family Risk Assessment"
1371
+
1372
+ # Genetic Variant Interpretation
1373
+ if any(kw in question_lower for kw in ['what does', 'genetic variant mean', 'variant mean', 'mutation mean', 'genetic result']):
1374
+ return "Genetic Variant Interpretation"
1375
+
1376
+ # Support and Resources
1377
+ if any(kw in question_lower for kw in ['cope', 'overwhelmed', 'resource', 'genetic counselor', 'support', 'research', 'help', 'insurance', 'gina']):
1378
+ return "Support and Resources"
1379
+
1380
+ # Default to Genetic Variant Interpretation if unclear
1381
+ return "Genetic Variant Interpretation"
1382
+
1383
+ def enhance_readability(self, answer: str, target_level: str = "middle_school") -> Tuple[str, float]:
1384
+ """Enhance answer readability to different levels and calculate Flesch-Kincaid Grade Level
1385
+
1386
+ Args:
1387
+ answer: The original answer to simplify or enhance
1388
+ target_level: One of "middle_school", "high_school", "college", or "doctoral"
1389
+
1390
+ Returns:
1391
+ Tuple of (enhanced_answer, grade_level)
1392
+ """
1393
+ try:
1394
+ # Define prompts for different reading levels
1395
+ if target_level == "middle_school":
1396
+ level_description = "middle school reading level (ages 12-14, 6th-8th grade)"
1397
+ instructions = """
1398
+ - Use simpler medical terms or explain them
1399
+ - Medium-length sentences
1400
+ - Clear, structured explanations
1401
+ - Keep important medical information accessible"""
1402
+ elif target_level == "high_school":
1403
+ level_description = "high school reading level (ages 15-18, 9th-12th grade)"
1404
+ instructions = """
1405
+ - Use appropriate medical terminology with context
1406
+ - Varied sentence length
1407
+ - Comprehensive yet accessible explanations
1408
+ - Maintain technical accuracy while ensuring clarity"""
1409
+ elif target_level == "college":
1410
+ level_description = "college reading level (undergraduate level, ages 18-22)"
1411
+ instructions = """
1412
+ - Use standard medical terminology with brief explanations
1413
+ - Professional and clear writing style
1414
+ - Include relevant clinical context
1415
+ - Maintain scientific accuracy and precision
1416
+ - Appropriate for undergraduate students in health sciences"""
1417
+ elif target_level == "doctoral":
1418
+ level_description = "doctoral/professional reading level (graduate level, medical professionals)"
1419
+ instructions = """
1420
+ - Use advanced medical and scientific terminology
1421
+ - Include detailed clinical and research context
1422
+ - Reference specific mechanisms, pathways, and evidence
1423
+ - Provide comprehensive technical explanations
1424
+ - Appropriate for medical professionals, researchers, and graduate students
1425
+ - Include nuanced discussions of clinical implications and research findings"""
1426
+ else:
1427
+ raise ValueError(f"Unknown target_level: {target_level}. Must be one of: middle_school, high_school, college, doctoral")
1428
+
1429
+ # Create a prompt to enhance the medical answer for the target level
1430
+ # Try to use chat template if available, otherwise use manual format
1431
+ system_message = f"""You are a helpful medical assistant who specializes in explaining complex medical information at appropriate reading levels. Rewrite the following medical answer for {level_description}:
1432
+ {instructions}
1433
+ - Keep the same important information but adapt the complexity
1434
+ - Provide context for technical terms
1435
+ - Ensure the answer is informative yet understandable"""
1436
+
1437
+ user_message = f"Please rewrite this medical answer for {level_description}:\n\n{answer}"
1438
+
1439
+ # Try to use chat template if available
1440
+ if hasattr(self.tokenizer, 'apply_chat_template') and self.tokenizer.chat_template is not None:
1441
+ try:
1442
+ messages = [
1443
+ {"role": "system", "content": system_message},
1444
+ {"role": "user", "content": user_message}
1445
+ ]
1446
+ readability_prompt = self.tokenizer.apply_chat_template(
1447
+ messages,
1448
+ tokenize=False,
1449
+ add_generation_prompt=True
1450
+ )
1451
+ except Exception as e:
1452
+ logger.warning(f"Failed to use chat template for readability, falling back to manual format: {e}")
1453
+ readability_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1454
+
1455
+ {system_message}
1456
+
1457
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
1458
+
1459
+ {user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1460
+
1461
+ """
1462
+ else:
1463
+ # Fall back to manual formatting (for Llama models)
1464
+ readability_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1465
+
1466
+ {system_message}
1467
+
1468
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
1469
+
1470
+ {user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1471
+
1472
+ """
1473
+
1474
+ # Generate simplified answer
1475
+ inputs = self.tokenizer(readability_prompt, return_tensors="pt", truncation=True, max_length=2048)
1476
+ if self.device == "mps":
1477
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
1478
+
1479
+ # Adjust generation parameters based on target level
1480
+ if target_level in ["college", "doctoral"]:
1481
+ max_tokens = 512 # Reduced from 1024 for faster responses
1482
+ temp = 0.4 # Slightly higher temperature for more natural flow
1483
+ else:
1484
+ max_tokens = 384 # Reduced from 512 for faster responses
1485
+ temp = 0.3 # Lower temperature for more consistent simplification
1486
+
1487
+ with torch.no_grad():
1488
+ outputs = self.model.generate(
1489
+ **inputs,
1490
+ max_new_tokens=max_tokens,
1491
+ temperature=temp,
1492
+ top_p=0.9,
1493
+ repetition_penalty=1.05,
1494
+ do_sample=True,
1495
+ pad_token_id=self.tokenizer.eos_token_id,
1496
+ eos_token_id=self.tokenizer.eos_token_id,
1497
+ use_cache=True,
1498
+ num_beams=1
1499
+ )
1500
+
1501
+ # Decode response
1502
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
1503
+
1504
+ # Extract enhanced answer
1505
+ # Try to find the assistant response marker
1506
+ prompt_end_marker = "<|start_header_id|>assistant<|end_header_id|>\n\n"
1507
+ if prompt_end_marker in response:
1508
+ simplified_answer = response.split(prompt_end_marker)[-1].strip()
1509
+ elif "<|assistant|>" in response:
1510
+ # Some chat templates use <|assistant|>
1511
+ simplified_answer = response.split("<|assistant|>")[-1].strip()
1512
+ else:
1513
+ # Fallback: extract everything after the prompt
1514
+ simplified_answer = response[len(readability_prompt):].strip()
1515
+
1516
+ # Clean up special tokens
1517
+ if "<|eot_id|>" in simplified_answer:
1518
+ if simplified_answer.endswith("<|eot_id|>"):
1519
+ simplified_answer = simplified_answer[:-len("<|eot_id|>")].strip()
1520
+ if "<|end_of_text|>" in simplified_answer:
1521
+ if simplified_answer.endswith("<|end_of_text|>"):
1522
+ simplified_answer = simplified_answer[:-len("<|end_of_text|>")].strip()
1523
+
1524
+ # Clean up unwanted phrases and formatting
1525
+ simplified_answer = self._clean_readability_answer(simplified_answer, target_level)
1526
+
1527
+ # Calculate Flesch-Kincaid Grade Level
1528
+ try:
1529
+ grade_level = textstat.flesch_kincaid_grade(simplified_answer)
1530
+ except:
1531
+ grade_level = 0.0
1532
+
1533
+ if self.args.verbose:
1534
+ logger.info(f"Simplified answer length: {len(simplified_answer)} characters")
1535
+ logger.info(f"Flesch-Kincaid Grade Level: {grade_level:.1f}")
1536
+
1537
+ return simplified_answer, grade_level
1538
+
1539
+ except Exception as e:
1540
+ logger.error(f"Error enhancing readability: {e}")
1541
+ # Fallback: return original answer with estimated grade level
1542
+ try:
1543
+ grade_level = textstat.flesch_kincaid_grade(answer)
1544
+ except:
1545
+ grade_level = 12.0 # Default to high school level
1546
+ return answer, grade_level
1547
+
1548
+ def write_csv(self, qa_pairs: List[Tuple[str, str, str, str, float, str, float, str, float, str, str]], output_path: str, append: bool = False, improved_prompt_text: str = "") -> None:
1549
+ """Write Q&A pairs to CSV file in results folder
1550
+
1551
+ Expected tuple format: (question, answer, sources, question_group, original_flesch,
1552
+ middle_school_answer, middle_school_flesch,
1553
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores)
1554
+ """
1555
+ # Ensure results directory exists
1556
+ os.makedirs('results', exist_ok=True)
1557
+
1558
+ # If output_path doesn't already have results/ prefix, add it
1559
+ if not output_path.startswith('results/'):
1560
+ output_path = f'results/{output_path}'
1561
+
1562
+ if append:
1563
+ logger.info(f"Appending results to {output_path}")
1564
+ else:
1565
+ logger.info(f"Writing results to {output_path}")
1566
+
1567
+ # Create output directory if needed
1568
+ output_path = Path(output_path)
1569
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1570
+
1571
+ try:
1572
+ # Check if file exists and if we're appending
1573
+ file_exists = output_path.exists()
1574
+ write_mode = 'a' if append and file_exists else 'w'
1575
+
1576
+ with open(output_path, write_mode, newline='', encoding='utf-8') as f:
1577
+ writer = csv.writer(f)
1578
+
1579
+ # Write header only if creating new file or first append
1580
+ if not append or not file_exists:
1581
+ # Create improved answer header with prompt text
1582
+ improved_header = f'improved_answer (PROMPT: {improved_prompt_text})'
1583
+ writer.writerow(['question', 'question_group', 'answer', 'original_flesch', 'sources',
1584
+ 'similarity_scores', 'middle_school_answer', 'middle_school_flesch',
1585
+ 'high_school_answer', 'high_school_flesch', improved_header])
1586
+
1587
+ for data in qa_pairs:
1588
+ # Unpack the data tuple
1589
+ (question, answer, sources, question_group, original_flesch,
1590
+ middle_school_answer, middle_school_flesch,
1591
+ high_school_answer, high_school_flesch, improved_answer, similarity_scores) = data
1592
+
1593
+ # Clean and escape the answers for CSV
1594
+ def clean_text(text):
1595
+ # Replace newlines with spaces and clean up formatting
1596
+ cleaned = text.replace('\n', ' ').replace('\r', ' ')
1597
+ # Remove extra whitespace but preserve the full content
1598
+ cleaned = ' '.join(cleaned.split())
1599
+ # Escape quotes properly for CSV
1600
+ cleaned = cleaned.replace('"', '""')
1601
+ return cleaned
1602
+
1603
+ clean_question = clean_text(question)
1604
+ clean_answer = clean_text(answer)
1605
+ clean_sources = clean_text(sources)
1606
+ clean_middle_school = clean_text(middle_school_answer)
1607
+ clean_high_school = clean_text(high_school_answer)
1608
+ clean_improved = clean_text(improved_answer)
1609
+
1610
+ # Log the full answer length for debugging
1611
+ if self.args.verbose:
1612
+ logger.info(f"Writing answer length: {len(clean_answer)} characters")
1613
+ logger.info(f"Middle school answer length: {len(clean_middle_school)} characters")
1614
+ logger.info(f"High school answer length: {len(clean_high_school)} characters")
1615
+ logger.info(f"Improved answer length: {len(clean_improved)} characters")
1616
+ logger.info(f"Question group: {question_group}")
1617
+
1618
+ # Use proper CSV quoting - let csv.writer handle the quoting
1619
+ writer.writerow([
1620
+ clean_question,
1621
+ question_group,
1622
+ clean_answer,
1623
+ f"{original_flesch:.1f}",
1624
+ clean_sources,
1625
+ similarity_scores, # Similarity scores (comma-separated)
1626
+ clean_middle_school,
1627
+ f"{middle_school_flesch:.1f}",
1628
+ clean_high_school,
1629
+ f"{high_school_flesch:.1f}",
1630
+ clean_improved
1631
+ ])
1632
+
1633
+ if append:
1634
+ logger.info(f"Appended {len(qa_pairs)} Q&A pairs to {output_path}")
1635
+ else:
1636
+ logger.info(f"Successfully wrote {len(qa_pairs)} Q&A pairs to {output_path}")
1637
+
1638
+ except Exception as e:
1639
+ logger.error(f"Failed to write CSV: {e}")
1640
+ sys.exit(4)
1641
+
1642
+
1643
+ def parse_args():
1644
+ """Parse command line arguments"""
1645
+ parser = argparse.ArgumentParser(description="RAG Chatbot for CGT-LLM-Beta with Vector Database")
1646
+
1647
+ # File paths
1648
+ parser.add_argument('--data-dir', default='./Data Resources',
1649
+ help='Directory containing documents to index')
1650
+ parser.add_argument('--questions', default='./questions.txt',
1651
+ help='File containing questions (one per line)')
1652
+ parser.add_argument('--out', default='./answers.csv',
1653
+ help='Output CSV file for answers')
1654
+ parser.add_argument('--vector-db-dir', default='./chroma_db',
1655
+ help='Directory for ChromaDB persistence')
1656
+
1657
+ # Retrieval parameters
1658
+ parser.add_argument('--k', type=int, default=5,
1659
+ help='Number of chunks to retrieve per question')
1660
+
1661
+ # Chunking parameters
1662
+ parser.add_argument('--chunk-size', type=int, default=500,
1663
+ help='Size of text chunks in tokens')
1664
+ parser.add_argument('--chunk-overlap', type=int, default=200,
1665
+ help='Overlap between chunks in tokens')
1666
+
1667
+ # Model selection
1668
+ parser.add_argument('--model', type=str, default='meta-llama/Llama-3.2-3B-Instruct',
1669
+ help='HuggingFace model name to use (e.g., meta-llama/Llama-3.2-3B-Instruct, mistralai/Mistral-7B-Instruct-v0.2)')
1670
+
1671
+ # Generation parameters
1672
+ parser.add_argument('--max-new-tokens', type=int, default=1024,
1673
+ help='Maximum new tokens to generate')
1674
+ parser.add_argument('--temperature', type=float, default=0.2,
1675
+ help='Generation temperature')
1676
+ parser.add_argument('--top-p', type=float, default=0.9,
1677
+ help='Top-p sampling parameter')
1678
+ parser.add_argument('--repetition-penalty', type=float, default=1.1,
1679
+ help='Repetition penalty')
1680
+
1681
+ # Database options
1682
+ parser.add_argument('--force-rebuild', action='store_true',
1683
+ help='Force rebuild of vector database')
1684
+ parser.add_argument('--skip-indexing', action='store_true',
1685
+ help='Skip document indexing, use existing database')
1686
+
1687
+ # Other options
1688
+ parser.add_argument('--seed', type=int, default=42,
1689
+ help='Random seed for reproducibility')
1690
+ parser.add_argument('--verbose', action='store_true',
1691
+ help='Enable verbose logging')
1692
+ parser.add_argument('--dry-run', action='store_true',
1693
+ help='Build index and test retrieval without generation')
1694
+ parser.add_argument('--diagnose', action='store_true',
1695
+ help='Run system diagnostics and exit')
1696
+
1697
+ return parser.parse_args()
1698
+
1699
+
1700
+ def main():
1701
+ """Main function"""
1702
+ args = parse_args()
1703
+
1704
+ # Set random seed
1705
+ torch.manual_seed(args.seed)
1706
+ np.random.seed(args.seed)
1707
+
1708
+ # Set logging level
1709
+ if args.verbose:
1710
+ logging.getLogger().setLevel(logging.DEBUG)
1711
+
1712
+ logger.info("Starting RAG Chatbot with Vector Database")
1713
+ logger.info(f"Arguments: {vars(args)}")
1714
+
1715
+ try:
1716
+ # Initialize bot
1717
+ bot = RAGBot(args)
1718
+
1719
+ # Check if we should skip indexing
1720
+ if not args.skip_indexing:
1721
+ # Load and process documents
1722
+ documents = bot.load_corpus(args.data_dir)
1723
+ if not documents:
1724
+ logger.error("No documents found to process")
1725
+ sys.exit(3)
1726
+
1727
+ # Chunk documents
1728
+ chunks = bot.chunk_documents(documents, args.chunk_size, args.chunk_overlap)
1729
+ if not chunks:
1730
+ logger.error("No chunks created from documents")
1731
+ sys.exit(3)
1732
+
1733
+ # Build or update index
1734
+ bot.build_or_update_index(chunks, args.force_rebuild)
1735
+ else:
1736
+ logger.info("Skipping document indexing, using existing vector database")
1737
+
1738
+ # Run diagnostics if requested
1739
+ if args.diagnose:
1740
+ sample_questions = [
1741
+ "What is Lynch Syndrome?",
1742
+ "What does a BRCA1 genetic variant mean?",
1743
+ "What screening tests are recommended for MSH2 carriers?"
1744
+ ]
1745
+ diagnostics = bot.diagnose_system(sample_questions=sample_questions)
1746
+ bot.print_diagnostics(diagnostics)
1747
+ return
1748
+
1749
+ if args.dry_run:
1750
+ logger.info("Dry run completed successfully")
1751
+ return
1752
+
1753
+ # Process questions
1754
+ generation_kwargs = {
1755
+ 'max_new_tokens': args.max_new_tokens,
1756
+ 'temperature': args.temperature,
1757
+ 'top_p': args.top_p,
1758
+ 'repetition_penalty': args.repetition_penalty
1759
+ }
1760
+
1761
+ qa_pairs = bot.process_questions(args.questions, output_file=args.out, **generation_kwargs)
1762
+
1763
+ logger.info("RAG Chatbot completed successfully")
1764
+
1765
+ except KeyboardInterrupt:
1766
+ logger.info("Interrupted by user")
1767
+ sys.exit(0)
1768
+ except Exception as e:
1769
+ logger.error(f"Unexpected error: {e}")
1770
+ if args.verbose:
1771
+ import traceback
1772
+ traceback.print_exc()
1773
+ sys.exit(1)
1774
+
1775
+
1776
+ if __name__ == "__main__":
1777
+ main()
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80fe29380be0f587de8c3d0df3bbd891219ebe35d3ab4e007721d322ca704b9f
3
+ size 18888520
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56091853c1c20a1ec97ba4a7935cb7ab95f58b91d1ca56b990bf768f7bd2df88
3
+ size 100
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:754f12ddf66368443039e44c7d3625dbfa54c42604f231054e5c8ab8df162ebb
3
+ size 548379
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e72c9f5fb80c8fa3f488f68172cf32cdaf226d94cb6cff09ff68990b34fbb04c
3
+ size 45080
chroma_db/7eddb202-b9b0-46c1-ae4b-37838cdc5aac/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0046b8333ff42649a27896a5da1f0fd89ee54954221fde9172dfe284d94262b
3
+ size 99820
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70340ab0d0dddb6b5bcf29c0e09f316b0f695f6645be0231302346d5af463700
3
+ size 294584320
requirements.txt ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # RAG Chatbot with Vector Database - Requirements
3
+ # =============================================================================
4
+ # Production-ready dependencies for medical document analysis and Q&A
5
+
6
+ # Core ML/AI Framework
7
+ torch>=2.0.0 # PyTorch for model inference
8
+ transformers>=4.30.0 # Hugging Face transformers
9
+ huggingface_hub>=0.23.0 # Hugging Face Hub API (for Inference API)
10
+ accelerate>=0.20.0 # Model loading optimization
11
+ safetensors>=0.3.0 # Safe model loading
12
+
13
+ # Vector Database & Embeddings
14
+ chromadb>=0.4.0 # Vector database for fast retrieval
15
+ sentence-transformers>=2.2.0 # Semantic embeddings (all-MiniLM-L6-v2)
16
+
17
+ # Data Processing
18
+ pandas>=1.3.0 # Data manipulation and CSV handling
19
+ numpy>=1.21.0 # Numerical computing
20
+ scikit-learn>=1.0.0 # ML utilities and TF-IDF
21
+
22
+ # Text Analysis & Readability
23
+ textstat>=0.7.0 # Flesch-Kincaid Grade Level calculation
24
+ nltk>=3.8.0 # Natural language processing utilities
25
+
26
+ # Document Processing (Core)
27
+ pypdf>=3.0.0 # PDF document parsing
28
+ python-docx>=0.8.11 # DOCX document parsing
29
+
30
+ # Optional Document Processing
31
+ rank-bm25>=0.2.2 # BM25 retrieval algorithm (alternative to TF-IDF)
32
+
33
+ # Utilities & Progress
34
+ tqdm>=4.65.0 # Progress bars
35
+ pathlib2>=2.3.0 # Enhanced path handling (if needed)
36
+
37
+ # Web Interface
38
+ gradio>=4.44.1 # Gradio web interface for chatbot (updated for Spaces compatibility)
39
+
40
+ # Development & Testing (Optional)
41
+ pytest>=7.0.0 # Testing framework
42
+ black>=22.0.0 # Code formatting
43
+ flake8>=4.0.0 # Code linting
44
+
45
+ # Performance Monitoring (Optional)
46
+ psutil>=5.8.0 # System resource monitoring
47
+ memory-profiler>=0.60.0 # Memory usage profiling
48
+
49
+ # =============================================================================
50
+ # Installation Notes:
51
+ # =============================================================================
52
+ # 1. Install with: pip install -r requirements.txt
53
+ # 2. For Apple Silicon: PyTorch will automatically use MPS acceleration
54
+ # 3. Optional packages can be installed separately if needed
55
+ # 4. Model files (~6GB) will be downloaded on first run
56
+ # =============================================================================