Spaces:
Running
Running
| import json | |
| import os | |
| import pprint | |
| import streamlit as st | |
| import streamlit.components.v1 as components | |
| import requests | |
| from typing import Union | |
| pp = pprint.PrettyPrinter(indent=2) | |
| st.set_page_config(page_title="Gaia Search 🌖🌏", layout="wide") | |
| os.makedirs(os.path.join(os.getcwd(), ".streamlit"), exist_ok=True) | |
| with open(os.path.join(os.getcwd(), ".streamlit/config.toml"), "w") as file: | |
| file.write('[theme]\nbase="light"') | |
| corpus_name_map = { | |
| "LAION": "laion", | |
| "ROOTS": "roots", | |
| "The Pile": "pile", | |
| "C4": "c4", | |
| } | |
| st.sidebar.markdown( | |
| """ | |
| <style> | |
| .aligncenter { | |
| text-align: center; | |
| font-weight: bold; | |
| font-size: 36px; | |
| } | |
| </style> | |
| <p class="aligncenter">Gaia Search 🌖🌏</p> | |
| <p>A search engine for large scale texual | |
| corpora. Most of the datasets included in the tool are based on Common | |
| Crawl. By using the tool, you are also bound by the Common Crawl terms | |
| of use in respect of the content contained in the datasets. | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.sidebar.markdown( | |
| """ | |
| <style> | |
| .aligncenter { | |
| text-align: center; | |
| } | |
| </style> | |
| <p style='text-align: center'> | |
| <a href="https://github.com/huggingface/gaia" style="color:#7978FF;">GitHub</a> | <a href="https://arxiv.org/abs/2306.01481" style="color:#7978FF;" >Paper</a> | <a href="" style="color:#7978FF;" >Colab</a> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # <p class="aligncenter"> | |
| # <a href="" target="_blank"> | |
| # <img src="https://colab.research.google.com/assets/colab-badge.svg"/> | |
| # </a> | |
| # </p> | |
| query = st.sidebar.text_input(label="Query", placeholder="Type your query here") | |
| corpus = st.sidebar.selectbox( | |
| "Corpus", | |
| tuple(corpus_name_map.keys()), | |
| index=2, | |
| ) | |
| max_results = st.sidebar.slider( | |
| "Max Results", | |
| min_value=1, | |
| max_value=100, | |
| step=1, | |
| value=10, | |
| help="Max Number of Documents to return", | |
| ) | |
| # dark_mode_toggle = """ | |
| # <script> | |
| # function load_image(id){ | |
| # console.log(id) | |
| # var x = document.getElementById(id); | |
| # console.log(x) | |
| # if (x.style.display === "none") { | |
| # x.style.display = "block"; | |
| # } else { | |
| # x.style.display = "none"; | |
| # } | |
| # }; | |
| # function myFunction() { | |
| # var element = document.body; | |
| # element.classList.toggle("dark-mode"); | |
| # } | |
| # </script> | |
| # <button onclick="myFunction()">Toggle dark mode</button> | |
| # """ | |
| # st.sidebar.markdown(dark_mode_toggle, unsafe_allow_html=True) | |
| footer = """ | |
| <style> | |
| .footer { | |
| position: fixed; | |
| left: 0; | |
| bottom: 0; | |
| width: 100%; | |
| background-color: white; | |
| color: black; | |
| text-align: center; | |
| } | |
| </style> | |
| <div class="footer"> | |
| <p>Powered by <a href="https://huggingface.co/" >HuggingFace 🤗</a> and <a href="https://github.com/castorini/pyserini" >Pyserini 🦆</a></p> | |
| </div> | |
| """ | |
| st.sidebar.markdown(footer, unsafe_allow_html=True) | |
| def scisearch(query, corpus, num_results=10): | |
| try: | |
| print(query, corpus, num_results) | |
| query = query.strip() | |
| if query == "" or query is None: | |
| return | |
| post_data = {"query": query, "corpus": corpus, "k": num_results, "lang": "all"} | |
| address = ( | |
| os.environ.get("address") | |
| if corpus != "roots" | |
| else os.environ.get("address_roots") | |
| ) | |
| output = requests.post( | |
| address, | |
| headers={"Content-type": "application/json"}, | |
| data=json.dumps(post_data), | |
| timeout=60, | |
| ) | |
| payload = json.loads(output.text) | |
| return payload["results"], payload["highlight_terms"] | |
| except Exception as e: | |
| print(e) | |
| PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} | |
| PII_PREFIX = "PI:" | |
| def process_pii(text): | |
| for tag in PII_TAGS: | |
| text = text.replace( | |
| PII_PREFIX + tag, | |
| """<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format( | |
| tag | |
| ), | |
| ) | |
| return text | |
| def highlight_string(paragraph: str, highlight_terms: list) -> str: | |
| tokens = paragraph.split() | |
| tokens_html = [] | |
| for token in tokens: | |
| if token in highlight_terms: | |
| tokens_html.append("<b>{}</b>".format(token)) | |
| else: | |
| tokens_html.append(token) | |
| tokens_html = " ".join(tokens_html) | |
| return process_pii(tokens_html) | |
| def extract_lang_from_docid(docid): | |
| return docid.split("_")[1] | |
| def format_result(result, highlight_terms): | |
| text = result["text"] | |
| docid = result["docid"] | |
| tokens_html = highlight_string(text, highlight_terms) | |
| language = extract_lang_from_docid(docid) | |
| result_html = """ | |
| <span style='font-size:14px; font-family: Arial; color:MediumAquaMarine'>Language: {} | </span> | |
| <span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {} | </span><br> | |
| <span style='font-family: Arial;'>{}</span><br> | |
| <br> | |
| """.format( | |
| language, docid, tokens_html | |
| ) | |
| return "<p>" + result_html + "</p>" | |
| def process_results(corpus: str, hits: Union[list, dict], highlight_terms: list) -> str: | |
| hit_list = [] | |
| if corpus == "roots": | |
| result_page_html = "" | |
| for lang, results_for_lang in hits.items(): | |
| print("Processing language", lang) | |
| if len(results_for_lang) == 0: | |
| result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'> | |
| No results for language: <b>{}</b></div>""".format( | |
| lang | |
| ) | |
| continue | |
| results_for_lang_html = "" | |
| for result in results_for_lang: | |
| result_html = format_result(result, highlight_terms) | |
| results_for_lang_html += result_html | |
| results_for_lang_html = f""" | |
| <details> | |
| <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'> | |
| Results for language: <b>{lang}</b> | |
| </summary> | |
| {results_for_lang_html} | |
| </details>""" | |
| result_page_html += results_for_lang_html | |
| return result_page_html | |
| for hit in hits: | |
| res_head = f""" | |
| <p class="searchresult" style="color: #7978FF;">Document ID: {hit['docid']} | Score: {round(hit['score'], 2)}</p> | |
| """ | |
| if corpus == "laion": | |
| res_head += f""" | |
| <p style="color: #7978FF;">Caption:</p> | |
| <p>{highlight_string(hit['text'], highlight_terms)}</p> | |
| """ | |
| if ( | |
| "meta" in hit | |
| and hit["meta"] is not None | |
| and "docs" in hit["meta"] | |
| and len(hit["meta"]["docs"]) > 0 | |
| ): | |
| res_head += """<p style="color: #7978FF;"> Image links:</p><ul>""" | |
| for subhit in hit["meta"]["docs"]: | |
| res_head += f"""<li><a href={subhit["URL"]} target="_blank" style="color:#ffcdf8; ">{subhit["URL"]}</a></li>""" | |
| res_head += "</ul>" | |
| res_head += "<hr>" | |
| else: | |
| res_head += ( | |
| f"""<p>{highlight_string(hit['text'], highlight_terms)}</p></div><hr>""" | |
| ) | |
| hit_list.append(res_head) | |
| return " ".join(hit_list) | |
| submit_button = st.sidebar.button("Search", type="primary") | |
| if submit_button or query: | |
| query = query.strip() | |
| if query is None or query == "": | |
| components.html( | |
| """<p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
| Please provide a non-empty query. | |
| </p><br><hr><br>""" | |
| ) | |
| else: | |
| hits, highlight_terms = scisearch(query, corpus_name_map[corpus], max_results) | |
| html_results = process_results(corpus_name_map[corpus], hits, highlight_terms) | |
| rendered_results = f""" | |
| <div id="searchresultsarea"> | |
| <br> | |
| <p id="searchresultsnumber">About {max_results} results</p> | |
| {html_results} | |
| </div>""" | |
| # st.markdown( | |
| # """ | |
| # <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.2/dist/css/bootstrap.min.css" rel="stylesheet" | |
| # integrity="sha384-EVSTQN3/azprG1Anm3QDgpJLIm9Nao0Yz1ztcQTwFspd3yD65VohhpuuCOmLASjC" crossorigin="anonymous"> | |
| # """, | |
| # unsafe_allow_html=True, | |
| # ) | |
| # st.markdown( | |
| # """ | |
| # <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"> | |
| # """, | |
| # unsafe_allow_html=True, | |
| # ) | |
| # st.markdown( | |
| # f""" | |
| # <div class="row no-gutters mt-3 align-items-center"> | |
| # Gaia Search 🌖🌏 | |
| # <div class="col col-md-4"> | |
| # <input class="form-control border-secondary rounded-pill pr-5" type="search" value="{query}" id="example-search-input2"> | |
| # </div> | |
| # <div class="col-auto"> | |
| # <button class="btn btn-outline-light text-dark border-0 rounded-pill ml-n5" type="button"> | |
| # <i class="fa fa-search"></i> | |
| # </button> | |
| # </div> | |
| # </div> | |
| # """, | |
| # unsafe_allow_html=True, | |
| # ) | |
| # .bk-root{position:relative;width:auto;height:auto;box-sizing:border-box;font-family:Helvetica, Arial, sans-serif;font-size:13px;}.bk-root .bk,.bk-root .bk:before,.bk-root .bk:after{box-sizing:inherit;margin:0;border:0;padding:0;background-image:none;font-family:inherit;font-size:100%;line-height:1.42857143;}.bk-root pre.bk{font-family:Courier, monospace;} | |
| components.html( | |
| """ | |
| <head> | |
| <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'> | |
| </head> | |
| <style> | |
| #searchresultsarea { | |
| font-family: "Source Sans Pro", sans-serif; | |
| } | |
| #searchresultsnumber { | |
| font-size: 0.8rem; | |
| color: gray; | |
| } | |
| .searchresult h2 { | |
| font-size: 19px; | |
| line-height: 18px; | |
| font-weight: normal; | |
| color: rgb(7, 111, 222); | |
| margin-bottom: 0px; | |
| margin-top: 25px; | |
| color: #7978FF;" | |
| } | |
| .searchresult a { | |
| font-size: 12px; | |
| line-height: 12px; | |
| color: green; | |
| margin-bottom: 0px; | |
| } | |
| </style> | |
| """ | |
| + rendered_results, | |
| height=800, | |
| scrolling=True, | |
| ) | |