Harshith Reddy
Fix: Exact import test format, TF32 set directly (no try/except), verify allocator config includes max_split_size_mb=128
5f56790
set -e
# Set CUDA environment variables
export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
export PATH=${CUDA_HOME}/bin:${PATH}
export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
export FORCE_CUDA=1
# Fix OMP_NUM_THREADS warning
export OMP_NUM_THREADS=1
echo "=========================================="
echo "Building CUDA Extensions for Speed"
echo "=========================================="
echo "CUDA_HOME: $CUDA_HOME"
echo "PyTorch version: $(python -c 'import torch; print(torch.__version__)')"
echo "CUDA available: $(python -c 'import torch; print(torch.cuda.is_available())')"
if python -c 'import torch; print(torch.cuda.is_available())' | grep -q True; then
echo "CUDA version: $(python -c 'import torch; print(torch.version.cuda)')"
echo "GPU: $(python -c 'import torch; print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\")')"
fi
echo "=========================================="
# Upgrade Gradio to latest compatible version (required for UI to work)
echo ""
echo "Upgrading Gradio to >=4.44.0..."
pip install --upgrade --no-cache-dir "gradio>=4.44.0,<5.0.0" || {
echo "Warning: Failed to upgrade Gradio, continuing with existing version"
}
# Install mamba-ssm (CUDA extension #1)
echo ""
echo "=========================================="
echo "Installing mamba-ssm (CUDA extension #1)"
echo "=========================================="
pip install "mamba-ssm>=2.2.2" || {
echo "ERROR: mamba-ssm installation failed. This is fatal."
exit 1
}
# Build selective_scan_cuda_oflex (CUDA extension #2)
echo ""
echo "=========================================="
echo "Building selective_scan_cuda_oflex (CUDA extension #2)"
echo "=========================================="
if [ ! -d "SRMA-Mamba/selective_scan" ]; then
echo "ERROR: SRMA-Mamba/selective_scan directory not found"
exit 1
fi
cd SRMA-Mamba/selective_scan
pip install -e . || {
echo "ERROR: selective_scan_cuda_oflex build failed. This is fatal."
cd ../..
exit 1
}
cd ../..
# Import-test both extensions (fatal if either fails)
echo ""
echo "=========================================="
echo "Verifying CUDA Extensions"
echo "=========================================="
python - <<'PY'
import sys
import mamba_ssm, selective_scan_cuda_oflex
print("mamba_ssm OK", mamba_ssm.__version__)
print("selective_scan_cuda_oflex OK")
PY
if [ $? -ne 0 ]; then
echo "ERROR: CUDA extension import test failed. This is fatal."
exit 1
fi
echo "All CUDA extensions verified successfully."