Perunio commited on
Commit
ea2aeef
·
1 Parent(s): f2f3107
Files changed (1) hide show
  1. galis_app.py +36 -29
galis_app.py CHANGED
@@ -1,20 +1,33 @@
1
  from pathlib import Path
2
  import streamlit as st
3
- import torch.nn.functional as F
4
- from predictor.link_predictor import (
5
- prepare_system,
6
- get_citation_predictions,
7
- abstract_to_vector,
8
- format_top_k_predictions,
9
- )
10
  from llm.related_work_generator import generate_related_work
11
 
12
- MODEL_PATH = Path("model.pth")
13
-
14
 
15
  @st.cache_resource
16
- def load_prediction_system(model_path):
17
- return prepare_system(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  def app():
@@ -30,7 +43,7 @@ def app():
30
  if "abstract_text" not in st.session_state:
31
  st.session_state.abstract_text = ""
32
 
33
- gcn_model, st_model, dataset, z_all = load_prediction_system(MODEL_PATH)
34
 
35
  col1, col2 = st.columns(2, gap="large")
36
 
@@ -91,32 +104,26 @@ def app():
91
  related_work_placeholder.empty()
92
 
93
  with st.spinner("Analyzing abstract and predicting references..."):
94
- new_vector = abstract_to_vector(
95
- abstract_input, abstract_title, st_model
96
- )
97
-
98
- probabilities = get_citation_predictions(
99
- vector=F.normalize(new_vector.view(1, -1), p=2, dim=1),
100
- model=gcn_model,
101
- z_all=z_all,
102
- num_nodes=dataset.data.num_nodes,
103
- )
104
- references = format_top_k_predictions(
105
- probabilities, dataset, top_k=num_citations
106
  )
 
107
  st.session_state.references = references
108
 
109
  with references_placeholder.container():
110
  st.header("Suggested References")
111
  with st.container(height=200):
112
- st.markdown(st.session_state.references)
113
 
114
  with related_work_placeholder.container():
115
  with st.spinner("Generating related work section..."):
 
116
  related_work = generate_related_work(
117
  st.session_state.abstract_title,
118
  st.session_state.abstract_text,
119
- st.session_state.references
120
  )
121
  st.session_state.related_work = related_work
122
 
@@ -124,14 +131,14 @@ def app():
124
  with references_placeholder.container():
125
  st.header("Suggested References")
126
  with st.container(height=200):
127
- st.markdown(st.session_state.references)
128
 
129
  if st.session_state.related_work:
130
  with related_work_placeholder.container():
131
  st.header("Suggested Related Works")
132
  with st.container(height=200):
133
- st.markdown(st.session_state.related_work)
134
 
135
 
136
  if __name__ == "__main__":
137
- app()
 
1
  from pathlib import Path
2
  import streamlit as st
3
+ from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
4
+ from model.paper_similarity import PaperSimilarityFinder
 
 
 
 
 
5
  from llm.related_work_generator import generate_related_work
6
 
 
 
7
 
8
  @st.cache_resource
9
+ def load_similarity_finder():
10
+ """Ładuje i inicjalizuje PaperSimilarityFinder."""
11
+ dataset = OGBNLinkPredDataset()
12
+ model_name = "all-mpnet-base-v2"
13
+ embeddings_dir = Path("embeddings_cache")
14
+
15
+ similarity_finder = PaperSimilarityFinder(
16
+ dataset,
17
+ method="sentence_transformer",
18
+ model_name=model_name,
19
+ embeddings_cache_path=embeddings_dir,
20
+ )
21
+ return similarity_finder, dataset
22
+
23
+
24
+ def format_top_k_predictions_from_similarity(similar_papers: list) -> str:
25
+ """Formatuje listę podobnych artykułów do wyświetlenia w Markdown."""
26
+ markdown_list = []
27
+ for i, (idx, score, text) in enumerate(similar_papers):
28
+ title = text.split('\n')[0].strip()
29
+ markdown_list.append(f"{i + 1}. **{title}** (Similarity: {score:.4f})")
30
+ return "\n".join(markdown_list)
31
 
32
 
33
  def app():
 
43
  if "abstract_text" not in st.session_state:
44
  st.session_state.abstract_text = ""
45
 
46
+ similarity_finder, dataset = load_similarity_finder()
47
 
48
  col1, col2 = st.columns(2, gap="large")
49
 
 
104
  related_work_placeholder.empty()
105
 
106
  with st.spinner("Analyzing abstract and predicting references..."):
107
+ similar_papers = similarity_finder.find_similar_papers(
108
+ title=abstract_title,
109
+ abstract=abstract_input,
110
+ top_k=num_citations
 
 
 
 
 
 
 
 
111
  )
112
+ references = format_top_k_predictions_from_similarity(similar_papers)
113
  st.session_state.references = references
114
 
115
  with references_placeholder.container():
116
  st.header("Suggested References")
117
  with st.container(height=200):
118
+ st.markdown(st.session_state.references, unsafe_allow_html=True)
119
 
120
  with related_work_placeholder.container():
121
  with st.spinner("Generating related work section..."):
122
+ # Upewnij się, że funkcja generate_related_work akceptuje ten format
123
  related_work = generate_related_work(
124
  st.session_state.abstract_title,
125
  st.session_state.abstract_text,
126
+ st.session_state.references,
127
  )
128
  st.session_state.related_work = related_work
129
 
 
131
  with references_placeholder.container():
132
  st.header("Suggested References")
133
  with st.container(height=200):
134
+ st.markdown(st.session_state.references, unsafe_allow_html=True)
135
 
136
  if st.session_state.related_work:
137
  with related_work_placeholder.container():
138
  st.header("Suggested Related Works")
139
  with st.container(height=200):
140
+ st.markdown(st.session_state.related_work, unsafe_allow_html=True)
141
 
142
 
143
  if __name__ == "__main__":
144
+ app()