Harshith Reddy
Fix: Exact import test format, TF32 set directly (no try/except), verify allocator config includes max_split_size_mb=128
5f56790
| set -e | |
| # Set CUDA environment variables | |
| export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} | |
| export PATH=${CUDA_HOME}/bin:${PATH} | |
| export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} | |
| export FORCE_CUDA=1 | |
| # Fix OMP_NUM_THREADS warning | |
| export OMP_NUM_THREADS=1 | |
| echo "==========================================" | |
| echo "Building CUDA Extensions for Speed" | |
| echo "==========================================" | |
| echo "CUDA_HOME: $CUDA_HOME" | |
| echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')" | |
| echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')" | |
| if python -c 'import torch; print(torch.cuda.is_available())' | grep -q True; then | |
| echo "CUDA version: $(python -c 'import torch; print(torch.version.cuda)')" | |
| echo "GPU: $(python -c 'import torch; print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\")')" | |
| fi | |
| echo "==========================================" | |
| # Upgrade Gradio to latest compatible version (required for UI to work) | |
| echo "" | |
| echo "Upgrading Gradio to >=4.44.0..." | |
| pip install --upgrade --no-cache-dir "gradio>=4.44.0,<5.0.0" || { | |
| echo "Warning: Failed to upgrade Gradio, continuing with existing version" | |
| } | |
| # Install mamba-ssm (CUDA extension #1) | |
| echo "" | |
| echo "==========================================" | |
| echo "Installing mamba-ssm (CUDA extension #1)" | |
| echo "==========================================" | |
| pip install "mamba-ssm>=2.2.2" || { | |
| echo "ERROR: mamba-ssm installation failed. This is fatal." | |
| exit 1 | |
| } | |
| # Build selective_scan_cuda_oflex (CUDA extension #2) | |
| echo "" | |
| echo "==========================================" | |
| echo "Building selective_scan_cuda_oflex (CUDA extension #2)" | |
| echo "==========================================" | |
| if [ ! -d "SRMA-Mamba/selective_scan" ]; then | |
| echo "ERROR: SRMA-Mamba/selective_scan directory not found" | |
| exit 1 | |
| fi | |
| cd SRMA-Mamba/selective_scan | |
| pip install -e . || { | |
| echo "ERROR: selective_scan_cuda_oflex build failed. This is fatal." | |
| cd ../.. | |
| exit 1 | |
| } | |
| cd ../.. | |
| # Import-test both extensions (fatal if either fails) | |
| echo "" | |
| echo "==========================================" | |
| echo "Verifying CUDA Extensions" | |
| echo "==========================================" | |
| python - <<'PY' | |
| import sys | |
| import mamba_ssm, selective_scan_cuda_oflex | |
| print("mamba_ssm OK", mamba_ssm.__version__) | |
| print("selective_scan_cuda_oflex OK") | |
| PY | |
| if [ $? -ne 0 ]; then | |
| echo "ERROR: CUDA extension import test failed. This is fatal." | |
| exit 1 | |
| fi | |
| echo "All CUDA extensions verified successfully." | |