| | """
|
| | Functional API for BitLinear operations.
|
| |
|
| | This module provides the core functional implementations that will be called
|
| | by the nn.Module wrappers. These functions implement the mathematical operations
|
| | described in the BitNet and ternary neural network papers.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple
|
| |
|
| |
|
| | def bitlinear_python(
|
| | x: torch.Tensor,
|
| | W: torch.Tensor,
|
| | gamma: torch.Tensor,
|
| | bias: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Pure PyTorch reference implementation of BitLinear forward pass.
|
| |
|
| | This implements the core BitLinear computation:
|
| | output = x @ W^T * gamma + bias
|
| |
|
| | where W is a ternary weight matrix ({-1, 0, +1}), and gamma is a per-output
|
| | scaling factor that compensates for the quantization.
|
| |
|
| | Args:
|
| | x: Input tensor of shape [..., in_features]
|
| | W: Ternary weight matrix of shape [out_features, in_features]
|
| | with values in {-1, 0, +1}
|
| | gamma: Scaling factors of shape [out_features] or [1, out_features]
|
| | bias: Optional bias tensor of shape [out_features]
|
| |
|
| | Returns:
|
| | Output tensor of shape [..., out_features]
|
| |
|
| | Notes:
|
| | - This is the reference implementation for correctness
|
| | - CUDA kernels will optimize the ternary matrix multiplication
|
| | - Gamma scaling is applied per output channel
|
| | """
|
| |
|
| |
|
| | output = torch.matmul(x, W.t())
|
| |
|
| |
|
| |
|
| | if gamma.dim() == 1:
|
| |
|
| | output = output * gamma.unsqueeze(0)
|
| | else:
|
| |
|
| | output = output * gamma
|
| |
|
| |
|
| | if bias is not None:
|
| | output = output + bias
|
| |
|
| | return output
|
| |
|
| |
|
| | def greedy_ternary_decomposition(
|
| | W: torch.Tensor,
|
| | k: int,
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Greedy ternary decomposition of a weight matrix.
|
| |
|
| | Decomposes a dense weight matrix W into a sum of k ternary matrices:
|
| | W ≈ sum_{i=1}^k gamma_i * W_i^ternary
|
| |
|
| | This follows the greedy residual minimization approach:
|
| | 1. Quantize W to ternary → W_1, compute gamma_1
|
| | 2. Compute residual R_1 = W - gamma_1 * W_1
|
| | 3. Quantize R_1 to ternary → W_2, compute gamma_2
|
| | 4. Repeat for k iterations
|
| |
|
| | Args:
|
| | W: Dense weight matrix of shape [out_features, in_features]
|
| | k: Number of ternary components (typically 2-4 for BitNet)
|
| |
|
| | Returns:
|
| | W_ternary: Stacked ternary matrices of shape [k, out_features, in_features]
|
| | gammas: Scaling factors of shape [k, out_features]
|
| |
|
| | Notes:
|
| | - Each iteration reduces the residual error
|
| | - Larger k provides better approximation but more computation
|
| | - This is used in MultiTernaryLinear for improved expressiveness
|
| |
|
| | References:
|
| | - BitNet paper: "BitNet: Scaling 1-bit Transformers for Large Language Models"
|
| | - JMLR paper: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf
|
| | """
|
| | from .quantization import weight_to_ternary
|
| |
|
| |
|
| | residual = W.clone()
|
| |
|
| |
|
| | ternary_weights = []
|
| | gammas = []
|
| |
|
| |
|
| | for i in range(k):
|
| |
|
| | W_t, gamma = weight_to_ternary(residual, per_channel=True)
|
| |
|
| |
|
| | ternary_weights.append(W_t)
|
| | gammas.append(gamma)
|
| |
|
| |
|
| |
|
| |
|
| | residual = residual - (gamma.unsqueeze(1) * W_t)
|
| |
|
| |
|
| | W_ternary = torch.stack(ternary_weights, dim=0)
|
| | gammas_stacked = torch.stack(gammas, dim=0)
|
| |
|
| | return W_ternary, gammas_stacked
|
| |
|
| |
|
| |
|
| | def multi_ternary_linear_python(
|
| | x: torch.Tensor,
|
| | W_ternary: torch.Tensor,
|
| | gammas: torch.Tensor,
|
| | bias: Optional[torch.Tensor] = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Forward pass for multi-component ternary linear layer.
|
| |
|
| | Computes the sum of k ternary linear transformations:
|
| | output = sum_{i=1}^k (x @ W_i^T * gamma_i) + bias
|
| |
|
| | Args:
|
| | x: Input tensor of shape [..., in_features]
|
| | W_ternary: Stacked ternary weights of shape [k, out_features, in_features]
|
| | gammas: Scaling factors of shape [k, out_features]
|
| | bias: Optional bias tensor of shape [out_features]
|
| |
|
| | Returns:
|
| | Output tensor of shape [..., out_features]
|
| | """
|
| | k = W_ternary.size(0)
|
| |
|
| |
|
| |
|
| | output_shape = list(x.shape[:-1]) + [W_ternary.size(1)]
|
| | output = torch.zeros(output_shape, dtype=x.dtype, device=x.device)
|
| |
|
| |
|
| | for i in range(k):
|
| |
|
| | W_i = W_ternary[i]
|
| | gamma_i = gammas[i]
|
| |
|
| |
|
| | component_output = bitlinear_python(x, W_i, gamma_i, bias=None)
|
| |
|
| |
|
| | output = output + component_output
|
| |
|
| |
|
| | if bias is not None:
|
| | output = output + bias
|
| |
|
| | return output
|
| |
|
| |
|
| | def activation_quant(x: torch.Tensor, bits: int = 8) -> torch.Tensor:
|
| | """
|
| | Quantize activations for BitLinear.
|
| |
|
| | BitNet uses activation quantization in addition to weight quantization.
|
| | This function implements per-token absmax quantization for activations.
|
| |
|
| | Args:
|
| | x: Input activations of shape [..., features]
|
| | bits: Number of bits for quantization (default: 8)
|
| |
|
| | Returns:
|
| | Quantized activations (as float, not int)
|
| |
|
| | Notes:
|
| | - Uses absmax scaling per token
|
| | - Returns float tensor for compatibility with autograd
|
| | - Simulates quantization effects without actual INT8 storage
|
| | """
|
| |
|
| | Q_max = 2 ** (bits - 1) - 1
|
| | Q_min = -Q_max
|
| |
|
| |
|
| |
|
| | scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
|
| |
|
| |
|
| | scale = torch.clamp(scale, min=1e-5)
|
| |
|
| |
|
| | x_normalized = x / scale
|
| |
|
| |
|
| | x_quant_int = torch.clamp(
|
| | torch.round(x_normalized * Q_max),
|
| | min=Q_min,
|
| | max=Q_max
|
| | )
|
| |
|
| |
|
| | x_quant = (x_quant_int / Q_max) * scale
|
| |
|
| | return x_quant
|
| |
|