Perunio commited on
Commit
7dba0b0
·
1 Parent(s): 42dd08a
galis_app.py CHANGED
@@ -2,7 +2,10 @@ from pathlib import Path
2
  import streamlit as st
3
  from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
4
  from model.paper_similarity import PaperSimilarityFinder
5
- from llm.related_work_generator import generate_related_work
 
 
 
6
 
7
 
8
  @st.cache_resource
@@ -17,72 +20,116 @@ def load_similarity_finder():
17
  model_name=model_name,
18
  embeddings_cache_path=embeddings_dir,
19
  )
20
- return similarity_finder, dataset
 
 
 
21
 
22
 
23
  def format_top_k_predictions_from_similarity(similar_papers: list) -> str:
24
  markdown_list = []
25
  for i, (idx, score, text) in enumerate(similar_papers):
26
- title = text.split('\n')[0].strip()
27
- markdown_list.append(f"{i + 1}. **{title}** (Similarity: {score:.4f})")
28
  return "\n".join(markdown_list)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def app():
32
  st.set_page_config(page_title="Galis", layout="wide")
33
  st.title("Galis")
 
 
34
 
35
  if "references" not in st.session_state:
36
- st.session_state.references = None
37
  if "related_work" not in st.session_state:
38
- st.session_state.related_work = None
39
  if "abstract_title" not in st.session_state:
40
  st.session_state.abstract_title = ""
41
  if "abstract_text" not in st.session_state:
42
  st.session_state.abstract_text = ""
43
 
44
- similarity_finder, dataset = load_similarity_finder()
45
 
46
  col1, col2 = st.columns(2, gap="large")
47
 
48
- with col2:
49
- references_placeholder = st.empty()
50
- related_work_placeholder = st.empty()
51
-
52
  with col1:
53
  st.header("Abstract Title")
54
- abstract_title = st.text_input(
55
- "Paste your title here",
56
- st.session_state.abstract_title,
57
- key="abstract_title_input",
58
- label_visibility="collapsed",
59
  )
60
 
61
  st.header("Abstract Text")
62
- abstract_input = st.text_area(
63
  "Paste your abstract here",
64
- st.session_state.abstract_text,
65
- key="abstract_text_input",
66
- height=100,
67
  label_visibility="collapsed",
68
  )
69
 
70
- st.write("...or **upload** a .txt file (first line = title, rest = abstract)")
71
- uploaded_file = st.file_uploader(
72
- "Drag and drop file here", type=["txt"], help="Limit 200MB per file • TXT"
 
 
 
73
  )
74
 
75
- if uploaded_file is not None:
76
- content = uploaded_file.getvalue().decode("utf-8").splitlines()
77
- st.session_state.abstract_title = content[0] if content else ""
78
- st.session_state.abstract_text = (
79
- "\n".join(content[1:]) if len(content) > 1 else ""
80
- )
81
- st.rerun()
82
-
83
- st.session_state.abstract_title = abstract_title
84
- st.session_state.abstract_text = abstract_input
85
-
86
  num_citations = st.number_input(
87
  "Number of suggestions",
88
  min_value=1,
@@ -93,49 +140,63 @@ def app():
93
  )
94
 
95
  if st.button("Suggest References and related work", type="primary"):
96
- if not abstract_title.strip() or not abstract_input.strip():
 
 
 
97
  st.warning("Please provide both a title and an abstract.")
98
  else:
99
- st.session_state.references = None
100
- st.session_state.related_work = None
101
- references_placeholder.empty()
102
- related_work_placeholder.empty()
103
-
104
- with st.spinner("Analyzing abstract and predicting references..."):
105
- similar_papers = similarity_finder.find_similar_papers(
106
- title=abstract_title,
107
- abstract=abstract_input,
108
- top_k=num_citations
109
- )
110
- references = format_top_k_predictions_from_similarity(similar_papers)
111
- st.session_state.references = references
112
-
113
- with references_placeholder.container():
114
- st.header("Suggested References")
115
- with st.container(height=200):
116
- st.markdown(st.session_state.references, unsafe_allow_html=True)
117
-
118
- with related_work_placeholder.container():
119
- with st.spinner("Generating related work section..."):
120
- related_work = generate_related_work(
121
- st.session_state.abstract_title,
122
- st.session_state.abstract_text,
123
- st.session_state.references,
124
- )
125
- st.session_state.related_work = related_work
126
-
127
- if st.session_state.references:
128
- with references_placeholder.container():
129
  st.header("Suggested References")
130
- with st.container(height=200):
131
- st.markdown(st.session_state.references, unsafe_allow_html=True)
 
 
 
 
 
132
 
133
- if st.session_state.related_work:
134
- with related_work_placeholder.container():
135
  st.header("Suggested Related Works")
136
- with st.container(height=200):
137
- st.markdown(st.session_state.related_work, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
  if __name__ == "__main__":
141
- app()
 
2
  import streamlit as st
3
  from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
4
  from model.paper_similarity import PaperSimilarityFinder
5
+ from llm.related_work_generator import (
6
+ generate_related_work,
7
+ create_related_work_pipeline,
8
+ )
9
 
10
 
11
  @st.cache_resource
 
20
  model_name=model_name,
21
  embeddings_cache_path=embeddings_dir,
22
  )
23
+
24
+ pipeline = create_related_work_pipeline()
25
+
26
+ return pipeline, similarity_finder, dataset
27
 
28
 
29
  def format_top_k_predictions_from_similarity(similar_papers: list) -> str:
30
  markdown_list = []
31
  for i, (idx, score, text) in enumerate(similar_papers):
32
+ title = text.split("\n")[0].strip()
33
+ markdown_list.append(f"{i + 1}. {title} (Similarity: {score:.4f})")
34
  return "\n".join(markdown_list)
35
 
36
 
37
+ def process_uploaded_file():
38
+ try:
39
+ uploaded_file = st.session_state.file_uploader
40
+ if uploaded_file is not None:
41
+ content = uploaded_file.getvalue().decode("utf-8").splitlines()
42
+ st.session_state.abstract_title = content[0] if content else ""
43
+ st.session_state.abstract_text = (
44
+ "\n".join(content[1:]) if len(content) > 1 else ""
45
+ )
46
+ except Exception as e:
47
+ st.error(f"Error processing file: {e}")
48
+
49
+
50
+ GALIS_DESCRIPTION = """
51
+ ### About GALIS
52
+
53
+ **GALIS** is a web-based application designed to streamline and improve the creation of related work and
54
+ references sections for research papers. It leverages an existing semantic graph that captures the
55
+ relationships and core concepts among cited papers to guide language model outputs.
56
+
57
+ ### Objective
58
+ The primary objective is to provide a practical tool that helps researchers generate high-quality, coherent
59
+ related work and references sections, making the process of synthesizing literature more efficient and
60
+ insightful.
61
+
62
+ ---
63
+
64
+ ### How to Use GALIS
65
+
66
+ #### Option 1: Manual Input
67
+ 1. **Enter your paper title** in the "Abstract Title" field
68
+ 2. **Paste your abstract** in the "Abstract Text" area
69
+ 3. **Set the number of suggestions** you want (1-100 papers)
70
+ 4. **Click "Suggest References and related work"**
71
+
72
+ #### Option 2: File Upload
73
+ 1. **Prepare a .txt file** with:
74
+ - **First line**: Your paper title
75
+ - **Remaining lines**: Your abstract text
76
+ 2. **Upload the file** using the file uploader
77
+ 3. **Set the number of suggestions** you want (1-100 papers)
78
+ 4. **Click "Suggest References and related work"**
79
+
80
+ #### What You'll Get
81
+ - **Suggested References**: A curated list of relevant papers based on semantic similarity
82
+ - **Related Work Section**: An automatically generated related work section that synthesizes the suggested
83
+ papers
84
+ - **Regeneration Option**: You can regenerate the related work section if needed
85
+
86
+ ---
87
+
88
+ *Note: File uploads are limited to 200MB and must be in .txt format*
89
+ """
90
+
91
+
92
  def app():
93
  st.set_page_config(page_title="Galis", layout="wide")
94
  st.title("Galis")
95
+ with st.popover("What is Galis?"):
96
+ st.markdown(GALIS_DESCRIPTION)
97
 
98
  if "references" not in st.session_state:
99
+ st.session_state.references = ""
100
  if "related_work" not in st.session_state:
101
+ st.session_state.related_work = ""
102
  if "abstract_title" not in st.session_state:
103
  st.session_state.abstract_title = ""
104
  if "abstract_text" not in st.session_state:
105
  st.session_state.abstract_text = ""
106
 
107
+ pipeline, similarity_finder, dataset = load_similarity_finder()
108
 
109
  col1, col2 = st.columns(2, gap="large")
110
 
 
 
 
 
111
  with col1:
112
  st.header("Abstract Title")
113
+ st.text_input(
114
+ "Paste your title here", key="abstract_title", label_visibility="collapsed"
 
 
 
115
  )
116
 
117
  st.header("Abstract Text")
118
+ st.text_area(
119
  "Paste your abstract here",
120
+ key="abstract_text",
121
+ height=150,
 
122
  label_visibility="collapsed",
123
  )
124
 
125
+ st.file_uploader(
126
+ "Upload a .txt file here (first line = title, rest = abstract)",
127
+ type=["txt"],
128
+ help="Limit 200MB per file • TXT",
129
+ key="file_uploader",
130
+ on_change=process_uploaded_file,
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
133
  num_citations = st.number_input(
134
  "Number of suggestions",
135
  min_value=1,
 
140
  )
141
 
142
  if st.button("Suggest References and related work", type="primary"):
143
+ if (
144
+ not st.session_state.abstract_title.strip()
145
+ or not st.session_state.abstract_text.strip()
146
+ ):
147
  st.warning("Please provide both a title and an abstract.")
148
  else:
149
+ st.session_state.references = "LOADING"
150
+ st.session_state.related_work = ""
151
+
152
+ with col2:
153
+ if st.session_state.references == "LOADING":
154
+ with st.spinner("Analyzing abstract and predicting references..."):
155
+ similar_papers = similarity_finder.find_similar_papers(
156
+ title=st.session_state.abstract_title,
157
+ abstract=st.session_state.abstract_text,
158
+ top_k=num_citations,
159
+ )
160
+ st.session_state.references = format_top_k_predictions_from_similarity(
161
+ similar_papers
162
+ )
163
+ st.session_state.related_work = "LOADING"
164
+ st.rerun()
165
+
166
+ if st.session_state.references not in ["", "LOADING"]:
 
 
 
 
 
 
 
 
 
 
 
 
167
  st.header("Suggested References")
168
+ st.text_area(
169
+ "References",
170
+ value=st.session_state.references,
171
+ height=150,
172
+ label_visibility="collapsed",
173
+ key="ref_output",
174
+ )
175
 
 
 
176
  st.header("Suggested Related Works")
177
+
178
+ if st.session_state.related_work == "LOADING":
179
+ with st.spinner("Generating related work section..."):
180
+ st.session_state.related_work = generate_related_work(
181
+ pipeline,
182
+ st.session_state.abstract_title,
183
+ st.session_state.abstract_text,
184
+ st.session_state.references,
185
+ )
186
+ st.rerun()
187
+ else:
188
+ st.text_area(
189
+ "Related Works",
190
+ value=st.session_state.related_work,
191
+ height=300,
192
+ label_visibility="collapsed",
193
+ key="rw_output",
194
+ )
195
+
196
+ if st.button("Regenerate Related Works"):
197
+ st.session_state.related_work = "LOADING"
198
+ st.rerun()
199
 
200
 
201
  if __name__ == "__main__":
202
+ app()
llm/related_work_generator.py CHANGED
@@ -59,8 +59,8 @@ the novelty and importance of the user's project.
59
  Use appropriate terminology and focus on concepts, methods, and challenges relevant to that particular field of study.
60
 
61
  7. **Output Format:** Generate only the text for the "Related Work" section. Do not include headers like
62
- "INSTRUCTIONS," "PAPER TITLE," or "PROVIDED CITATIONS" in the final output. The entire response should be the
63
- section text itself, ready to be inserted into an academic paper.
64
  """
65
 
66
 
@@ -74,16 +74,10 @@ def check_api_key():
74
 
75
 
76
  def create_related_work_pipeline():
77
- """Creates a ready-to-use pipeline for generating the Related Work section."""
78
-
79
- llm = ChatGoogleGenerativeAI(
80
- model="gemini-2.0-flash-exp",
81
- temperature=0.3
82
- )
83
 
84
  prompt = PromptTemplate(
85
- input_variables=["title", "abstract", "citations"],
86
- template=PROMPT_TEXT
87
  )
88
 
89
  parser = StrOutputParser()
@@ -93,24 +87,12 @@ def create_related_work_pipeline():
93
  return chain
94
 
95
 
96
- def generate_related_work(title:str, abstract:str, citations_text: str) -> str:
97
- """
98
- Main function - pass title, abstract, and citations, get Related Work
99
-
100
- Args:
101
- title: The paper's title
102
- abstract: The paper's abstract
103
- citations_text: Text with citations (can be a list or a string)
104
-
105
- Returns:
106
- The generated Related Work section
107
- """
108
- pipeline = create_related_work_pipeline()
109
- result = pipeline.invoke({
110
- "title": title,
111
- "abstract": abstract,
112
- "citations": citations_text
113
- })
114
  return result
115
 
116
 
@@ -141,11 +123,12 @@ Top 5 Citation Predictions:
141
  print("-" * 50)
142
 
143
  try:
144
- related_work = generate_related_work(title, abstract, citations)
 
145
  print(related_work)
146
  except Exception as e:
147
  print(f"Error: {e}")
148
  print("1. Create a .env file in the same folder as the script")
149
  print("2. Add the line: GOOGLE_API_KEY=your_key")
150
  print("3. Get the key at: https://makersuite.google.com/app/apikey")
151
- check_api_key()
 
59
  Use appropriate terminology and focus on concepts, methods, and challenges relevant to that particular field of study.
60
 
61
  7. **Output Format:** Generate only the text for the "Related Work" section. Do not include headers like
62
+ "INSTRUCTIONS" "PAPER TITLE", "RELATED WORK" or "PROVIDED CITATIONS" in the final output. Do not use markdown syntax.
63
+ The entire response should be the section text itself, ready to be inserted into an academic paper.
64
  """
65
 
66
 
 
74
 
75
 
76
  def create_related_work_pipeline():
77
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", temperature=0.3)
 
 
 
 
 
78
 
79
  prompt = PromptTemplate(
80
+ input_variables=["title", "abstract", "citations"], template=PROMPT_TEXT
 
81
  )
82
 
83
  parser = StrOutputParser()
 
87
  return chain
88
 
89
 
90
+ def generate_related_work(
91
+ pipeline, title: str, abstract: str, citations_text: str
92
+ ) -> str:
93
+ result = pipeline.invoke(
94
+ {"title": title, "abstract": abstract, "citations": citations_text}
95
+ )
 
 
 
 
 
 
 
 
 
 
 
 
96
  return result
97
 
98
 
 
123
  print("-" * 50)
124
 
125
  try:
126
+ pipeline = create_related_work_pipeline()
127
+ related_work = generate_related_work(pipeline, title, abstract, citations)
128
  print(related_work)
129
  except Exception as e:
130
  print(f"Error: {e}")
131
  print("1. Create a .env file in the same folder as the script")
132
  print("2. Add the line: GOOGLE_API_KEY=your_key")
133
  print("3. Get the key at: https://makersuite.google.com/app/apikey")
134
+ check_api_key()
model/mlp.py DELETED
@@ -1,137 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from sklearn.metrics import roc_auc_score, average_precision_score
5
- import numpy as np
6
- from dataset.ogbn_link_pred_dataset import (
7
- OGBNLinkPredDataset,
8
- OGBNLinkPredNegDataset,
9
- # OGBNLinkPredNegDataset2,
10
- )
11
- from pathlib import Path
12
- from sentence_transformers import SentenceTransformer
13
- import argparse
14
-
15
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- BATCH_SIZE = 2048
17
- NUM_EPOCHS = 50
18
-
19
-
20
- def parse_args():
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument(
23
- "--custom-neg", action=argparse.BooleanOptionalAction, default=False
24
- )
25
- parser.add_argument(
26
- "--bert-embed", action=argparse.BooleanOptionalAction, default=False
27
- )
28
- return parser.parse_args()
29
-
30
-
31
- # --- Feature builder ---
32
- def edge_features(emb, ei):
33
- u, v = ei
34
- eu, ev = emb[u], emb[v]
35
- return torch.cat([eu * ev, torch.abs(eu - ev)], dim=1)
36
-
37
-
38
- # --- Simple MLP ---
39
- class PairMLP(nn.Module):
40
- def __init__(self, in_dim, hidden=256):
41
- super().__init__()
42
- self.fc1 = nn.Linear(in_dim, hidden)
43
- self.fc2 = nn.Linear(hidden, 1)
44
-
45
- def forward(self, x):
46
- x = F.relu(self.fc1(x))
47
- return self.fc2(x).squeeze(-1)
48
-
49
-
50
- # --- Training loop ---
51
- def run_epoch(data, train=True):
52
- model.train(train)
53
- total_loss = 0
54
- idx = (
55
- torch.randperm(data.edge_label.size(0))
56
- if train
57
- else torch.arange(data.edge_label.size(0))
58
- )
59
- for i in range(0, len(idx), BATCH_SIZE):
60
- batch_end = min(i + BATCH_SIZE, data.edge_label.size(0))
61
- batch_idx = idx[i:batch_end]
62
- feats = edge_features(emb, data.edge_label_index[:, batch_idx]).to(DEVICE)
63
- labels = data.edge_label[batch_idx].float().to(DEVICE)
64
- scores = model(feats)
65
- loss = F.binary_cross_entropy_with_logits(scores, labels)
66
- if train:
67
- opt.zero_grad()
68
- loss.backward()
69
- opt.step()
70
- total_loss += loss.item() * len(batch_idx)
71
- return total_loss / len(idx)
72
-
73
-
74
- @torch.no_grad()
75
- def evaluate(data):
76
- scores_all, labels_all = [], []
77
- for i in range(0, data.edge_label.size(0), BATCH_SIZE):
78
- batch_end = min(i + BATCH_SIZE, data.edge_label.size(0))
79
- feats = edge_features(emb, data.edge_label_index[:, i:batch_end]).to(DEVICE)
80
- labels = data.edge_label[i : i + BATCH_SIZE]
81
- scores = torch.sigmoid(model(feats)).cpu().numpy()
82
- scores_all.append(scores)
83
- labels_all.append(labels.numpy())
84
- y_scores = np.concatenate(scores_all)
85
- y_true = np.concatenate(labels_all)
86
- return roc_auc_score(y_true, y_scores), average_precision_score(y_true, y_scores)
87
-
88
-
89
- if __name__ == "__main__":
90
- args = parse_args()
91
- USE_CUSTOM_NEG = args.custom_neg
92
- USE_BERT_EMBED = args.bert_embed
93
-
94
- # --- Load dataset + frozen embeddings ---
95
- if USE_CUSTOM_NEG:
96
- print("using hard negatives")
97
- dataset = OGBNLinkPredNegDataset(val_size=0.1, test_size=0.2)
98
- else:
99
- print("using random negatives")
100
- dataset = OGBNLinkPredDataset(val_size=0.1, test_size=0.2)
101
- if USE_BERT_EMBED:
102
- print("using BERT embeds")
103
- if Path("model/embeddings.pth").exists():
104
- emb = torch.load("model/embeddings.pth", map_location=DEVICE)
105
- else:
106
- st = SentenceTransformer("bongsoo/kpf-sbert-128d-v1", device=DEVICE)
107
- emb = st.encode(
108
- dataset.corpus, convert_to_tensor=True, show_progress_bar=True
109
- )
110
- Path("model").mkdir(parents=True, exist_ok=True)
111
- torch.save(emb, "model/embeddings.pth")
112
- emb = emb.to(DEVICE)
113
- else:
114
- print("using skipgram embeds")
115
- emb = dataset.data.x
116
-
117
- train_data, val_data, test_data = dataset.get_splits()
118
-
119
- model = PairMLP(emb.size(1) * 2).to(DEVICE)
120
- opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
121
-
122
- # --- Training ---
123
- best_roc, best_ap = 0.0, 0.0
124
- for epoch in range(NUM_EPOCHS):
125
- loss = run_epoch(train_data, train=True)
126
- val_roc, val_ap = evaluate(val_data)
127
- if val_roc > best_roc:
128
- torch.save(
129
- model.state_dict(), f"model_roc{str(val_roc)[:4].replace('.', '_')}.pth"
130
- )
131
- print(
132
- f"Epoch {epoch + 1} | Loss {loss:.4f} | Val ROC {val_roc:.4f} | Val AP {val_ap:.4f}"
133
- )
134
-
135
- # --- Final test ---
136
- test_roc, test_ap = evaluate(test_data)
137
- print(f"Test ROC {test_roc:.4f} | Test AP {test_ap:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/paper_similarity.py CHANGED
@@ -236,9 +236,9 @@ class PaperSimilarityFinder:
236
 
237
  def compare_methods(self, title: str, abstract: str, top_k: int = 5):
238
  """Compare TF-IDF vs sentence embeddings"""
239
- if not hasattr(self, 'corpus_vectors'):
240
  self._setup_tfidf()
241
- if not hasattr(self, 'corpus_embeddings'):
242
  self._setup_sentence_embeddings()
243
 
244
  query = f"{title}\n{abstract}"
@@ -246,10 +246,8 @@ class PaperSimilarityFinder:
246
  tfidf_results = self._find_similar_tfidf(query, top_k)
247
  sent_results = self._find_similar_sentence_transformer(query, top_k)
248
 
249
- return {
250
- 'tfidf': tfidf_results,
251
- 'sentence_transformer': sent_results
252
- }
253
 
254
  if __name__ == "__main__":
255
  dataset = OGBNLinkPredDataset()
@@ -265,28 +263,27 @@ if __name__ == "__main__":
265
  embeddings_cache_path=embeddings_dir,
266
  )
267
 
268
- my_title = "Polynomial Implicit Neural Representations For Large Diverse Datasets"
 
 
269
  my_abstract = """
270
- Implicit neural representations (INR) have gained significant popularity for signal and image representation for
271
- many end-tasks, such as superresolution, 3D modeling, and
272
- more. Most INR architectures rely on sinusoidal positional
273
- encoding, which accounts for high-frequency information in
274
- data. However, the finite encoding size restricts the model’s
275
- representational power. Higher representational power is
276
- needed to go from representing a single given image to representing large and diverse datasets. Our approach addresses
277
- this gap by representing an image with a polynomial function
278
- and eliminates the need for positional encodings. Therefore,
279
- to achieve a progressively higher degree of polynomial representation, we use element-wise multiplications between
280
- features and affine-transformed coordinate locations after
281
- every ReLU layer. The proposed method is evaluated qualitatively and quantitatively on large datasets like ImageNet.
282
- The proposed Poly-INR model performs comparably to stateof-the-art generative models without any convolution,
283
- normalization, or self-attention layers, and with far fewer trainable parameters. With much fewer training parameters and
284
- higher representative power, our approach paves the way
285
- for broader adoption of INR models for generative modeling tasks in complex domains. The code is available at
286
- https://github.com/Rajhans0/Poly_INR
287
  """
288
 
289
- top_k = 5
290
  print(f"\nTop {top_k} Citation Predictions:\n")
291
 
292
  top_papers = similarity_finder.find_similar_papers(
@@ -311,5 +308,3 @@ if __name__ == "__main__":
311
  for idx, score, text in top_papers_cached:
312
  title = text.split("\n")[0].strip()
313
  print(f"Title: '{title}'")
314
-
315
-
 
236
 
237
  def compare_methods(self, title: str, abstract: str, top_k: int = 5):
238
  """Compare TF-IDF vs sentence embeddings"""
239
+ if not hasattr(self, "corpus_vectors"):
240
  self._setup_tfidf()
241
+ if not hasattr(self, "corpus_embeddings"):
242
  self._setup_sentence_embeddings()
243
 
244
  query = f"{title}\n{abstract}"
 
246
  tfidf_results = self._find_similar_tfidf(query, top_k)
247
  sent_results = self._find_similar_sentence_transformer(query, top_k)
248
 
249
+ return {"tfidf": tfidf_results, "sentence_transformer": sent_results}
250
+
 
 
251
 
252
  if __name__ == "__main__":
253
  dataset = OGBNLinkPredDataset()
 
263
  embeddings_cache_path=embeddings_dir,
264
  )
265
 
266
+ my_title = (
267
+ "PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation"
268
+ )
269
  my_abstract = """
270
+ Point cloud is an important type of geometric data
271
+ structure. Due to its irregular format, most researchers
272
+ transform such data to regular 3D voxel grids or collections
273
+ of images. This, however, renders data unnecessarily
274
+ voluminous and causes issues. In this paper, we design a
275
+ novel type of neural network that directly consumes point
276
+ clouds, which well respects the permutation invariance of
277
+ points in the input. Our network, named PointNet, provides a unified architecture for applications ranging from
278
+ object classification, part segmentation, to scene semantic
279
+ parsing. Though simple, PointNet is highly efficient and
280
+ effective. Empirically, it shows strong performance on
281
+ par or even better than state of the art. Theoretically,
282
+ we provide analysis towards understanding of what the
283
+ network has learnt and why the network is r
 
 
 
284
  """
285
 
286
+ top_k = 10
287
  print(f"\nTop {top_k} Citation Predictions:\n")
288
 
289
  top_papers = similarity_finder.find_similar_papers(
 
308
  for idx, score, text in top_papers_cached:
309
  title = text.split("\n")[0].strip()
310
  print(f"Title: '{title}'")
 
 
model/simple_gcn_model.py DELETED
@@ -1,37 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch_geometric.nn import GCNConv
4
-
5
-
6
- class EdgeDecoder(torch.nn.Module):
7
- """Predict citation existence of two node embeddings."""
8
-
9
- def __init__(self, in_channels):
10
- super().__init__()
11
- self.linear = torch.nn.Linear(in_channels * 2, 1)
12
-
13
- def forward(self, z, edge_index):
14
- row, col = edge_index
15
- # Concatenate the embeddings of the two nodes
16
- z_cat = torch.cat([z[row], z[col]], dim=-1)
17
- return self.linear(z_cat).squeeze(-1)
18
-
19
-
20
- class SimpleGCN(torch.nn.Module):
21
- """Include encoder and decoder part. Encoder creates embedding for given node and decoder predict link existence between node embeddings."""
22
-
23
- def __init__(self, in_channels, hidden_channels, out_channels):
24
- super().__init__()
25
- self.conv1 = GCNConv(in_channels, hidden_channels)
26
- self.conv2 = GCNConv(hidden_channels, out_channels)
27
- self.decoder = EdgeDecoder(out_channels)
28
-
29
- def forward(self, x, edge_index):
30
- x = self.conv1(x, edge_index).relu()
31
- x = F.dropout(x, p=0.5, training=self.training)
32
- z = self.conv2(x, edge_index)
33
- return z
34
-
35
- def decode(self, z, edge_label_index):
36
- # We pass the edge_label_index to the decoder, which contains both pos and neg edges
37
- return self.decoder(z, edge_label_index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/train.py DELETED
@@ -1,139 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from torch_geometric.loader import LinkNeighborLoader
4
- from sklearn.metrics import roc_auc_score, accuracy_score
5
- from tqdm import tqdm
6
- from model.simple_gcn_model import SimpleGCN
7
- from dataset.ogbn_link_pred_dataset import OGBNLinkPredDataset
8
-
9
-
10
- BATCH_SIZE = 128
11
- NUM_EPOCHS = 20
12
- LR = 0.001
13
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
-
15
- # data
16
- dataset = OGBNLinkPredDataset(val_size=0.1, test_size=0.2)
17
- train_data, val_data, test_data = dataset.get_splits()
18
-
19
- train_loader = LinkNeighborLoader(
20
- train_data,
21
- num_neighbors=[-1, -1], # Use all neighbors
22
- neg_sampling_ratio=1.0, # 1 negative sample per positive edge
23
- edge_label_index=train_data.edge_label_index,
24
- edge_label=train_data.edge_label,
25
- batch_size=BATCH_SIZE,
26
- shuffle=True,
27
- num_workers=4,
28
- )
29
-
30
- val_loader = LinkNeighborLoader(
31
- val_data,
32
- num_neighbors=[-1, -1],
33
- neg_sampling_ratio=0.0, # RandomLinkSplit already added negative edges
34
- edge_label_index=val_data.edge_label_index,
35
- edge_label=val_data.edge_label,
36
- batch_size=BATCH_SIZE,
37
- shuffle=False,
38
- num_workers=4,
39
- )
40
-
41
- test_loader = LinkNeighborLoader(
42
- test_data,
43
- num_neighbors=[-1, -1],
44
- neg_sampling_ratio=0.0,
45
- edge_label_index=test_data.edge_label_index,
46
- edge_label=test_data.edge_label,
47
- batch_size=BATCH_SIZE,
48
- shuffle=False,
49
- num_workers=4,
50
- )
51
-
52
- # model
53
- model = SimpleGCN(
54
- in_channels=dataset.num_features,
55
- hidden_channels=256,
56
- out_channels=128,
57
- ).to(DEVICE)
58
-
59
- optimizer = torch.optim.Adam(model.parameters(), lr=LR)
60
- criterion = torch.nn.BCEWithLogitsLoss()
61
-
62
-
63
- # training
64
- def train(train_loader, epoch):
65
- model.train()
66
- total_loss = 0
67
- scaler = torch.GradScaler()
68
-
69
- pbar = tqdm(train_loader, desc=f"Training Epoch: {epoch}")
70
- for batch in pbar:
71
- batch = batch.to(DEVICE)
72
- optimizer.zero_grad()
73
-
74
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
75
- z = model(batch.x, batch.edge_index)
76
- out = model.decode(z, batch.edge_label_index)
77
- labels = batch.edge_label.float()
78
-
79
- loss = criterion(out, labels)
80
-
81
- scaler.scale(loss).backward()
82
- scaler.step(optimizer)
83
- scaler.update()
84
-
85
- total_loss += loss.item()
86
- pbar.set_postfix(loss=f"{loss.item():.4f}")
87
-
88
- return total_loss / len(train_loader)
89
-
90
-
91
- @torch.no_grad()
92
- def calc_metrics(loader):
93
- model.eval()
94
- all_scores = []
95
- all_labels = []
96
-
97
- pbar = tqdm(loader, desc="Testing")
98
- for batch in pbar:
99
- batch = batch.to(DEVICE)
100
- with torch.autocast(device_type=DEVICE.type, dtype=torch.bfloat16):
101
- z = model(batch.x, batch.edge_index)
102
- out = model.decode(z, batch.edge_label_index)
103
-
104
- scores = torch.sigmoid(out).float().cpu().numpy()
105
- labels = batch.edge_label.cpu().numpy()
106
-
107
- all_scores.append(scores)
108
- all_labels.append(labels)
109
-
110
- all_scores = np.concatenate(all_scores)
111
- all_labels = np.concatenate(all_labels)
112
-
113
- return roc_auc_score(all_labels, all_scores), accuracy_score(
114
- all_labels, all_scores > 0.5
115
- )
116
-
117
-
118
- if __name__ == "__main__":
119
- best_val_auc = 0
120
- best_auc = 0
121
- for epoch in range(1, NUM_EPOCHS + 1):
122
- loss = train(train_loader, epoch)
123
- val_auc, val_acc = calc_metrics(val_loader)
124
-
125
-
126
- print(
127
- f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, Val acc: {val_acc:.4f}",
128
- end=" ",
129
- )
130
- if val_auc > best_val_auc:
131
- print("New best")
132
- best_val_auc = val_auc
133
- best_auc = val_auc
134
- torch.save(model.state_dict(), "model.pth")
135
-
136
- test_auc, test_acc = calc_metrics(test_loader)
137
-
138
- print("-" * 30)
139
- print(f"Best validation AUC: {best_auc:.4f}")