import gradio as gr
import pandas as pd
import plotly.express as px
import torch
from huggingface_hub import hf_hub_download
from importlib import import_module
import shutil
import os
import numpy as np
# Load inference.py and model
repo_id = "logasanjeev/emotions-analyzer-bert"
local_file = hf_hub_download(repo_id=repo_id, filename="inference.py")
print("Downloaded inference.py successfully!")
current_dir = os.getcwd()
destination = os.path.join(current_dir, "inference.py")
shutil.copy(local_file, destination)
print("Copied inference.py to current directory!")
inference_module = import_module("inference")
predict_emotions = inference_module.predict_emotions
print("Imported predict_emotions successfully!")
_, _ = predict_emotions("dummy text")
emotion_labels = inference_module.EMOTION_LABELS
default_thresholds = inference_module.THRESHOLDS
# Function to merge subwords and their scores
def merge_subwords(tokens, scores):
words = []
word_scores = []
current_word = ""
current_score = 0.0
count = 0
for t, s in zip(tokens, scores):
if t in ['[CLS]', '[SEP]', '[PAD]']:
continue
if t.startswith('##'):
current_word += t[2:]
current_score += s
count += 1
else:
if current_word:
words.append(current_word)
word_scores.append(current_score / count)
current_word = t
current_score = s
count = 1
if current_word:
words.append(current_word)
word_scores.append(current_score / count)
return words, word_scores
# Prediction function with grouped bar chart and optional heatmap
def predict_emotions_with_details(text, confidence_threshold=0.0, show_heatmap=False):
if not text.strip():
return "Please enter some text.", "", "", None, ""
predictions_str, processed_text = predict_emotions(text)
# Parse predictions
predictions = []
if predictions_str != "No emotions predicted.":
for line in predictions_str.split("\n"):
emotion, confidence = line.split(": ")
predictions.append((emotion, float(confidence)))
# Get raw logits for all emotions (for Top 5)
encodings = inference_module.TOKENIZER(
processed_text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(inference_module.DEVICE)
attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
with torch.no_grad():
outputs = inference_module.MODEL(
input_ids,
attention_mask=attention_mask,
output_attentions=show_heatmap
)
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
# All emotions for Top 5
all_emotions = [(emotion_labels[i], round(logit, 4)) for i, logit in enumerate(logits)]
all_emotions.sort(key=lambda x: x[1], reverse=True)
top_5_emotions = all_emotions[:5]
top_5_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in top_5_emotions])
# Filter predictions based on threshold
filtered_predictions = []
for emotion, confidence in predictions:
thresh = default_thresholds[emotion_labels.index(emotion)]
adjusted_thresh = max(thresh, confidence_threshold)
if confidence >= adjusted_thresh:
filtered_predictions.append((emotion, confidence))
if not filtered_predictions:
thresholded_output = "No emotions predicted above thresholds."
else:
thresholded_output = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in filtered_predictions])
# Create grouped bar chart
fig = None
if filtered_predictions or top_5_emotions:
emotions = set([pred[0] for pred in filtered_predictions] + [emo[0] for emo in top_5_emotions])
thresholded_dict = {pred[0]: pred[1] for pred in filtered_predictions}
top_5_dict = {emo[0]: emo[1] for emo in top_5_emotions}
data = {
"Emotion": [],
"Confidence": [],
"Category": []
}
for emotion in emotions:
if emotion in thresholded_dict:
data["Emotion"].append(emotion)
data["Confidence"].append(thresholded_dict[emotion])
data["Category"].append("Above Threshold")
if emotion in top_5_dict:
data["Emotion"].append(emotion)
data["Confidence"].append(top_5_dict[emotion])
data["Category"].append("Top 5")
df = pd.DataFrame(data)
fig = px.bar(
df,
x="Emotion",
y="Confidence",
color="Category",
barmode="group",
title="Emotion Confidence Comparison",
height=400,
color_discrete_map={"Above Threshold": "#6366f1", "Top 5": "#10b981"}
)
fig.update_traces(texttemplate='%{y:.2f}', textposition='auto')
fig.update_layout(
margin=dict(t=50, b=50),
xaxis_title="",
yaxis_title="Confidence",
legend_title="",
legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.5),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
font=dict(color="#e5e7eb")
)
# Generate heatmap if enabled
heatmap_html = ""
if show_heatmap:
attentions = outputs.attentions[-1] # Last layer attention [batch, heads, seq, seq]
cls_att = attentions[0, :, 0, :].mean(dim=0).cpu().numpy() # Average over heads, from CLS
seq_len = int(attention_mask[0].sum())
att_scores = cls_att[:seq_len]
input_tokens = inference_module.TOKENIZER.convert_ids_to_tokens(input_ids[0][:seq_len])
words, word_scores = merge_subwords(input_tokens, att_scores)
# Normalize scores to 0-1
if max(word_scores) > min(word_scores):
word_scores = [(s - min(word_scores)) / (max(word_scores) - min(word_scores)) for s in word_scores]
else:
word_scores = [0.0] * len(words)
# Generate HTML with colored spans
html = ""
for word, score in zip(words, word_scores):
alpha = score
color = f"rgba(255, 100, 100, {alpha:.2f})" # Gradient from transparent to red
html += f'{word}'
heatmap_html = f"""
Attention Heatmap (Focus Areas)
{html}
Red intensity indicates model's focus (based on CLS token attention in last layer).
"""
return processed_text, thresholded_output, top_5_output, fig, heatmap_html
# Enhanced CSS with modern design
custom_css = """
body {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #111827 0%, #1f2937 100%);
color: #e5e7eb;
margin: 0;
padding: 24px;
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
}
.gr-panel {
border-radius: 16px;
box-shadow: 0 10px 30px rgba(0,0,0,0.3);
background: rgba(31, 41, 55, 0.9);
backdrop-filter: blur(12px);
padding: 32px;
margin: 24px auto;
max-width: 960px;
width: 100%;
border: 1px solid rgba(255, 255, 255, 0.1);
transition: transform 0.3s ease, box-shadow 0.3s ease;
}
.gr-panel:hover {
transform: translateY(-4px);
box-shadow: 0 12px 40px rgba(0,0,0,0.35);
}
.gr-button {
border-radius: 8px;
padding: 12px 32px;
font-weight: 600;
font-size: 16px;
background: linear-gradient(90deg, #6366f1 0%, #8b5cf6 100%);
color: #ffffff;
border: none;
transition: all 0.3s ease;
cursor: pointer;
margin-top: 16px;
}
.gr-button:hover {
background: linear-gradient(90deg, #8b5cf6 0%, #6366f1 100%);
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(99, 102, 241, 0.4);
}
.gr-button:focus {
outline: none;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.3);
}
.gr-textbox, .gr-slider {
margin-bottom: 24px;
}
.gr-textbox label, .gr-slider label {
font-size: 16px;
font-weight: 600;
color: #e5e7eb;
margin-bottom: 8px;
display: block;
}
.gr-textbox textarea, .gr-textbox input {
border: 1px solid rgba(255, 255, 255, 0.15);
border-radius: 8px;
padding: 12px;
font-size: 16px;
background: rgba(55, 65, 81, 0.5);
color: #e5e7eb;
transition: border-color 0.3s ease, box-shadow 0.3s ease;
}
.gr-textbox textarea:focus, .gr-textbox input:focus {
border-color: #6366f1;
box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.2);
outline: none;
}
#title {
font-size: 2.5rem;
font-weight: 800;
color: #ffffff;
text-align: center;
margin: 32px 0 16px 0;
background: linear-gradient(90deg, #6366f1, #8b5cf6);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
#description {
font-size: 1.125rem;
color: #d1d5db;
text-align: center;
max-width: 720px;
margin: 0 auto 48px auto;
line-height: 1.75;
}
#examples-title {
font-size: 1.25rem;
font-weight: 600;
color: #e5e7eb;
margin: 32px 0 16px 0;
text-align: center;
}
footer {
text-align: center;
margin: 48px 0 24px 0;
padding: 16px;
font-size: 14px;
color: #d1d5db;
}
footer a {
color: #6366f1;
text-decoration: none;
font-weight: 500;
transition: color 0.3s ease;
}
footer a:hover {
color: #8b5cf6;
}
.gr-plot {
margin-top: 24px;
background: rgba(31, 41, 55, 0.5);
border-radius: 12px;
padding: 16px;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.gr-examples .example {
background: rgba(55, 65, 81, 0.7);
border-radius: 10px;
padding: 16px;
margin: 12px 0;
transition: all 0.3s ease;
cursor: pointer;
border: 1px solid rgba(255, 255, 255, 0.1);
}
.gr-examples .example:hover {
background: rgba(99, 102, 241, 0.15);
transform: translateY(-2px);
border-color: #6366f1;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(16px); }
to { opacity: 1; transform: translateY(0); }
}
.gr-panel, #title, #description, footer, .gr-examples .example {
animation: fadeIn 0.6s ease-out;
}
/* Responsive design */
@media (max-width: 768px) {
.gr-panel {
padding: 24px;
margin: 16px;
}
#title {
font-size: 2rem;
}
#description {
font-size: 1rem;
}
.gr-button {
padding: 10px 24px;
font-size: 14px;
}
}
"""
# Gradio Blocks UI (Modernized)
with gr.Blocks(css=custom_css) as demo:
# Header
gr.Markdown(
"Emotions Analyzer BERT
",
elem_id="title"
)
gr.Markdown(
"""
Uncover the emotions in your text with our fine-tuned BERT model, trained on the GoEmotions dataset.
Enter your text, fine-tune the confidence threshold, and visualize the results in a sleek, interactive chart.
""",
elem_id="description"
)
# Input Section
with gr.Group():
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Textbox(
label="Enter Your Text",
placeholder="Try: 'I'm over the moon today!' or 'This is so frustrating... š£'",
lines=4,
show_label=True,
elem_classes=["input-textbox"]
)
with gr.Column(scale=1):
confidence_slider = gr.Slider(
minimum=0.0,
maximum=0.9,
value=0.0,
step=0.05,
label="Confidence Threshold",
info="Filter emotions below this confidence level",
elem_classes=["input-slider"]
)
show_heatmap = gr.Checkbox(
label="Show Attention Heatmap",
value=False,
info="Visualize model focus on text (explainability)"
)
submit_btn = gr.Button("Analyze Emotions", variant="primary")
# Output Section
with gr.Group():
with gr.Row():
with gr.Column(scale=1):
processed_text_output = gr.Textbox(
label="Processed Text",
lines=2,
interactive=False,
elem_classes=["output-textbox"]
)
thresholded_output = gr.Textbox(
label="Detected Emotions (Above Threshold)",
lines=5,
interactive=False,
elem_classes=["output-textbox"]
)
top_5_output = gr.Textbox(
label="Top 5 Emotions",
lines=5,
interactive=False,
elem_classes=["output-textbox"]
)
with gr.Column(scale=2):
output_plot = gr.Plot(
label="Emotion Confidence Visualization",
elem_classes=["output-plot"]
)
heatmap_output = gr.HTML(
label="Attention Heatmap",
visible=True
)
# Example carousel
with gr.Group():
gr.Markdown(
"Try These Examples
",
elem_id="examples-title"
)
examples = gr.Examples(
examples=[
["Iām thrilled to win this award! š", "Joy Example"],
["This is so frustrating, nothing works. š£", "Annoyance Example"],
["I feel so sorry for what happened. š¢", "Sadness Example"],
["What a beautiful day to be alive! š", "Admiration Example"],
["Feeling nervous about the exam tomorrow š u/student r/study", "Nervousness Example"]
],
inputs=[text_input],
label=""
)
# Footer
gr.HTML(
"""
"""
)
# Bind predictions
submit_btn.click(
fn=predict_emotions_with_details,
inputs=[text_input, confidence_slider, show_heatmap],
outputs=[processed_text_output, thresholded_output, top_5_output, output_plot, heatmap_output]
)
# Launch
if __name__ == "__main__":
demo.launch()