import os import sys import subprocess if 'OMP_NUM_THREADS' not in os.environ or not os.environ['OMP_NUM_THREADS'].isdigit(): os.environ['OMP_NUM_THREADS'] = '1' print(f"✓ Set OMP_NUM_THREADS={os.environ['OMP_NUM_THREADS']}") if 'PYTORCH_ALLOC_CONF' not in os.environ: os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb=128' print(f"Set PYTORCH_ALLOC_CONF={os.environ['PYTORCH_ALLOC_CONF']}") ENABLE_TORCH_COMPILE = os.environ.get('ENABLE_TORCH_COMPILE', 'false').lower() == 'true' if ENABLE_TORCH_COMPILE: print("torch.compile enabled. First inference may take 30-60s for compilation.") print(" Set ENABLE_TORCH_COMPILE=false to disable for faster first run") print(" Set TORCH_COMPILE_MODE=max-autotune for maximum speed (slower first run)") else: print("torch.compile disabled by default (set ENABLE_TORCH_COMPILE=true to enable)") ENABLE_CUDNN_BENCHMARK = os.environ.get('ENABLE_CUDNN_BENCHMARK', 'true').lower() == 'true' INFERENCE_TIMEOUT = int(os.environ.get('INFERENCE_TIMEOUT', '1800')) MAX_GRADIO_CONCURRENCY = int(os.environ.get('MAX_GRADIO_CONCURRENCY', '1')) import gradio as gr print(f"Gradio version: {gr.__version__}") try: from gradio.routes import mount_gradio_app HAS_MOUNT_GRADIO_APP = True except ImportError: HAS_MOUNT_GRADIO_APP = False print("⚠ CRITICAL: mount_gradio_app not available. Gradio version too old. Need >= 4.44.1") print(f"⚠ Current Gradio version: {gr.__version__}") print("⚠ Please ensure requirements.txt has gradio==4.44.1") try: import spaces HAS_SPACES = True except ImportError: HAS_SPACES = False print("Warning: spaces module not found. GPU decorator will not be used.") srma_mamba_paths = [ os.path.join(os.path.dirname(__file__), 'SRMA-Mamba'), os.path.join(os.path.dirname(__file__), '../../SRMA-Mamba'), 'SRMA-Mamba', ] SRMA_MAMBA_DIR = None for path in srma_mamba_paths: if os.path.exists(path): sys.path.insert(0, path) SRMA_MAMBA_DIR = path print(f"Found model code at: {path}") break else: print("Warning: SRMA-Mamba directory not found. Model imports may fail.") BUILD_SRMAMAMBA_AVAILABLE = False build_SRMAMamba = None try: import mamba_ssm HAS_MAMBA_SSM = True try: version = mamba_ssm.__version__ print(f"mamba_ssm CUDA extension loaded (version: {version}) - fast path enabled") except: print("mamba_ssm CUDA extension loaded - fast path enabled") except ImportError: HAS_MAMBA_SSM = False print("ERROR: mamba_ssm not found. This is CRITICAL for speed. Model will use slow fallback.") print(" To install: Run setup.sh or: pip install mamba-ssm>=2.2.2") import os if os.environ.get('REQUIRE_CUDA_EXTENSIONS', 'false').lower() == 'true': raise ImportError("mamba_ssm is required but not installed. Set REQUIRE_CUDA_EXTENSIONS=false to allow fallback.") try: import selective_scan_cuda_oflex HAS_SELECTIVE_SCAN_CUDA = True print("selective_scan_cuda_oflex CUDA extension loaded - fast path enabled") except ImportError: HAS_SELECTIVE_SCAN_CUDA = False print("ERROR: selective_scan_cuda_oflex not found. This is CRITICAL for speed. Model will use slow fallback.") print(" To install: Run setup.sh or: cd SRMA-Mamba/selective_scan && pip install -e .") import os if os.environ.get('REQUIRE_CUDA_EXTENSIONS', 'false').lower() == 'true': raise ImportError("selective_scan_cuda_oflex is required but not installed. Set REQUIRE_CUDA_EXTENSIONS=false to allow fallback.") try: from configs.model_configs import build_SRMAMamba BUILD_SRMAMAMBA_AVAILABLE = True print("✓ Successfully imported build_SRMAMamba") except ImportError as e: error_str = str(e) print(f"Import error: {error_str}") if 'mamba_ssm' in error_str or 'mamba-ssm' in error_str: print("⚠ mamba-ssm not found. Attempting runtime installation...") print("This may take 5-10 minutes. Please wait...") os.environ['FORCE_CUDA'] = '1' if 'CUDA_HOME' not in os.environ: os.environ['CUDA_HOME'] = '/usr/local/cuda' try: print("Attempting mamba-ssm installation (method 1)...") result = subprocess.run( [sys.executable, "-m", "pip", "install", "--no-cache-dir", "mamba-ssm>=2.2.2"], capture_output=True, text=True, timeout=900 ) if result.returncode != 0: print(f"Method 1 failed. Trying method 2 (no build isolation)...") result = subprocess.run( [sys.executable, "-m", "pip", "install", "--no-cache-dir", "--no-build-isolation", "mamba-ssm>=2.2.2"], capture_output=True, text=True, timeout=900 ) if result.returncode == 0: print("✓ mamba-ssm installed successfully") try: from configs.model_configs import build_SRMAMamba BUILD_SRMAMAMBA_AVAILABLE = True print("✓ Successfully imported build_SRMAMamba after installation") except ImportError as e2: print(f"⚠ Still cannot import after installation: {e2}") print("⚠ App will start but model loading will fail") BUILD_SRMAMAMBA_AVAILABLE = False else: print(f"⚠ Installation failed. Output: {result.stdout[:500]}") print(f"⚠ Error: {result.stderr[:500]}") print("⚠ App will start but model loading will fail") BUILD_SRMAMAMBA_AVAILABLE = False except subprocess.TimeoutExpired: print("⚠ Installation timed out after 15 minutes") print("⚠ App will start but model loading will fail") BUILD_SRMAMAMBA_AVAILABLE = False except Exception as install_error: print(f"⚠ Installation error: {install_error}") print("⚠ App will start but model loading will fail") BUILD_SRMAMAMBA_AVAILABLE = False else: print(f"⚠ Import error (not mamba-ssm related): {e}") print("⚠ App will start but model loading will fail") BUILD_SRMAMAMBA_AVAILABLE = False