Spaces:
Running
Running
| # --- The Final, Definitive, and Corrected Application --- | |
| import os | |
| import time | |
| import google.generativeai as genai | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_huggingface import HuggingFacePipeline | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from dotenv import load_dotenv | |
| from google.api_core.exceptions import ResourceExhausted | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from transformers.utils.logging import set_verbosity_error | |
| import gradio as gr | |
| import PyPDF2 | |
| from docx import Document | |
| # --- 1. GLOBAL SETUP --- | |
| set_verbosity_error() | |
| load_dotenv() | |
| # --- 2. ONE-TIME MODEL INITIALIZATION --- | |
| def initialize_hf_models(): | |
| """Loads all local Hugging Face models ONCE.""" | |
| print("--- Initializing Hugging Face Models (once) ---") | |
| device = -1 | |
| print(f"✅ Using device: CPU (forced for HF models for stability)") | |
| start_time = time.time() | |
| summarizer_pipeline = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device) | |
| summarizer = HuggingFacePipeline(pipeline=summarizer_pipeline) | |
| print(f"-> Summarization model loaded in {time.time() - start_time:.2f} seconds.") | |
| start_time = time.time() | |
| qa_pipeline_obj = pipeline("question-answering", model="distilbert-base-cased-distilled-squad", device=device) | |
| print(f"-> Q&A model loaded in {time.time() - start_time:.2f} seconds.") | |
| return summarizer, qa_pipeline_obj | |
| SUMMARIZER_MODEL, QA_PIPELINE_MODEL = initialize_hf_models() | |
| # Initialize the Hugging Face summarization model | |
| def initialize_hf_summarizer(): | |
| """Initialize the Hugging Face summarization model.""" | |
| print("--- Initializing Hugging Face Summarization Model ---") | |
| device = -1 # Use CPU | |
| tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-large-16384") | |
| print("✅ Hugging Face summarization model loaded.") | |
| return tokenizer, model | |
| HF_TOKENIZER, HF_MODEL = initialize_hf_summarizer() | |
| # Summarize text or document using LED model | |
| def summarize_text(tokenizer, model, text): | |
| print("\n⏳ Generating summary...") | |
| start_time = time.time() | |
| inputs = tokenizer(text, return_tensors="pt", max_length=16384, truncation=True) | |
| summary_ids = model.generate(inputs["input_ids"], max_length=512, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| print(f"-> Summary generated in {time.time() - start_time:.2f} seconds.") | |
| return summary | |
| # --- 3. TASK-SPECIFIC FUNCTIONS (No changes here, they were correct) --- | |
| def summarize_text_with_prompt(summarizer, text): | |
| print("\n⏳ Generating summary...") | |
| start_time = time.time() | |
| summary_template = PromptTemplate.from_template("Summarize the following text in a concise way:\n\n{text}") | |
| chain = summary_template | summarizer | |
| try: | |
| summary = chain.invoke({"text": text}) | |
| print(f"-> Summary generated in {time.time() - start_time:.2f} seconds.") | |
| return summary | |
| except Exception as e: | |
| raise gr.Error(f"Error during summarization: {e}") | |
| def create_quiz(gemini_key, text, num_questions): | |
| print(f"\n⏳ Generating {num_questions} quiz questions with Gemini...") | |
| start_time = time.time() | |
| try: | |
| genai.configure(api_key=gemini_key) | |
| # Replace "models/chat-bison-002" with a valid model name from the list_models output | |
| SELECTED_MODEL = "models/gemini-2.5-pro" # Example model name, replace with an appropriate one | |
| gemini_model = ChatGoogleGenerativeAI(model=SELECTED_MODEL, google_api_key=gemini_key, temperature=0.7) | |
| except Exception as e: | |
| raise gr.Error(f"Gemini API configuration error. Check your key. Details: {e}") | |
| example = """[START OF EXAMPLE] | |
| Context: The Moon is Earth's only natural satellite. It is the fifth largest satellite in the Solar System. The dark areas on its surface are called maria. | |
| Quiz: | |
| Q: What is the Moon's status relative to Earth? | |
| A) A man-made satellite | |
| B) A natural satellite | |
| C) A dwarf planet | |
| D) A star | |
| Answer: B | |
| Q: The dark areas on the Moon's surface are known as what? | |
| A) Craters | |
| B) Valleys | |
| C) Maria | |
| D) Highlands | |
| Answer: C | |
| [END OF EXAMPLE]""" | |
| prompt_text = f"{example}\n\n[START OF TASK]\nContext: {{text}}\n\nGenerate exactly {{num_questions}} multiple-choice questions in the same format. Each question must have 4 options (A-D) and indicate the correct Answer.\n\nQuiz:" | |
| prompt = PromptTemplate.from_template(prompt_text) | |
| chain = LLMChain(llm=gemini_model, prompt=prompt) | |
| try: | |
| quiz_text = chain.run(text=text, num_questions=num_questions) | |
| print(f"-> Quiz generated in {time.time() - start_time:.2f} seconds.") | |
| return quiz_text | |
| except Exception as e: | |
| raise gr.Error(f"Error during quiz generation: {e}") | |
| def answer_question(qa_pipeline, text, question): | |
| print(f"\n⏳ Answering question: '{question}'") | |
| start_time = time.time() | |
| try: | |
| result = qa_pipeline(question=question, context=text) | |
| print(f"-> Answer generated in {time.time() - start_time:.2f} seconds.") | |
| return f"Answer: {result['answer']}" | |
| except Exception as e: | |
| raise gr.Error(f"Error during Q&A: {e}") | |
| def create_flashcards(gemini_key, text, num_flashcards): | |
| print(f"\n⏳ Generating {num_flashcards} flashcards with Gemini...") | |
| start_time = time.time() | |
| try: | |
| genai.configure(api_key=gemini_key) | |
| gemini_model = ChatGoogleGenerativeAI(model="models/gemini-2.5-pro", google_api_key=gemini_key, temperature=0.7) | |
| except Exception as e: | |
| raise gr.Error(f"Gemini API configuration error. Check your key. Details: {e}") | |
| example = """[START OF EXAMPLE] | |
| Context: The Moon is Earth's only natural satellite. It is the fifth largest satellite in the Solar System. The dark areas on its surface are called maria. | |
| Flashcards: | |
| Flashcard 1: | |
| Front: What is Earth's only natural satellite? | |
| Back: The Moon | |
| Flashcard 2: | |
| Front: What are the dark areas on the Moon's surface called? | |
| Back: Maria | |
| [END OF EXAMPLE]""" | |
| prompt_text = f"{example}\n\n[START OF TASK]\nContext: {{text}}\n\nGenerate exactly {{num_flashcards}} flashcards in the same format.\n\nFlashcards:" | |
| prompt = PromptTemplate.from_template(prompt_text) | |
| chain = LLMChain(llm=gemini_model, prompt=prompt) | |
| try: | |
| flashcards_text = chain.run(text=text, num_flashcards=num_flashcards) | |
| print(f"-> Flashcards generated in {time.time() - start_time:.2f} seconds.") | |
| return flashcards_text | |
| except Exception as e: | |
| raise gr.Error(f"Error during flashcard generation: {e}") | |
| # --- 4. MAIN PROCESSING FUNCTION (REWRITTEN FOR CLARITY AND CORRECTNESS) --- | |
| # Update the process_request function to use Gemini API for summarization when a document is uploaded | |
| def process_request(text, task, num_items, question, file, progress=gr.Progress()): | |
| """Main function called by the Gradio interface with corrected logic.""" | |
| progress(0, desc="Starting...") | |
| gemini_key = os.getenv("GEMINI_API_KEY") | |
| # If a file is uploaded, extract its content | |
| if file is not None: | |
| text = extract_text_from_file(file) | |
| if not text: | |
| raise gr.Error("Please provide input text or upload a document.") | |
| output_content = "An unexpected error occurred." | |
| if task == "Summary": | |
| progress(0.5, desc="Generating summary...") | |
| try: | |
| if file is not None: | |
| # Use LED model for documents | |
| output_content = summarize_text(HF_TOKENIZER, HF_MODEL, text) | |
| else: | |
| # Use the text summarizer for text input | |
| output_content = summarize_text_with_prompt(SUMMARIZER_MODEL, text) | |
| except Exception as e: | |
| raise gr.Error(f"Error during summarization: {e}") | |
| elif task == "Q&A": | |
| if not question or not question.strip(): | |
| raise gr.Error("Please enter a question for the Q&A task.") | |
| progress(0.5, desc="Finding answer...") | |
| output_content = answer_question(QA_PIPELINE_MODEL, text, question) | |
| elif task == "Quiz": | |
| if not gemini_key: | |
| raise gr.Error("API Key Error: The app owner has not set the GEMINI_API_KEY secret in the Hugging Face Space.") | |
| progress(0.5, desc=f"Generating {num_items} quiz questions...") | |
| output_content = create_quiz(gemini_key, text, num_questions=num_items) | |
| elif task == "Flashcards": | |
| if not gemini_key: | |
| raise gr.Error("API Key Error: The app owner has not set the GEMINI_API_KEY secret in the Hugging Face Space.") | |
| progress(0.5, desc=f"Generating {num_items} flashcards...") | |
| output_content = create_flashcards(gemini_key, text, num_flashcards=num_items) | |
| progress(1, desc="Done!") | |
| return output_content | |
| # Add support for document upload and processing | |
| # Function to extract text from uploaded files | |
| def extract_text_from_file(file): | |
| """Extract text from uploaded file based on its type.""" | |
| if file.name.endswith(".txt"): | |
| # Handle .txt files | |
| with open(file.name, "r", encoding="utf-8") as f: | |
| return f.read() | |
| elif file.name.endswith(".pdf"): | |
| # Handle .pdf files | |
| pdf_reader = PyPDF2.PdfReader(file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| return text | |
| elif file.name.endswith(".docx"): | |
| # Handle .docx files | |
| doc = Document(file) | |
| text = "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
| return text | |
| else: | |
| raise gr.Error("Unsupported file type. Please upload a .txt, .pdf, or .docx file.") | |
| # --- 5. GRADIO INTERFACE (NO CHANGES NEEDED HERE) --- | |
| with gr.Blocks(title="Study Buddy AI with Document Upload") as demo: | |
| gr.Markdown("# Study Buddy AI: Summary, Quiz, Q&A, Flashcards with Document Upload") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox(label="Input Text", lines=10, placeholder="Paste your study material here...") | |
| file_input = gr.File(label="Upload Document (.txt, .pdf, .docx)") | |
| with gr.Column(scale=1): | |
| task_dropdown = gr.Dropdown(choices=["Summary", "Quiz", "Q&A", "Flashcards"], label="Select a Task", value="Summary") | |
| num_items_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Number of Questions/Flashcards") | |
| question_input = gr.Textbox(label="Your Question (for Q&A task only)", placeholder="e.g., What is the Great Red Spot?") | |
| submit_button = gr.Button("Generate", variant="primary") | |
| output_textbox = gr.Textbox(label="Output", lines=15, interactive=False) | |
| submit_button.click( | |
| fn=process_request, | |
| inputs=[text_input, task_dropdown, num_items_slider, question_input, file_input], | |
| outputs=output_textbox | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |