|
|
#!/bin/bash |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
set -e |
|
|
|
|
|
echo "Setting up environment for LLM training on TPU v4-32..." |
|
|
|
|
|
|
|
|
mkdir -p logs |
|
|
mkdir -p checkpoints |
|
|
|
|
|
|
|
|
if [ ! -d "maxtext" ]; then |
|
|
echo "Cloning MaxText repository..." |
|
|
git clone https://github.com/AI-Hypercomputer/maxtext.git |
|
|
cd maxtext |
|
|
else |
|
|
echo "MaxText repository already exists, updating..." |
|
|
cd maxtext |
|
|
git pull |
|
|
fi |
|
|
|
|
|
|
|
|
echo "Installing dependencies..." |
|
|
bash setup.sh |
|
|
pre-commit install |
|
|
|
|
|
|
|
|
echo "Setting up environment variables..." |
|
|
export PYTHONPATH=$PYTHONPATH:$(pwd) |
|
|
export JAX_PLATFORMS="tpu" |
|
|
|
|
|
|
|
|
echo "Checking TPU configuration..." |
|
|
python3 -c "import jax; print(f'TPU devices: {jax.device_count()}')" |
|
|
python3 -c "import jax; print(f'TPU type: {jax.devices()[0].platform}')" |
|
|
|
|
|
|
|
|
|
|
|
if [ -z "$GCS_BUCKET" ]; then |
|
|
echo "Please set the GCS_BUCKET environment variable" |
|
|
echo "Example: export GCS_BUCKET=gs://your-bucket-name" |
|
|
exit 1 |
|
|
fi |
|
|
|
|
|
echo "Checking GCS bucket access..." |
|
|
gsutil ls $GCS_BUCKET > /dev/null || (echo "Creating GCS bucket $GCS_BUCKET..." && gsutil mb -p $PROJECT_ID $GCS_BUCKET) |
|
|
|
|
|
|
|
|
echo "Setting up Cloud Storage FUSE..." |
|
|
BUCKET_NAME=$(echo $GCS_BUCKET | sed 's/gs:\/\///') |
|
|
MOUNT_PATH="/tmp/gcsfuse" |
|
|
bash setup_gcsfuse.sh DATASET_GCS_BUCKET=$BUCKET_NAME MOUNT_PATH=$MOUNT_PATH |
|
|
|
|
|
echo "Environment setup complete!" |
|
|
cd .. |
|
|
|
|
|
|
|
|
echo "You can now run the training script with: bash tpu_train.sh" |
|
|
|