study-buddy-ai / app.py
dharma087's picture
doc_updated
64f18ec verified
# --- The Final, Definitive, and Corrected Application ---
import os
import time
import google.generativeai as genai
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from dotenv import load_dotenv
from google.api_core.exceptions import ResourceExhausted
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.utils.logging import set_verbosity_error
import gradio as gr
import PyPDF2
from docx import Document
# --- 1. GLOBAL SETUP ---
set_verbosity_error()
load_dotenv()
# --- 2. ONE-TIME MODEL INITIALIZATION ---
def initialize_hf_models():
"""Loads all local Hugging Face models ONCE."""
print("--- Initializing Hugging Face Models (once) ---")
device = -1
print(f"✅ Using device: CPU (forced for HF models for stability)")
start_time = time.time()
summarizer_pipeline = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)
summarizer = HuggingFacePipeline(pipeline=summarizer_pipeline)
print(f"-> Summarization model loaded in {time.time() - start_time:.2f} seconds.")
start_time = time.time()
qa_pipeline_obj = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=device)
print(f"-> Q&A model loaded in {time.time() - start_time:.2f} seconds.")
return summarizer, qa_pipeline_obj
SUMMARIZER_MODEL, QA_PIPELINE_MODEL = initialize_hf_models()
# Initialize the Hugging Face summarization model
def initialize_hf_summarizer():
"""Initialize the Hugging Face summarization model."""
print("--- Initializing Hugging Face Summarization Model ---")
device = -1 # Use CPU
tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-large-16384")
print("✅ Hugging Face summarization model loaded.")
return tokenizer, model
HF_TOKENIZER, HF_MODEL = initialize_hf_summarizer()
# Summarize text or document using LED model
def summarize_text(tokenizer, model, text):
print("\n⏳ Generating summary...")
start_time = time.time()
inputs = tokenizer(text, return_tensors="pt", max_length=16384, truncation=True)
summary_ids = model.generate(inputs["input_ids"], max_length=512, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(f"-> Summary generated in {time.time() - start_time:.2f} seconds.")
return summary
# --- 3. TASK-SPECIFIC FUNCTIONS (No changes here, they were correct) ---
def summarize_text_with_prompt(summarizer, text):
print("\n⏳ Generating summary...")
start_time = time.time()
summary_template = PromptTemplate.from_template("Summarize the following text in a concise way:\n\n{text}")
chain = summary_template | summarizer
try:
summary = chain.invoke({"text": text})
print(f"-> Summary generated in {time.time() - start_time:.2f} seconds.")
return summary
except Exception as e:
raise gr.Error(f"Error during summarization: {e}")
def create_quiz(gemini_key, text, num_questions):
print(f"\n⏳ Generating {num_questions} quiz questions with Gemini...")
start_time = time.time()
try:
genai.configure(api_key=gemini_key)
# Replace "models/chat-bison-002" with a valid model name from the list_models output
SELECTED_MODEL = "models/gemini-2.5-pro" # Example model name, replace with an appropriate one
gemini_model = ChatGoogleGenerativeAI(model=SELECTED_MODEL, google_api_key=gemini_key, temperature=0.7)
except Exception as e:
raise gr.Error(f"Gemini API configuration error. Check your key. Details: {e}")
example = """[START OF EXAMPLE]
Context: The Moon is Earth's only natural satellite. It is the fifth largest satellite in the Solar System. The dark areas on its surface are called maria.
Quiz:
Q: What is the Moon's status relative to Earth?
A) A man-made satellite
B) A natural satellite
C) A dwarf planet
D) A star
Answer: B
Q: The dark areas on the Moon's surface are known as what?
A) Craters
B) Valleys
C) Maria
D) Highlands
Answer: C
[END OF EXAMPLE]"""
prompt_text = f"{example}\n\n[START OF TASK]\nContext: {{text}}\n\nGenerate exactly {{num_questions}} multiple-choice questions in the same format. Each question must have 4 options (A-D) and indicate the correct Answer.\n\nQuiz:"
prompt = PromptTemplate.from_template(prompt_text)
chain = LLMChain(llm=gemini_model, prompt=prompt)
try:
quiz_text = chain.run(text=text, num_questions=num_questions)
print(f"-> Quiz generated in {time.time() - start_time:.2f} seconds.")
return quiz_text
except Exception as e:
raise gr.Error(f"Error during quiz generation: {e}")
def answer_question(qa_pipeline, text, question):
print(f"\n⏳ Answering question: '{question}'")
start_time = time.time()
try:
result = qa_pipeline(question=question, context=text)
print(f"-> Answer generated in {time.time() - start_time:.2f} seconds.")
return f"Answer: {result['answer']}"
except Exception as e:
raise gr.Error(f"Error during Q&A: {e}")
def create_flashcards(gemini_key, text, num_flashcards):
print(f"\n⏳ Generating {num_flashcards} flashcards with Gemini...")
start_time = time.time()
try:
genai.configure(api_key=gemini_key)
gemini_model = ChatGoogleGenerativeAI(model="models/gemini-2.5-pro", google_api_key=gemini_key, temperature=0.7)
except Exception as e:
raise gr.Error(f"Gemini API configuration error. Check your key. Details: {e}")
example = """[START OF EXAMPLE]
Context: The Moon is Earth's only natural satellite. It is the fifth largest satellite in the Solar System. The dark areas on its surface are called maria.
Flashcards:
Flashcard 1:
Front: What is Earth's only natural satellite?
Back: The Moon
Flashcard 2:
Front: What are the dark areas on the Moon's surface called?
Back: Maria
[END OF EXAMPLE]"""
prompt_text = f"{example}\n\n[START OF TASK]\nContext: {{text}}\n\nGenerate exactly {{num_flashcards}} flashcards in the same format.\n\nFlashcards:"
prompt = PromptTemplate.from_template(prompt_text)
chain = LLMChain(llm=gemini_model, prompt=prompt)
try:
flashcards_text = chain.run(text=text, num_flashcards=num_flashcards)
print(f"-> Flashcards generated in {time.time() - start_time:.2f} seconds.")
return flashcards_text
except Exception as e:
raise gr.Error(f"Error during flashcard generation: {e}")
# --- 4. MAIN PROCESSING FUNCTION (REWRITTEN FOR CLARITY AND CORRECTNESS) ---
# Update the process_request function to use Gemini API for summarization when a document is uploaded
def process_request(text, task, num_items, question, file, progress=gr.Progress()):
"""Main function called by the Gradio interface with corrected logic."""
progress(0, desc="Starting...")
gemini_key = os.getenv("GEMINI_API_KEY")
# If a file is uploaded, extract its content
if file is not None:
text = extract_text_from_file(file)
if not text:
raise gr.Error("Please provide input text or upload a document.")
output_content = "An unexpected error occurred."
if task == "Summary":
progress(0.5, desc="Generating summary...")
try:
if file is not None:
# Use LED model for documents
output_content = summarize_text(HF_TOKENIZER, HF_MODEL, text)
else:
# Use the text summarizer for text input
output_content = summarize_text_with_prompt(SUMMARIZER_MODEL, text)
except Exception as e:
raise gr.Error(f"Error during summarization: {e}")
elif task == "Q&A":
if not question or not question.strip():
raise gr.Error("Please enter a question for the Q&A task.")
progress(0.5, desc="Finding answer...")
output_content = answer_question(QA_PIPELINE_MODEL, text, question)
elif task == "Quiz":
if not gemini_key:
raise gr.Error("API Key Error: The app owner has not set the GEMINI_API_KEY secret in the Hugging Face Space.")
progress(0.5, desc=f"Generating {num_items} quiz questions...")
output_content = create_quiz(gemini_key, text, num_questions=num_items)
elif task == "Flashcards":
if not gemini_key:
raise gr.Error("API Key Error: The app owner has not set the GEMINI_API_KEY secret in the Hugging Face Space.")
progress(0.5, desc=f"Generating {num_items} flashcards...")
output_content = create_flashcards(gemini_key, text, num_flashcards=num_items)
progress(1, desc="Done!")
return output_content
# Add support for document upload and processing
# Function to extract text from uploaded files
def extract_text_from_file(file):
"""Extract text from uploaded file based on its type."""
if file.name.endswith(".txt"):
# Handle .txt files
with open(file.name, "r", encoding="utf-8") as f:
return f.read()
elif file.name.endswith(".pdf"):
# Handle .pdf files
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
elif file.name.endswith(".docx"):
# Handle .docx files
doc = Document(file)
text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
return text
else:
raise gr.Error("Unsupported file type. Please upload a .txt, .pdf, or .docx file.")
# --- 5. GRADIO INTERFACE (NO CHANGES NEEDED HERE) ---
with gr.Blocks(title="Study Buddy AI with Document Upload") as demo:
gr.Markdown("# Study Buddy AI: Summary, Quiz, Q&A, Flashcards with Document Upload")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(label="Input Text", lines=10, placeholder="Paste your study material here...")
file_input = gr.File(label="Upload Document (.txt, .pdf, .docx)")
with gr.Column(scale=1):
task_dropdown = gr.Dropdown(choices=["Summary", "Quiz", "Q&A", "Flashcards"], label="Select a Task", value="Summary")
num_items_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Number of Questions/Flashcards")
question_input = gr.Textbox(label="Your Question (for Q&A task only)", placeholder="e.g., What is the Great Red Spot?")
submit_button = gr.Button("Generate", variant="primary")
output_textbox = gr.Textbox(label="Output", lines=15, interactive=False)
submit_button.click(
fn=process_request,
inputs=[text_input, task_dropdown, num_items_slider, question_input, file_input],
outputs=output_textbox
)
if __name__ == "__main__":
demo.launch()