Spaces:
Running
Running
| # main.py | |
| import os | |
| import streamlit as st | |
| import anthropic | |
| from requests import JSONDecodeError | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_community.llms import HuggingFaceEndpoint | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from supabase import Client, create_client | |
| from streamlit.logger import get_logger | |
| from stats import get_usage, add_usage | |
| # βββββββ supabase + secrets ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| supabase_url = st.secrets.SUPABASE_URL | |
| supabase_key = st.secrets.SUPABASE_KEY | |
| openai_api_key = st.secrets.openai_api_key | |
| anthropic_api_key = st.secrets.anthropic_api_key | |
| hf_api_key = st.secrets.hf_api_key | |
| username = st.secrets.username | |
| supabase: Client = create_client(supabase_url, supabase_key) | |
| logger = get_logger(__name__) | |
| # βββββββ embeddings βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Switch to local BGE embeddings (no JSONDecode errors, no HTTPβbatch issues) :contentReference[oaicite:0]{index=0} | |
| embeddings = HuggingFaceBgeEmbeddings( | |
| model_name="BAAI/bge-large-en-v1.5", | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| # βββββββ vector store + memory βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| vector_store = SupabaseVectorStore( | |
| client=supabase, | |
| embedding=embeddings, | |
| query_name="match_documents", | |
| table_name="documents", | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| input_key="question", | |
| output_key="answer", | |
| return_messages=True, | |
| ) | |
| # βββββββ LLM setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| temperature = 0.1 | |
| max_tokens = 500 | |
| def response_generator(query: str) -> str: | |
| """Ask the RAG chain to answer `query`, with JSONβerror fallback.""" | |
| # log usage | |
| add_usage(supabase, "chat", "prompt:" + query, {"model": model, "temperature": temperature}) | |
| logger.info("Using HF model %s", model) | |
| # prepare HF text-generation LLM | |
| hf = HuggingFaceEndpoint( | |
| endpoint_url=f"https://api-inference.huggingface.co/models/{model}", | |
| task="text-generation", | |
| huggingfacehub_api_token=hf_api_key, | |
| model_kwargs={ | |
| "temperature": temperature, | |
| "max_new_tokens": max_tokens, | |
| "return_full_text": False, | |
| }, | |
| ) | |
| # conversational RAG chain | |
| qa = ConversationalRetrievalChain.from_llm( | |
| llm=hf, | |
| retriever=vector_store.as_retriever( | |
| search_kwargs={"score_threshold": 0.6, "k": 4, "filter": {"user": username}} | |
| ), | |
| memory=memory, | |
| verbose=True, | |
| return_source_documents=True, | |
| ) | |
| try: | |
| result = qa({"question": query}) | |
| except JSONDecodeError as e: | |
| # fallback logging | |
| logger.error("Embedding JSONDecodeError: %s", e) | |
| return "Sorry, I had trouble understanding the embedded data. Please try again." | |
| answer = result.get("answer", "") | |
| sources = result.get("source_documents", []) | |
| if not sources: | |
| return ( | |
| "Iβm sorry, I donβt have enough information to answer that. " | |
| "If you have a public data source to add, please email copilot@securade.ai." | |
| ) | |
| return answer | |
| # βββββββ Streamlit UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="Securade.ai - Safety Copilot", | |
| page_icon="https://securade.ai/favicon.ico", | |
| layout="centered", | |
| initial_sidebar_state="collapsed", | |
| menu_items={ | |
| "About": "# Securade.ai Safety Copilot v0.1\n[https://securade.ai](https://securade.ai)", | |
| "Get Help": "https://securade.ai", | |
| "Report a Bug": "mailto:hello@securade.ai", | |
| }, | |
| ) | |
| st.title("π·ββοΈ Safety Copilot π¦Ί") | |
| stats = get_usage(supabase) | |
| st.markdown(f"_{stats} queries answered!_") | |
| st.markdown( | |
| "Chat with your personal safety assistant about any health & safety related queries. " | |
| "[[blog](https://securade.ai/blog/how-securade-ai-safety-copilot-transforms-worker-safety.html)" | |
| "|[paper](https://securade.ai/assets/pdfs/Securade.ai-Safety-Copilot-Whitepaper.pdf)]" | |
| ) | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # show history | |
| for msg in st.session_state.chat_history: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # new user input | |
| if prompt := st.chat_input("Ask a question"): | |
| st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.spinner("Safety briefing in progress..."): | |
| answer = response_generator(prompt) | |
| with st.chat_message("assistant"): | |
| st.markdown(answer) | |
| st.session_state.chat_history.append({"role": "assistant", "content": answer}) | |