|
|
import streamlit as st |
|
|
from config import MODEL_OPTIONS |
|
|
import traceback |
|
|
from services.mcp_service import connect_to_mcp_servers |
|
|
from services.chat_service import create_chat, delete_chat, get_current_chat |
|
|
from services.export_service import create_download_button_for_chat |
|
|
from services.import_service import create_import_widget |
|
|
from services.logging_service import get_logger |
|
|
from utils.tool_schema_parser import extract_tool_parameters |
|
|
from utils.async_helpers import reset_connection_state |
|
|
|
|
|
|
|
|
def create_history_chat_container(): |
|
|
history_container = st.sidebar.container(height=400, border=None) |
|
|
with history_container: |
|
|
chat_history_menu = [ |
|
|
f"{chat['chat_name']}_::_{chat['chat_id']}" |
|
|
for chat in st.session_state["history_chats"] |
|
|
] |
|
|
chat_history_menu = chat_history_menu[:50][::-1] |
|
|
|
|
|
if chat_history_menu: |
|
|
current_chat = st.radio( |
|
|
label="History Chats", |
|
|
format_func=lambda x: x.split("_::_")[0] + '...' if "_::_" in x else x, |
|
|
options=chat_history_menu, |
|
|
label_visibility="collapsed", |
|
|
index=st.session_state["current_chat_index"], |
|
|
key="current_chat" |
|
|
) |
|
|
|
|
|
if current_chat: |
|
|
new_chat_id = current_chat.split("_::_")[1] |
|
|
|
|
|
if st.session_state.get('current_chat_id') != new_chat_id: |
|
|
logger = get_logger() |
|
|
logger.log_system_status(f"Switching from chat {st.session_state.get('current_chat_id')} to {new_chat_id}") |
|
|
|
|
|
st.session_state['current_chat_id'] = new_chat_id |
|
|
|
|
|
for i, chat in enumerate(st.session_state["history_chats"]): |
|
|
if chat['chat_id'] == new_chat_id: |
|
|
st.session_state["current_chat_index"] = i |
|
|
break |
|
|
|
|
|
st.session_state["messages"] = get_current_chat(new_chat_id) |
|
|
|
|
|
logger.log_system_status(f"Loaded {len(st.session_state['messages'])} messages for chat {new_chat_id}") |
|
|
|
|
|
for i, msg in enumerate(st.session_state["messages"]): |
|
|
has_tool = "tool" in msg and msg["tool"] |
|
|
has_content = "content" in msg and msg["content"] |
|
|
logger.log_system_status(f"Message {i}: role={msg.get('role')}, has_tool={has_tool}, has_content={has_content}") |
|
|
|
|
|
|
|
|
chat_id = current_chat.split("_::_")[1] |
|
|
st.markdown("---") |
|
|
st.markdown("**π₯ Export Chat History:**") |
|
|
|
|
|
|
|
|
create_download_button_for_chat(chat_id, "json") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
create_import_widget() |
|
|
|
|
|
|
|
|
def create_sidebar_chat_buttons(): |
|
|
with st.sidebar: |
|
|
c1, c2 = st.columns(2) |
|
|
create_chat_button = c1.button( |
|
|
"New Chat", use_container_width=True, key="create_chat_button" |
|
|
) |
|
|
if create_chat_button: |
|
|
create_chat() |
|
|
st.rerun() |
|
|
|
|
|
delete_chat_button = c2.button( |
|
|
"Delete Chat", use_container_width=True, key="delete_chat_button" |
|
|
) |
|
|
if delete_chat_button and st.session_state.get('current_chat_id'): |
|
|
delete_chat(st.session_state['current_chat_id']) |
|
|
st.rerun() |
|
|
|
|
|
def create_model_select_widget(): |
|
|
params = st.session_state["params"] |
|
|
params['model_id'] = st.sidebar.selectbox('π Choose model', |
|
|
options=MODEL_OPTIONS.keys(), |
|
|
index=0) |
|
|
|
|
|
def create_provider_select_widget(): |
|
|
params = st.session_state.setdefault('params', {}) |
|
|
|
|
|
default_provider = params.get("model_id", list(MODEL_OPTIONS.keys())[0]) |
|
|
default_index = list(MODEL_OPTIONS.keys()).index(default_provider) |
|
|
|
|
|
selected_provider = st.sidebar.selectbox( |
|
|
'π Choose Provider', |
|
|
options=list(MODEL_OPTIONS.keys()), |
|
|
index=default_index, |
|
|
key="provider_selection", |
|
|
on_change=reset_connection_state |
|
|
) |
|
|
|
|
|
if selected_provider: |
|
|
params['model_id'] = selected_provider |
|
|
params['provider_index'] = list(MODEL_OPTIONS.keys()).index(selected_provider) |
|
|
st.sidebar.success(f"Model: {MODEL_OPTIONS[selected_provider]}") |
|
|
|
|
|
|
|
|
with st.sidebar.container(): |
|
|
if selected_provider == "Bedrock": |
|
|
with st.expander("π Bedrock Credentials", expanded=True): |
|
|
|
|
|
config_mode = st.radio( |
|
|
"Configuration Mode", |
|
|
["π Default", "βοΈ Custom"], |
|
|
key="bedrock_config_mode", |
|
|
horizontal=True |
|
|
) |
|
|
|
|
|
if config_mode == "π Default": |
|
|
|
|
|
from config import DEFAULT_ENV_CONFIG |
|
|
env_config = DEFAULT_ENV_CONFIG.get('Bedrock', {}) |
|
|
|
|
|
|
|
|
params['region_name'] = env_config.get('region_name', '') |
|
|
params['aws_access_key'] = env_config.get('aws_access_key', '') |
|
|
params['aws_secret_key'] = env_config.get('aws_secret_key', '') |
|
|
|
|
|
st.info("π Using configuration from environment variables") |
|
|
if env_config.get('region_name'): |
|
|
st.success(f"Region: {env_config.get('region_name')}") |
|
|
else: |
|
|
st.warning("β οΈ AWS_REGION environment variable not set") |
|
|
if env_config.get('aws_access_key'): |
|
|
st.success("β
AWS Access Key configured") |
|
|
else: |
|
|
st.warning("β οΈ AWS_ACCESS_KEY_ID environment variable not set") |
|
|
if env_config.get('aws_secret_key'): |
|
|
st.success("β
AWS Secret Key configured") |
|
|
else: |
|
|
st.warning("β οΈ AWS_SECRET_ACCESS_KEY environment variable not set") |
|
|
|
|
|
else: |
|
|
|
|
|
params['region_name'] = st.text_input("AWS Region", value='', placeholder="Enter AWS Region", key="region_name") |
|
|
params['aws_access_key'] = st.text_input("AWS Access Key", value='', type="password", placeholder="Enter AWS Access Key", key="aws_access_key") |
|
|
params['aws_secret_key'] = st.text_input("AWS Secret Key", value='', type="password", placeholder="Enter AWS Secret Key", key="aws_secret_key") |
|
|
|
|
|
|
|
|
if st.button("π§ͺ Test Connection", key="bedrock_test"): |
|
|
from services.ai_service import test_llm_connection |
|
|
test_params = { |
|
|
'region_name': params.get('region_name'), |
|
|
'aws_access_key': params.get('aws_access_key'), |
|
|
'aws_secret_key': params.get('aws_secret_key') |
|
|
} |
|
|
success, message = test_llm_connection(selected_provider, test_params) |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
logger.log_llm_test(selected_provider, success, None if success else message) |
|
|
|
|
|
if success: |
|
|
st.success(message) |
|
|
else: |
|
|
st.error(message) |
|
|
else: |
|
|
with st.expander("π API Key", expanded=True): |
|
|
|
|
|
config_mode = st.radio( |
|
|
"Configuration Mode", |
|
|
["π Default", "βοΈ Custom"], |
|
|
key=f"{selected_provider.lower()}_config_mode", |
|
|
horizontal=True |
|
|
) |
|
|
|
|
|
from config import DEFAULT_ENV_CONFIG |
|
|
env_config = DEFAULT_ENV_CONFIG.get(selected_provider, {}) |
|
|
|
|
|
if config_mode == "π Default": |
|
|
|
|
|
|
|
|
params['api_key'] = env_config.get('api_key', '') |
|
|
params['base_url'] = env_config.get('base_url', '') |
|
|
|
|
|
st.info("π Using configuration from environment variables") |
|
|
if env_config.get('api_key'): |
|
|
if selected_provider == "OpenAI": |
|
|
st.success("β
OpenAI API Key configured (hidden display)") |
|
|
else: |
|
|
st.success(f"β
{selected_provider} API Key configured") |
|
|
else: |
|
|
st.warning(f"β οΈ {selected_provider.upper()}_API_KEY environment variable not set") |
|
|
if env_config.get('base_url'): |
|
|
st.success(f"Base URL: {env_config.get('base_url')}") |
|
|
else: |
|
|
st.info(f"Using default Base URL: {env_config.get('base_url', 'N/A')}") |
|
|
|
|
|
else: |
|
|
|
|
|
params['api_key'] = st.text_input( |
|
|
f"{selected_provider} API Key", |
|
|
value='', |
|
|
type="password", |
|
|
placeholder=f"Enter {selected_provider} API Key", |
|
|
key="api_key" |
|
|
) |
|
|
params['base_url'] = st.text_input( |
|
|
f"{selected_provider} Base URL", |
|
|
value='', |
|
|
placeholder=env_config.get('base_url', f"Enter {selected_provider} Base URL"), |
|
|
key="base_url" |
|
|
) |
|
|
|
|
|
|
|
|
if st.button("π§ͺ Test Connection", key=f"{selected_provider.lower()}_test"): |
|
|
from services.ai_service import test_llm_connection |
|
|
test_params = { |
|
|
'api_key': params.get('api_key'), |
|
|
'base_url': params.get('base_url') |
|
|
} |
|
|
success, message = test_llm_connection(selected_provider, test_params) |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
logger.log_llm_test(selected_provider, success, None if success else message) |
|
|
|
|
|
if success: |
|
|
st.success(message) |
|
|
else: |
|
|
st.error(message) |
|
|
|
|
|
|
|
|
def create_advanced_configuration_widget(): |
|
|
params = st.session_state["params"] |
|
|
with st.sidebar.expander("βοΈ Basic config", expanded=False): |
|
|
params['max_tokens'] = st.number_input("Max tokens", |
|
|
min_value=1024, |
|
|
max_value=10240, |
|
|
value=4096, |
|
|
step=512,) |
|
|
params['temperature'] = st.slider("Temperature", 0.0, 1.0, step=0.05, value=1.0) |
|
|
|
|
|
def create_mcp_connection_widget(): |
|
|
|
|
|
if not st.session_state.get("auto_connect_attempted", False): |
|
|
from services.chat_service import auto_connect_to_mcp |
|
|
auto_connect_to_mcp() |
|
|
|
|
|
with st.sidebar: |
|
|
st.subheader("Server Management") |
|
|
with st.expander(f"MCP Servers ({len(st.session_state.servers)})"): |
|
|
for name, config in st.session_state.servers.items(): |
|
|
with st.container(border=True): |
|
|
st.markdown(f"**Server:** {name}") |
|
|
st.markdown(f"**URL:** {config['url']}") |
|
|
if st.button(f"Remove {name}", key=f"remove_{name}"): |
|
|
del st.session_state.servers[name] |
|
|
st.rerun() |
|
|
|
|
|
if st.session_state.get("agent"): |
|
|
st.success(f"πΆ Connected to {len(st.session_state.servers)} MCP servers!" |
|
|
f" Found {len(st.session_state.tools)} tools.") |
|
|
if st.button("Disconnect to MCP Servers"): |
|
|
with st.spinner("Connecting to MCP servers..."): |
|
|
try: |
|
|
logger = get_logger() |
|
|
logger.log_system_status("Disconnecting from MCP servers") |
|
|
|
|
|
reset_connection_state() |
|
|
|
|
|
|
|
|
logger.log_system_status("Successfully disconnected from MCP servers") |
|
|
|
|
|
st.rerun() |
|
|
except Exception as e: |
|
|
|
|
|
logger.log_error( |
|
|
"MCP_Disconnection_Error", |
|
|
str(e), |
|
|
{'servers': list(st.session_state.servers.keys())} |
|
|
) |
|
|
|
|
|
st.error(f"Error disconnecting to MCP servers: {str(e)}") |
|
|
st.code(traceback.format_exc(), language="python") |
|
|
else: |
|
|
st.warning("β οΈ Not connected to MCP server") |
|
|
if st.button("Connect to MCP Servers"): |
|
|
with st.spinner("Connecting to MCP servers..."): |
|
|
try: |
|
|
logger = get_logger() |
|
|
logger.log_system_status("Attempting to connect to MCP servers") |
|
|
|
|
|
connect_to_mcp_servers() |
|
|
|
|
|
|
|
|
logger.log_system_status("Successfully connected to MCP servers", { |
|
|
'servers_count': len(st.session_state.servers), |
|
|
'tools_count': len(st.session_state.tools) |
|
|
}) |
|
|
|
|
|
st.rerun() |
|
|
except Exception as e: |
|
|
|
|
|
logger.log_error( |
|
|
"MCP_Connection_Error", |
|
|
str(e), |
|
|
{'servers': list(st.session_state.servers.keys())} |
|
|
) |
|
|
|
|
|
st.error(f"Error connecting to MCP servers: {str(e)}") |
|
|
st.code(traceback.format_exc(), language="python") |
|
|
|
|
|
def create_mcp_tools_widget(): |
|
|
with st.sidebar: |
|
|
if st.session_state.tools: |
|
|
st.subheader("π§° Available Tools") |
|
|
|
|
|
selected_tool_name = st.selectbox( |
|
|
"Select a Tool", |
|
|
options=[tool.name for tool in st.session_state.tools], |
|
|
index=0 |
|
|
) |
|
|
|
|
|
if selected_tool_name: |
|
|
selected_tool = next( |
|
|
(tool for tool in st.session_state.tools if tool.name == selected_tool_name), |
|
|
None |
|
|
) |
|
|
|
|
|
if selected_tool: |
|
|
with st.container(): |
|
|
st.write("**Description:**") |
|
|
st.write(selected_tool.description) |
|
|
|
|
|
parameters = extract_tool_parameters(selected_tool) |
|
|
|
|
|
if parameters: |
|
|
st.write("**Parameters:**") |
|
|
for param in parameters: |
|
|
st.code(param) |