Spaces:
Running
Running
| import anthropic | |
| import streamlit as st | |
| from streamlit.logger import get_logger | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.llms import OpenAI | |
| from langchain.llms import HuggingFaceEndpoint | |
| from langchain.chat_models import ChatAnthropic | |
| from langchain.vectorstores import SupabaseVectorStore | |
| from stats import add_usage | |
| memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True) | |
| openai_api_key = st.secrets.openai_api_key | |
| anthropic_api_key = st.secrets.anthropic_api_key | |
| hf_api_key = st.secrets.hf_api_key | |
| logger = get_logger(__name__) | |
| def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db): | |
| if 'chat_history' not in st.session_state: | |
| st.session_state['chat_history'] = [] | |
| question = st.text_area("## Ask a question") | |
| columns = st.columns(2) | |
| with columns[0]: | |
| button = st.button("Ask") | |
| with columns[1]: | |
| clear_history = st.button("Clear History", type='secondary') | |
| st.markdown("---\n\n") | |
| if clear_history: | |
| # Clear memory in Langchain | |
| memory.clear() | |
| st.session_state['chat_history'] = [] | |
| st.experimental_rerun() | |
| if button: | |
| qa = None | |
| add_usage(stats_db, "chat", "prompt" + question, {"model": model, "temperature": st.session_state['temperature']}) | |
| if model.startswith("gpt"): | |
| logger.info('Using OpenAI model %s', model) | |
| qa = ConversationalRetrievalChain.from_llm( | |
| OpenAI( | |
| model_name=st.session_state['model'], openai_api_key=openai_api_key, temperature=st.session_state['temperature'], max_tokens=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True) | |
| elif anthropic_api_key and model.startswith("claude"): | |
| logger.info('Using Anthropics model %s', model) | |
| qa = ConversationalRetrievalChain.from_llm( | |
| ChatAnthropic( | |
| model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400) | |
| elif hf_api_key: | |
| logger.info('Using HF model %s', model) | |
| # print(st.session_state['max_tokens']) | |
| endpoint_url = ("https://api-inference.huggingface.co/models/"+ model) | |
| model_kwargs = {"temperature" : st.session_state['temperature'], | |
| "max_new_tokens" : st.session_state['max_tokens'], | |
| "return_full_text" : False} | |
| hf = HuggingFaceEndpoint( | |
| endpoint_url=endpoint_url, | |
| task="text-generation", | |
| huggingfacehub_api_token=hf_api_key, | |
| model_kwargs=model_kwargs | |
| ) | |
| qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": st.session_state["username"]}}), memory=memory, verbose=True, return_source_documents=True) | |
| st.session_state['chat_history'].append(("You", question)) | |
| # Generate model's response and add it to chat history | |
| model_response = qa({"question": question}) | |
| logger.info('Result: %s', model_response["answer"]) | |
| st.session_state['chat_history'].append(("Safety Copilot", model_response["answer"])) | |
| logger.info('Sources: %s', model_response["source_documents"]) | |
| # Display chat history | |
| st.empty() | |
| chat_history = st.session_state['chat_history'] | |
| for speaker, text in chat_history: | |
| st.markdown(f"**{speaker}:** {text}") |