File size: 2,713 Bytes
f24563f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Optimizers for LLM training.
"""

import jax
import jax.numpy as jnp
import optax
from typing import Any, Callable, Dict, Optional, Tuple, Union
import flax


def create_adamw_optimizer(
    learning_rate: Union[float, Callable],
    weight_decay: float = 0.01,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    mask: Optional[Callable] = None
) -> optax.GradientTransformation:
    """
    Create AdamW optimizer.
    
    Args:
        learning_rate: Learning rate or learning rate schedule
        weight_decay: Weight decay coefficient
        b1: First moment decay
        b2: Second moment decay
        eps: Epsilon for numerical stability
        mask: Function to mask parameters from weight decay
        
    Returns:
        AdamW optimizer
    """
    if mask is None:
        # Default mask excludes bias and layer norm parameters from weight decay
        def mask(params):
            flat_params = flax.traverse_util.flatten_dict(params)
            return {
                k: (k[-1] != "bias" and not k[-1].startswith("layer_norm"))
                for k in flat_params.keys()
            }
    
    # Create optimizer chain
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),  # Gradient clipping
        optax.adamw(
            learning_rate=learning_rate,
            b1=b1,
            b2=b2,
            eps=eps,
            weight_decay=weight_decay,
            mask=mask
        )
    )
    
    return optimizer


def create_lion_optimizer(
    learning_rate: Union[float, Callable],
    weight_decay: float = 0.01,
    b1: float = 0.9,
    b2: float = 0.99,
    mask: Optional[Callable] = None
) -> optax.GradientTransformation:
    """
    Create Lion optimizer.
    
    Args:
        learning_rate: Learning rate or learning rate schedule
        weight_decay: Weight decay coefficient
        b1: First moment decay
        b2: Second moment decay
        mask: Function to mask parameters from weight decay
        
    Returns:
        Lion optimizer
    """
    if mask is None:
        # Default mask excludes bias and layer norm parameters from weight decay
        def mask(params):
            flat_params = flax.traverse_util.flatten_dict(params)
            return {
                k: (k[-1] != "bias" and not k[-1].startswith("layer_norm"))
                for k in flat_params.keys()
            }
    
    # Create optimizer chain
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),  # Gradient clipping
        optax.lion(
            learning_rate=learning_rate,
            b1=b1,
            b2=b2,
            weight_decay=weight_decay,
            mask=mask
        )
    )
    
    return optimizer