Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| import torch | |
| import numpy as np | |
| from tqdm.auto import tqdm | |
| import os | |
| # CSS to style the custom share button (for the "Sparse Representation" tab) | |
| css = """ | |
| .share-button-container { | |
| display: flex; | |
| justify-content: center; | |
| margin-top: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .custom-share-button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| border: none; | |
| color: white; | |
| padding: 8px 16px; | |
| text-align: center; | |
| text-decoration: none; | |
| display: inline-block; | |
| font-size: 14px; | |
| margin: 4px 2px; | |
| cursor: pointer; | |
| border-radius: 6px; | |
| transition: all 0.3s ease; | |
| } | |
| .custom-share-button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2); | |
| } | |
| /* IMPORTANT: This CSS targets Gradio's *default* share button that appears | |
| when demo.launch(share=True) is used. | |
| You might want to comment this out if you prefer Gradio's default positioning | |
| for the main share button (usually in the header/footer) and rely only on your custom one. | |
| */ | |
| .share-button { | |
| position: fixed !important; | |
| top: 20px !important; | |
| right: 20px !important; | |
| z-index: 1000 !important; | |
| background: #4CAF50 !important; | |
| color: white !important; | |
| border-radius: 8px !important; | |
| padding: 8px 16px !important; | |
| font-weight: bold !important; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important; | |
| } | |
| .share-button:hover { | |
| background: #45a049 !important; | |
| transform: translateY(-1px) !important; | |
| } | |
| """ | |
| # --- Model Loading --- | |
| tokenizer_splade = None | |
| model_splade = None | |
| tokenizer_splade_lexical = None | |
| model_splade_lexical = None | |
| tokenizer_splade_doc = None | |
| model_splade_doc = None | |
| # Load SPLADE v3 model (original) | |
| try: | |
| tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil") | |
| model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil") | |
| model_splade.eval() | |
| print("SPLADE-cocondenser-distil model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE-cocondenser-distil model: {e}") | |
| print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.") | |
| # Load SPLADE v3 Lexical model | |
| try: | |
| splade_lexical_model_name = "naver/splade-v3-lexical" | |
| tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name) | |
| model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name) | |
| model_splade_lexical.eval() | |
| print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE-v3-Lexical model: {e}") | |
| print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
| # Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW | |
| try: | |
| splade_doc_model_name = "naver/splade-v3-doc" | |
| tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name) | |
| model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model | |
| model_splade_doc.eval() | |
| print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading SPLADE-v3-Doc model: {e}") | |
| print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).") | |
| # --- Helper function for lexical mask (now handles batches, but used for single input here) --- | |
| def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer): | |
| """ | |
| Creates a batch of lexical BOW masks. | |
| input_ids_batch: torch.Tensor of shape (batch_size, sequence_length) | |
| vocab_size: int, size of the tokenizer vocabulary | |
| tokenizer: the tokenizer object | |
| Returns: torch.Tensor of shape (batch_size, vocab_size) | |
| """ | |
| batch_size = input_ids_batch.shape[0] | |
| bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device) | |
| for i in range(batch_size): | |
| input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch | |
| meaningful_token_ids = [] | |
| for token_id in input_ids.tolist(): | |
| if token_id not in [ | |
| tokenizer.pad_token_id, | |
| tokenizer.cls_token_id, | |
| tokenizer.sep_token_id, | |
| tokenizer.mask_token_id, | |
| tokenizer.unk_token_id | |
| ]: | |
| meaningful_token_ids.append(token_id) | |
| if meaningful_token_ids: | |
| # Apply mask to the current row in the batch | |
| bow_masks[i, list(set(meaningful_token_ids))] = 1 | |
| return bow_masks | |
| # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) --- | |
| # These functions now return a tuple: (main_representation_str, info_str) | |
| def get_splade_cocondenser_representation(text): | |
| if tokenizer_splade is None or model_splade is None: | |
| return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", "" | |
| inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() # Squeeze is fine here as it's a single input | |
| else: | |
| return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", "" | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| values = splade_vector[indices].cpu().tolist() | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| formatted_output = "MLM Representation:\n\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| # Changed to paragraph style | |
| terms_list = [] | |
| for term, weight in sorted_representation: | |
| terms_list.append(f"**{term}**: {weight:.4f}") | |
| formatted_output += ", ".join(terms_list) + "." | |
| info_output = f"" # Line 1 | |
| info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity) | |
| return formatted_output, info_output | |
| def get_splade_lexical_representation(text): | |
| if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
| return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", "" | |
| inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade_lexical(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() # Squeeze is fine here | |
| else: | |
| return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", "" | |
| # Always apply lexical mask for this model's specific behavior | |
| vocab_size = tokenizer_splade_lexical.vocab_size | |
| # Call with unsqueezed input_ids for single sample processing | |
| bow_mask = create_lexical_bow_mask( | |
| inputs['input_ids'], vocab_size, tokenizer_splade_lexical | |
| ).squeeze() # Squeeze back for single output | |
| splade_vector = splade_vector * bow_mask | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| values = splade_vector[indices].cpu().tolist() | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade_lexical.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| formatted_output = "MLP Representation:\n\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| # Changed to paragraph style | |
| terms_list = [] | |
| for term, weight in sorted_representation: | |
| terms_list.append(f"**{term}**: {weight:.4f}") | |
| formatted_output += ", ".join(terms_list) + "." | |
| info_output = f"" # Line 1 | |
| info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity) | |
| return formatted_output, info_output | |
| def get_splade_doc_representation(text): | |
| if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits' | |
| return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", "" | |
| inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation | |
| vocab_size = tokenizer_splade_doc.vocab_size | |
| # Directly create the binary Bag-of-Words vector using the input_ids | |
| binary_bow_vector = create_lexical_bow_mask( | |
| inputs['input_ids'], vocab_size, tokenizer_splade_doc | |
| ).squeeze() # Squeeze back for single output | |
| indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| values = [1.0] * len(indices) # All values are 1 for binary representation | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer_splade_doc.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity | |
| formatted_output = "Binary:\n\n" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found for this input.\n" | |
| else: | |
| # Changed to paragraph style | |
| terms_list = [] | |
| for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display | |
| terms_list.append(f"**{term}**") | |
| formatted_output += ", ".join(terms_list) + "." | |
| info_output = f"" # Line 1 | |
| info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2 | |
| return formatted_output, info_output | |
| # --- Unified Prediction Function for the Explorer Tab --- | |
| def predict_representation_explorer(model_choice, text): | |
| if model_choice == "MLM encoder (SPLADE-cocondenser-distil)": | |
| return get_splade_cocondenser_representation(text) | |
| elif model_choice == "MLP encoder (SPLADE-v3-lexical)": | |
| return get_splade_lexical_representation(text) | |
| elif model_choice == "Binary": # Changed name | |
| return get_splade_doc_representation(text) | |
| else: | |
| return "Please select a model.", "" # Return two empty strings for consistency | |
| # --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) --- | |
| # These functions remain unchanged from the previous iteration, as they return the raw tensors. | |
| def get_splade_cocondenser_vector(text): | |
| if tokenizer_splade is None or model_splade is None: | |
| return None | |
| inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() | |
| return splade_vector | |
| return None | |
| def get_splade_lexical_vector(text): | |
| if tokenizer_splade_lexical is None or model_splade_lexical is None: | |
| return None | |
| inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model_splade_lexical(**inputs) | |
| if hasattr(output, 'logits'): | |
| splade_vector = torch.max( | |
| torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), | |
| dim=1 | |
| )[0].squeeze() | |
| vocab_size = tokenizer_splade_lexical.vocab_size | |
| bow_mask = create_lexical_bow_mask( | |
| inputs['input_ids'], vocab_size, tokenizer_splade_lexical | |
| ).squeeze() | |
| splade_vector = splade_vector * bow_mask | |
| return splade_vector | |
| return None | |
| def get_splade_doc_vector(text): | |
| if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits' | |
| return None | |
| inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation | |
| vocab_size = tokenizer_splade_doc.vocab_size | |
| # Directly create the binary Bag-of-Words vector using the input_ids | |
| binary_bow_vector = create_lexical_bow_mask( | |
| inputs['input_ids'], vocab_size, tokenizer_splade_doc | |
| ).squeeze() | |
| return binary_bow_vector | |
| # --- Function to get formatted representation from a raw vector and tokenizer --- | |
| # This function remains unchanged as it's a generic formatter for any sparse vector. | |
| def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False): | |
| if splade_vector is None: | |
| return "Failed to generate vector.", "" | |
| indices = torch.nonzero(splade_vector).squeeze().cpu().tolist() | |
| if not isinstance(indices, list): | |
| indices = [indices] if indices else [] | |
| if is_binary: | |
| values = [1.0] * len(indices) | |
| else: | |
| values = splade_vector[indices].cpu().tolist() | |
| token_weights = dict(zip(indices, values)) | |
| meaningful_tokens = {} | |
| for token_id, weight in token_weights.items(): | |
| decoded_token = tokenizer.decode([token_id]) | |
| if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0: | |
| meaningful_tokens[decoded_token] = weight | |
| if is_binary: | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary | |
| else: | |
| sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True) | |
| formatted_output = "" | |
| if not sorted_representation: | |
| formatted_output += "No significant terms found.\n" | |
| else: | |
| terms_list = [] | |
| for i, (term, weight) in enumerate(sorted_representation): | |
| # Limiting to 50 terms for display to avoid overly long output | |
| if is_binary: | |
| terms_list.append(f"**{term}**") | |
| else: | |
| terms_list.append(f"**{term}**: {weight:.4f}") | |
| formatted_output += ", ".join(terms_list) + "." | |
| # This is the line that will now always be split into two | |
| info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1 | |
| return formatted_output, info_output | |
| # --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag --- | |
| def get_model_assets(model_choice_str): | |
| if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)": | |
| return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)" | |
| elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)": | |
| return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)" | |
| elif model_choice_str == "Binary": | |
| return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words" | |
| else: | |
| return None, None, False, "Unknown Model" | |
| # --- MODIFIED: Dot Product Calculation Function for the new tab --- | |
| def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text): | |
| query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice) | |
| doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice) | |
| if query_vector_fn is None or doc_vector_fn is None: | |
| return "Please select valid models for both query and document encoding.", "" | |
| query_vector = query_vector_fn(query_text) | |
| doc_vector = doc_vector_fn(doc_text) | |
| if query_vector is None or doc_vector is None: | |
| return "Failed to generate one or both vectors. Please check model loading and input text.", "" | |
| # Calculate overall dot product | |
| dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item()) | |
| # Format representations for display | |
| query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary) | |
| doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary) | |
| # --- NEW FEATURE: Calculate dot product of overlapping terms --- | |
| overlapping_terms_dot_products = {} | |
| query_indices = torch.nonzero(query_vector).squeeze().cpu() | |
| doc_indices = torch.nonzero(doc_vector).squeeze().cpu() | |
| # Handle cases where vectors are empty or single element | |
| if query_indices.dim() == 0 and query_indices.numel() == 1: | |
| query_indices = query_indices.unsqueeze(0) | |
| if doc_indices.dim() == 0 and doc_indices.numel() == 1: | |
| doc_indices = doc_indices.unsqueeze(0) | |
| # Convert indices to sets for efficient intersection | |
| query_index_set = set(query_indices.tolist()) | |
| doc_index_set = set(doc_indices.tolist()) | |
| common_indices = sorted(list(query_index_set.intersection(doc_index_set))) | |
| if common_indices: | |
| for idx in common_indices: | |
| query_weight = query_vector[idx].item() | |
| doc_weight = doc_vector[idx].item() | |
| term = query_tokenizer.decode([idx]) # Tokenizers should be the same for this purpose | |
| if term not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(term.strip()) > 0: | |
| overlapping_terms_dot_products[term] = query_weight * doc_weight | |
| sorted_overlapping_dot_products = sorted( | |
| overlapping_terms_dot_products.items(), | |
| key=lambda item: item[1], | |
| reverse=True | |
| ) | |
| # --- End NEW FEATURE --- | |
| # Combine output into a single string for the Markdown component | |
| full_output = f"### Overall Dot Product Score: {dot_product:.6f}\n\n" | |
| full_output += "---\n\n" | |
| # Overlapping Terms Dot Products | |
| if sorted_overlapping_dot_products: | |
| full_output += "### Product of Query and Document Term Scores:\n" | |
| full_output += "\n" # Removed the individual weight explanation | |
| overlap_list = [] | |
| for term, product_val in sorted_overlapping_dot_products: | |
| overlap_list.append(f"**{term}**: {product_val:.4f}") # Simplified to just the dot product | |
| full_output += ", ".join(overlap_list) + ".\n\n" | |
| full_output += "---\n\n" | |
| else: | |
| full_output += "### No Overlapping Terms Found.\n\n" | |
| full_output += "---\n\n" | |
| # Query Representation | |
| full_output += f"#### Query Representation: {query_model_name_display}\n" # Smaller heading for sub-section | |
| full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list | |
| full_output += f"> {query_info_str}\n" # Using blockquote for info as well | |
| full_output += "\n---\n\n" # Separator | |
| # Document Representation | |
| full_output += f"#### Document Representation: {doc_model_name_display}\n" # Smaller heading for sub-section | |
| full_output += f"> {doc_main_rep_str}\n" # Using blockquote | |
| full_output += f"> {doc_info_str}" # Using blockquote | |
| return full_output | |
| # Global variable to store the share URL once the app is launched | |
| global_share_url = "https://huggingface.co/spaces/SiddharthAK/TextLSRDemo" | |
| def get_current_share_url(): | |
| """Returns the globally stored share URL.""" | |
| return global_share_url if global_share_url else "Share URL not available yet." | |
| # --- Gradio Interface Setup with Tabs --- | |
| with gr.Blocks(title="SPLADE Demos", css=css) as demo: | |
| gr.Markdown("# 🌌 Sparse Encoder Playground") # Updated title | |
| gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description | |
| with gr.Tabs(): | |
| with gr.TabItem("Sparse Representation"): | |
| gr.Markdown("### Produce a Sparse Representation of an Input Text") | |
| with gr.Row(): | |
| with gr.Column(scale=1): # Left column for inputs and info | |
| model_radio = gr.Radio( | |
| [ | |
| "MLM encoder (SPLADE-cocondenser-distil)", | |
| "MLP encoder (SPLADE-v3-lexical)", | |
| "Binary" | |
| ], | |
| label="Choose Sparse Encoder", | |
| value="MLM encoder (SPLADE-cocondenser-distil)" | |
| ) | |
| input_text = gr.Textbox( | |
| lines=5, | |
| label="Enter your query or document text here:", | |
| placeholder="e.g., Why is Padua the nicest city in Italy?" | |
| ) | |
| # Custom Share Button and URL display | |
| with gr.Row(elem_classes="share-button-container"): | |
| share_button = gr.Button( | |
| "🔗 Get Share Link", | |
| elem_classes="custom-share-button", | |
| size="sm" | |
| ) | |
| share_output = gr.Textbox( | |
| label="Share URL", | |
| interactive=True, | |
| visible=False, | |
| placeholder="Click 'Get Share Link' to generate URL..." | |
| ) | |
| info_output_display = gr.Markdown( | |
| value="", | |
| label="Vector Information", | |
| elem_id="info_output_display" | |
| ) | |
| with gr.Column(scale=2): # Right column for the main representation output | |
| main_representation_output = gr.Markdown() | |
| # Connect share button. | |
| share_button.click( | |
| fn=get_current_share_url, | |
| outputs=share_output | |
| ).then( | |
| fn=lambda: gr.update(visible=True), | |
| outputs=share_output | |
| ) | |
| # Connect the core prediction logic | |
| model_radio.change( | |
| fn=predict_representation_explorer, | |
| inputs=[model_radio, input_text], | |
| outputs=[main_representation_output, info_output_display] | |
| ) | |
| input_text.change( | |
| fn=predict_representation_explorer, | |
| inputs=[model_radio, input_text], | |
| outputs=[main_representation_output, info_output_display] | |
| ) | |
| # Initial call to populate on load (optional, but good for demo) | |
| demo.load( | |
| fn=lambda: predict_representation_explorer(model_radio.value, input_text.value), | |
| outputs=[main_representation_output, info_output_display] | |
| ) | |
| with gr.TabItem("Compute Query-Document Similarity Score"): | |
| gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document") | |
| model_choices = [ | |
| "MLM encoder (SPLADE-cocondenser-distil)", | |
| "MLP encoder (SPLADE-v3-lexical)", | |
| "Binary" | |
| ] | |
| # Input components for the second tab | |
| query_model_radio = gr.Radio( | |
| model_choices, | |
| label="Choose Query Encoding Model", | |
| value="MLM encoder (SPLADE-cocondenser-distil)" | |
| ) | |
| doc_model_radio = gr.Radio( | |
| model_choices, | |
| label="Choose Document Encoding Model", | |
| value="MLM encoder (SPLADE-cocondenser-distil)" | |
| ) | |
| query_text_input = gr.Textbox( | |
| lines=3, | |
| label="Enter Query Text:", | |
| placeholder="e.g., famous dishes of Padua" | |
| ) | |
| doc_text_input = gr.Textbox( | |
| lines=5, | |
| label="Enter Document Text:", | |
| placeholder="e.g., Padua's cuisine is as famous as its legendary University." | |
| ) | |
| # --- MODIFIED: Output component as a gr.Markdown with scrolling --- | |
| # Reverting to gr.Markdown, and adding height/scroll for it | |
| output_dot_product_markdown = gr.Markdown( | |
| # Use value="" to initialize, content will be set by the function | |
| value="", | |
| # Fixed height for the scrollable area | |
| # You can adjust this value (e.g., "500px") to your preference | |
| # Or set it as a percentage of available space, e.g., "80%" | |
| height=500, # Example: 500 pixels height | |
| # Enable vertical scrolling if content overflows | |
| # "auto" is often good, "scroll" always shows scrollbar | |
| # Gradio uses `css` for this, so these parameters might translate to inline styles | |
| # or custom CSS classes automatically added by Gradio. | |
| elem_classes=["scrollable-output"] # Add a custom class for CSS targeting if needed | |
| ) | |
| # Add CSS specifically for this scrollable markdown output | |
| # This needs to be added to the overall `css` string or handled directly here | |
| # For simplicity, let's assume `height` itself will enable scroll in newer Gradio, | |
| # or add a specific CSS class targeting the markdown. | |
| # However, for pure markdown, `height` is the primary way. | |
| # Update the gr.Interface call to use the new Markdown output | |
| gr.Interface( | |
| fn=calculate_dot_product_and_representations_independent, | |
| inputs=[ | |
| query_model_radio, | |
| doc_model_radio, | |
| query_text_input, | |
| doc_text_input | |
| ], | |
| outputs=output_dot_product_markdown, # Changed back to Markdown | |
| allow_flagging="never" | |
| ) | |
| # --- UPDATED CITATION BLOCK WITH TWO REFERENCES --- | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### References | |
| This demo utilizes **SPLADE** models. For more details, please refer to the following papers: | |
| 1. Formal, T., Lassance, C., Piwowarski, B., & Clinchant, S. (2022). **From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective**. *arXiv preprint arXiv:2205.04733*. Available at: [https://arxiv.org/abs/2205.04733](https://arxiv.org/abs/2205.04733) | |
| 2. Lassance, C., Déjean, H., Formal, T., & Clinchant, S. (2024). **SPLADE-v3: New baselines for SPLADE**. *arXiv preprint arXiv:2403.06789*. Available at: [https://arxiv.org/abs/2403.06789](https://arxiv.org/abs/2403.06789) | |
| """ | |
| ) | |
| # This block ensures the share URL is captured when the app launches | |
| if __name__ == "__main__": | |
| launched_demo = demo.launch(share=True) | |
| print("\n--- Gradio App Launched ---") | |
| print("If a public share link is generated, it will be displayed in your console.") | |
| print("You can also use the '🔗 Get Share Link' button on the 'Sparse Representation' tab.") | |
| print("---------------------------\n") |