logasanjeev commited on
Commit
aa820aa
·
verified ·
1 Parent(s): 42f15b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -7
app.py CHANGED
@@ -6,6 +6,7 @@ from huggingface_hub import hf_hub_download
6
  from importlib import import_module
7
  import shutil
8
  import os
 
9
 
10
  # Load inference.py and model
11
  repo_id = "logasanjeev/emotions-analyzer-bert"
@@ -25,10 +26,36 @@ _, _ = predict_emotions("dummy text")
25
  emotion_labels = inference_module.EMOTION_LABELS
26
  default_thresholds = inference_module.THRESHOLDS
27
 
28
- # Prediction function with grouped bar chart
29
- def predict_emotions_with_details(text, confidence_threshold=0.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if not text.strip():
31
- return "Please enter some text.", "", "", None
32
 
33
  predictions_str, processed_text = predict_emotions(text)
34
 
@@ -51,7 +78,11 @@ def predict_emotions_with_details(text, confidence_threshold=0.0):
51
  attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
52
 
53
  with torch.no_grad():
54
- outputs = inference_module.MODEL(input_ids, attention_mask=attention_mask)
 
 
 
 
55
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
56
 
57
  # All emotions for Top 5
@@ -120,7 +151,39 @@ def predict_emotions_with_details(text, confidence_threshold=0.0):
120
  font=dict(color="#e5e7eb")
121
  )
122
 
123
- return processed_text, thresholded_output, top_5_output, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Enhanced CSS with modern design
126
  custom_css = """
@@ -343,6 +406,11 @@ with gr.Blocks(css=custom_css) as demo:
343
  info="Filter emotions below this confidence level",
344
  elem_classes=["input-slider"]
345
  )
 
 
 
 
 
346
  submit_btn = gr.Button("Analyze Emotions", variant="primary")
347
 
348
  # Output Section
@@ -372,6 +440,10 @@ with gr.Blocks(css=custom_css) as demo:
372
  label="Emotion Confidence Visualization",
373
  elem_classes=["output-plot"]
374
  )
 
 
 
 
375
 
376
  # Example carousel
377
  with gr.Group():
@@ -405,8 +477,8 @@ with gr.Blocks(css=custom_css) as demo:
405
  # Bind predictions
406
  submit_btn.click(
407
  fn=predict_emotions_with_details,
408
- inputs=[text_input, confidence_slider],
409
- outputs=[processed_text_output, thresholded_output, top_5_output, output_plot]
410
  )
411
 
412
  # Launch
 
6
  from importlib import import_module
7
  import shutil
8
  import os
9
+ import numpy as np
10
 
11
  # Load inference.py and model
12
  repo_id = "logasanjeev/emotions-analyzer-bert"
 
26
  emotion_labels = inference_module.EMOTION_LABELS
27
  default_thresholds = inference_module.THRESHOLDS
28
 
29
+ # Function to merge subwords and their scores
30
+ def merge_subwords(tokens, scores):
31
+ words = []
32
+ word_scores = []
33
+ current_word = ""
34
+ current_score = 0.0
35
+ count = 0
36
+ for t, s in zip(tokens, scores):
37
+ if t in ['[CLS]', '[SEP]', '[PAD]']:
38
+ continue
39
+ if t.startswith('##'):
40
+ current_word += t[2:]
41
+ current_score += s
42
+ count += 1
43
+ else:
44
+ if current_word:
45
+ words.append(current_word)
46
+ word_scores.append(current_score / count)
47
+ current_word = t
48
+ current_score = s
49
+ count = 1
50
+ if current_word:
51
+ words.append(current_word)
52
+ word_scores.append(current_score / count)
53
+ return words, word_scores
54
+
55
+ # Prediction function with grouped bar chart and optional heatmap
56
+ def predict_emotions_with_details(text, confidence_threshold=0.0, show_heatmap=False):
57
  if not text.strip():
58
+ return "Please enter some text.", "", "", None, ""
59
 
60
  predictions_str, processed_text = predict_emotions(text)
61
 
 
78
  attention_mask = encodings['attention_mask'].to(inference_module.DEVICE)
79
 
80
  with torch.no_grad():
81
+ outputs = inference_module.MODEL(
82
+ input_ids,
83
+ attention_mask=attention_mask,
84
+ output_attentions=show_heatmap
85
+ )
86
  logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
87
 
88
  # All emotions for Top 5
 
151
  font=dict(color="#e5e7eb")
152
  )
153
 
154
+ # Generate heatmap if enabled
155
+ heatmap_html = ""
156
+ if show_heatmap:
157
+ attentions = outputs.attentions[-1] # Last layer attention [batch, heads, seq, seq]
158
+ cls_att = attentions[0, :, 0, :].mean(dim=0).cpu().numpy() # Average over heads, from CLS
159
+ seq_len = int(attention_mask[0].sum())
160
+ att_scores = cls_att[:seq_len]
161
+ input_tokens = inference_module.TOKENIZER.convert_ids_to_tokens(input_ids[0][:seq_len])
162
+
163
+ words, word_scores = merge_subwords(input_tokens, att_scores)
164
+
165
+ # Normalize scores to 0-1
166
+ if max(word_scores) > min(word_scores):
167
+ word_scores = [(s - min(word_scores)) / (max(word_scores) - min(word_scores)) for s in word_scores]
168
+ else:
169
+ word_scores = [0.0] * len(words)
170
+
171
+ # Generate HTML with colored spans
172
+ html = ""
173
+ for word, score in zip(words, word_scores):
174
+ alpha = score
175
+ color = f"rgba(255, 100, 100, {alpha:.2f})" # Gradient from transparent to red
176
+ html += f'<span style="background-color: {color}; padding: 2px 4px; margin: 0 2px; border-radius: 4px; color: black;">{word}</span>'
177
+
178
+ heatmap_html = f"""
179
+ <div style="padding: 16px; background: rgba(55, 65, 81, 0.7); border-radius: 12px; border: 1px solid rgba(255, 255, 255, 0.1); margin-top: 24px;">
180
+ <h4 style="color: #e5e7eb; margin-bottom: 12px;">Attention Heatmap (Focus Areas)</h4>
181
+ <div style="overflow-x: auto; white-space: nowrap;">{html}</div>
182
+ <p style="color: #d1d5db; font-size: 0.875rem; margin-top: 8px;">Red intensity indicates model's focus (based on CLS token attention in last layer).</p>
183
+ </div>
184
+ """
185
+
186
+ return processed_text, thresholded_output, top_5_output, fig, heatmap_html
187
 
188
  # Enhanced CSS with modern design
189
  custom_css = """
 
406
  info="Filter emotions below this confidence level",
407
  elem_classes=["input-slider"]
408
  )
409
+ show_heatmap = gr.Checkbox(
410
+ label="Show Attention Heatmap",
411
+ value=False,
412
+ info="Visualize model focus on text (explainability)"
413
+ )
414
  submit_btn = gr.Button("Analyze Emotions", variant="primary")
415
 
416
  # Output Section
 
440
  label="Emotion Confidence Visualization",
441
  elem_classes=["output-plot"]
442
  )
443
+ heatmap_output = gr.HTML(
444
+ label="Attention Heatmap",
445
+ visible=True
446
+ )
447
 
448
  # Example carousel
449
  with gr.Group():
 
477
  # Bind predictions
478
  submit_btn.click(
479
  fn=predict_emotions_with_details,
480
+ inputs=[text_input, confidence_slider, show_heatmap],
481
+ outputs=[processed_text_output, thresholded_output, top_5_output, output_plot, heatmap_output]
482
  )
483
 
484
  # Launch