set -e # Set CUDA environment variables export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} export PATH=${CUDA_HOME}/bin:${PATH} export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} export FORCE_CUDA=1 # Fix OMP_NUM_THREADS warning export OMP_NUM_THREADS=1 echo "==========================================" echo "Building CUDA Extensions for Speed" echo "==========================================" echo "CUDA_HOME: $CUDA_HOME" echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')" echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')" if python -c 'import torch; print(torch.cuda.is_available())' | grep -q True; then echo "CUDA version: $(python -c 'import torch; print(torch.version.cuda)')" echo "GPU: $(python -c 'import torch; print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\")')" fi echo "==========================================" # Upgrade Gradio to latest compatible version (required for UI to work) echo "" echo "Upgrading Gradio to >=4.44.0..." pip install --upgrade --no-cache-dir "gradio>=4.44.0,<5.0.0" || { echo "Warning: Failed to upgrade Gradio, continuing with existing version" } # Install mamba-ssm (CUDA extension #1) echo "" echo "==========================================" echo "Installing mamba-ssm (CUDA extension #1)" echo "==========================================" pip install "mamba-ssm>=2.2.2" || { echo "ERROR: mamba-ssm installation failed. This is fatal." exit 1 } # Build selective_scan_cuda_oflex (CUDA extension #2) echo "" echo "==========================================" echo "Building selective_scan_cuda_oflex (CUDA extension #2)" echo "==========================================" if [ ! -d "SRMA-Mamba/selective_scan" ]; then echo "ERROR: SRMA-Mamba/selective_scan directory not found" exit 1 fi cd SRMA-Mamba/selective_scan pip install -e . || { echo "ERROR: selective_scan_cuda_oflex build failed. This is fatal." cd ../.. exit 1 } cd ../.. # Import-test both extensions (fatal if either fails) echo "" echo "==========================================" echo "Verifying CUDA Extensions" echo "==========================================" python - <<'PY' import sys import mamba_ssm, selective_scan_cuda_oflex print("mamba_ssm OK", mamba_ssm.__version__) print("selective_scan_cuda_oflex OK") PY if [ $? -ne 0 ]; then echo "ERROR: CUDA extension import test failed. This is fatal." exit 1 fi echo "All CUDA extensions verified successfully."