File size: 2,713 Bytes
f24563f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
"""
Optimizers for LLM training.
"""
import jax
import jax.numpy as jnp
import optax
from typing import Any, Callable, Dict, Optional, Tuple, Union
import flax
def create_adamw_optimizer(
learning_rate: Union[float, Callable],
weight_decay: float = 0.01,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
mask: Optional[Callable] = None
) -> optax.GradientTransformation:
"""
Create AdamW optimizer.
Args:
learning_rate: Learning rate or learning rate schedule
weight_decay: Weight decay coefficient
b1: First moment decay
b2: Second moment decay
eps: Epsilon for numerical stability
mask: Function to mask parameters from weight decay
Returns:
AdamW optimizer
"""
if mask is None:
# Default mask excludes bias and layer norm parameters from weight decay
def mask(params):
flat_params = flax.traverse_util.flatten_dict(params)
return {
k: (k[-1] != "bias" and not k[-1].startswith("layer_norm"))
for k in flat_params.keys()
}
# Create optimizer chain
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping
optax.adamw(
learning_rate=learning_rate,
b1=b1,
b2=b2,
eps=eps,
weight_decay=weight_decay,
mask=mask
)
)
return optimizer
def create_lion_optimizer(
learning_rate: Union[float, Callable],
weight_decay: float = 0.01,
b1: float = 0.9,
b2: float = 0.99,
mask: Optional[Callable] = None
) -> optax.GradientTransformation:
"""
Create Lion optimizer.
Args:
learning_rate: Learning rate or learning rate schedule
weight_decay: Weight decay coefficient
b1: First moment decay
b2: Second moment decay
mask: Function to mask parameters from weight decay
Returns:
Lion optimizer
"""
if mask is None:
# Default mask excludes bias and layer norm parameters from weight decay
def mask(params):
flat_params = flax.traverse_util.flatten_dict(params)
return {
k: (k[-1] != "bias" and not k[-1].startswith("layer_norm"))
for k in flat_params.keys()
}
# Create optimizer chain
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # Gradient clipping
optax.lion(
learning_rate=learning_rate,
b1=b1,
b2=b2,
weight_decay=weight_decay,
mask=mask
)
)
return optimizer
|