import streamlit as st from langchain_core.messages import HumanMessage, SystemMessage from langchain_aws import ChatBedrock from langchain_openai import ChatOpenAI from langchain_groq import ChatGroq from langchain_anthropic import ChatAnthropic from langchain_google_genai import ChatGoogleGenerativeAI from typing import Optional from config import MODEL_OPTIONS def create_llm_model(llm_provider: str, **kwargs): """Create a language model based on the selected provider.""" params = st.session_state.get('params') if llm_provider == "OpenAI": return ChatOpenAI( base_url=params.get("base_url"), openai_api_key=params.get("api_key"), model=MODEL_OPTIONS['OpenAI'], temperature=kwargs.get('temperature', 0.7), ) elif llm_provider == "Antropic": return ChatAnthropic( base_url=params.get("base_url"), anthropic_api_key=params.get("api_key"), model=MODEL_OPTIONS['Antropic'], temperature=kwargs.get('temperature', 0.7), ) elif llm_provider == "Bedrock": import boto3 # Initialize Bedrock client _bedrock = boto3.client( 'bedrock-runtime', region_name=params.get("region_name"), aws_access_key_id=params.get("aws_access_key"), aws_secret_access_key=params.get("aws_secret_key"), ) return ChatBedrock( client=_bedrock, model_id=MODEL_OPTIONS['Bedrock'], **kwargs ) elif llm_provider == "Google": return ChatGoogleGenerativeAI( google_api_key=params.get("api_key"), model=MODEL_OPTIONS['Google'], temperature=kwargs.get('temperature', 0.7), max_tokens=kwargs.get('max_tokens', 4096), max_retries=2, ) elif llm_provider == "Groq": return ChatGroq( api_key=params.get("api_key"), # groq_api_key expected here model=MODEL_OPTIONS['Groq'], temperature=kwargs.get("temperature", 0.7), streaming=kwargs.get("streaming", False) ) else: raise ValueError(f"Unsupported LLM provider: {llm_provider}") def get_response(prompt: str, llm_provider: str): """Get a response from the LLM using the standard LangChain interface.""" try: # Create the LLM instance dynamically llm = create_llm_model(llm_provider) # Wrap prompt in a HumanMessage message = HumanMessage(content=prompt) # Invoke model and return the output content response = llm.invoke([message]) return response.content except Exception as e: return f"Error during LLM invocation: {str(e)}" def get_response_stream( prompt: str, llm_provider: str, system: Optional[str] = '', temperature: float = 1.0, max_tokens: int = 4096, **kwargs, ): """ Get a streaming response from the selected LLM provider. All provider-specific connection/auth should be handled via kwargs. """ try: # Add streaming and generation params to kwargs kwargs.update({ "temperature": temperature, "max_tokens": max_tokens, "streaming": True }) # Create the LLM with streaming enabled llm = create_llm_model(llm_provider, **kwargs) # Compose messages messages = [] if system: messages.append(SystemMessage(content=system)) messages.append(HumanMessage(content=prompt)) # Stream the response stream_response = llm.stream(messages) return stream_response except Exception as e: st.error(f"[Error during streaming: {str(e)}]") st.stop() def test_llm_connection(llm_provider: str, test_params: dict = None): """ Test the connection to the specified LLM provider. Returns a tuple of (success: bool, message: str) """ try: # Use test_params if provided, otherwise use session state if test_params: # Temporarily store original params original_params = st.session_state.get('params', {}) # Set test params st.session_state['params'] = test_params # Create LLM instance llm = create_llm_model(llm_provider) # Test with a simple message test_message = HumanMessage(content="Hello, this is a connection test. Please respond with 'OK'.") response = llm.invoke([test_message]) # Restore original params if we used test_params if test_params: st.session_state['params'] = original_params return True, f"✅ Connection successful! Model response: {response.content[:100]}..." except Exception as e: # Restore original params if we used test_params if test_params: st.session_state['params'] = original_params return False, f"❌ Connection failed: {str(e)}"