Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- utils/conv_layer.py +33 -0
- utils/domain_configs.py +116 -0
- utils/selective_scan.py +55 -0
utils/conv_layer.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# utils/conv_layer.py
|
| 3 |
+
# =============================================================================
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
class Mamba1DConv(nn.Module):
|
| 8 |
+
def __init__(self, d_inner: int, d_conv: int = 4, bias: bool = True):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.d_conv = d_conv
|
| 11 |
+
|
| 12 |
+
self.conv1d = nn.Conv1d(
|
| 13 |
+
in_channels=d_inner,
|
| 14 |
+
out_channels=d_inner,
|
| 15 |
+
kernel_size=d_conv,
|
| 16 |
+
bias=bias,
|
| 17 |
+
groups=d_inner, # Depthwise convolution
|
| 18 |
+
padding=d_conv - 1
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
Args:
|
| 24 |
+
x: [batch, seq_len, d_inner]
|
| 25 |
+
Returns:
|
| 26 |
+
x: [batch, seq_len, d_inner]
|
| 27 |
+
"""
|
| 28 |
+
# Conv1d expects [batch, channels, seq_len]
|
| 29 |
+
x = x.transpose(1, 2) # [batch, d_inner, seq_len]
|
| 30 |
+
x = self.conv1d(x)
|
| 31 |
+
x = x[:, :, :-(self.d_conv-1)] # Remove padding
|
| 32 |
+
x = x.transpose(1, 2) # [batch, seq_len, d_inner]
|
| 33 |
+
return x
|
utils/domain_configs.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# utils/domain_configs.py
|
| 3 |
+
# =============================================================================
|
| 4 |
+
from typing import Dict, List
|
| 5 |
+
from core.config import MambaConfig
|
| 6 |
+
|
| 7 |
+
class DomainConfigs:
|
| 8 |
+
"""Configurations for different specialist domains"""
|
| 9 |
+
|
| 10 |
+
DOMAINS = {
|
| 11 |
+
# STEM domains
|
| 12 |
+
"mathematics": {
|
| 13 |
+
"keywords": ["equation", "theorem", "proof", "calculate", "derivative", "integral", "matrix", "algebra", "geometry", "statistics"],
|
| 14 |
+
"description": "Mathematical reasoning and computation"
|
| 15 |
+
},
|
| 16 |
+
"physics": {
|
| 17 |
+
"keywords": ["force", "energy", "momentum", "quantum", "relativity", "particle", "wave", "thermodynamics", "mechanics"],
|
| 18 |
+
"description": "Physics concepts and problems"
|
| 19 |
+
},
|
| 20 |
+
"chemistry": {
|
| 21 |
+
"keywords": ["molecule", "atom", "reaction", "compound", "bond", "element", "organic", "inorganic", "catalyst"],
|
| 22 |
+
"description": "Chemistry and molecular science"
|
| 23 |
+
},
|
| 24 |
+
"biology": {
|
| 25 |
+
"keywords": ["cell", "DNA", "protein", "organism", "evolution", "genetics", "ecology", "anatomy", "physiology"],
|
| 26 |
+
"description": "Biological sciences"
|
| 27 |
+
},
|
| 28 |
+
|
| 29 |
+
# Programming domains
|
| 30 |
+
"python": {
|
| 31 |
+
"keywords": ["def", "class", "import", "python", "pandas", "numpy", "matplotlib", "sklearn", "tensorflow"],
|
| 32 |
+
"description": "Python programming and data science"
|
| 33 |
+
},
|
| 34 |
+
"javascript": {
|
| 35 |
+
"keywords": ["function", "var", "let", "const", "javascript", "react", "node", "async", "promise"],
|
| 36 |
+
"description": "JavaScript and web development"
|
| 37 |
+
},
|
| 38 |
+
"systems": {
|
| 39 |
+
"keywords": ["linux", "server", "network", "database", "docker", "kubernetes", "cloud", "devops"],
|
| 40 |
+
"description": "Systems programming and infrastructure"
|
| 41 |
+
},
|
| 42 |
+
|
| 43 |
+
# Language domains
|
| 44 |
+
"writing": {
|
| 45 |
+
"keywords": ["essay", "article", "story", "paragraph", "thesis", "narrative", "prose", "literature"],
|
| 46 |
+
"description": "Creative and technical writing"
|
| 47 |
+
},
|
| 48 |
+
"translation": {
|
| 49 |
+
"keywords": ["translate", "language", "spanish", "french", "german", "chinese", "japanese", "korean"],
|
| 50 |
+
"description": "Language translation and linguistics"
|
| 51 |
+
},
|
| 52 |
+
|
| 53 |
+
# Business domains
|
| 54 |
+
"business": {
|
| 55 |
+
"keywords": ["market", "strategy", "finance", "management", "revenue", "profit", "customer", "sales"],
|
| 56 |
+
"description": "Business and economics"
|
| 57 |
+
},
|
| 58 |
+
"legal": {
|
| 59 |
+
"keywords": ["law", "contract", "court", "legal", "attorney", "judge", "case", "statute", "regulation"],
|
| 60 |
+
"description": "Legal reasoning and analysis"
|
| 61 |
+
},
|
| 62 |
+
|
| 63 |
+
# Other domains
|
| 64 |
+
"history": {
|
| 65 |
+
"keywords": ["war", "empire", "civilization", "century", "ancient", "medieval", "revolution", "dynasty"],
|
| 66 |
+
"description": "Historical knowledge and analysis"
|
| 67 |
+
},
|
| 68 |
+
"philosophy": {
|
| 69 |
+
"keywords": ["ethics", "moral", "logic", "metaphysics", "epistemology", "consciousness", "existence"],
|
| 70 |
+
"description": "Philosophical reasoning"
|
| 71 |
+
},
|
| 72 |
+
"medical": {
|
| 73 |
+
"keywords": ["patient", "diagnosis", "treatment", "disease", "medicine", "surgery", "therapy", "symptom"],
|
| 74 |
+
"description": "Medical knowledge and healthcare"
|
| 75 |
+
},
|
| 76 |
+
"arts": {
|
| 77 |
+
"keywords": ["painting", "music", "sculpture", "artist", "gallery", "museum", "aesthetic", "culture"],
|
| 78 |
+
"description": "Arts and cultural topics"
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def get_domain_configs(cls, num_specialists: int = 100) -> List[Dict]:
|
| 84 |
+
"""Generate configurations for specialist domains"""
|
| 85 |
+
configs = []
|
| 86 |
+
base_domains = list(cls.DOMAINS.keys())
|
| 87 |
+
|
| 88 |
+
# Create configurations
|
| 89 |
+
for i in range(num_specialists):
|
| 90 |
+
if i < len(base_domains):
|
| 91 |
+
# Use predefined domains
|
| 92 |
+
domain_name = base_domains[i]
|
| 93 |
+
domain_info = cls.DOMAINS[domain_name]
|
| 94 |
+
else:
|
| 95 |
+
# Create sub-specializations or general domains
|
| 96 |
+
base_idx = i % len(base_domains)
|
| 97 |
+
domain_name = f"{base_domains[base_idx]}_sub_{i}"
|
| 98 |
+
domain_info = cls.DOMAINS[base_domains[base_idx]]
|
| 99 |
+
|
| 100 |
+
config = {
|
| 101 |
+
"id": i,
|
| 102 |
+
"name": domain_name,
|
| 103 |
+
"keywords": domain_info["keywords"],
|
| 104 |
+
"description": domain_info["description"],
|
| 105 |
+
"weight": 1.0 # Can be adjusted based on importance
|
| 106 |
+
}
|
| 107 |
+
configs.append(config)
|
| 108 |
+
|
| 109 |
+
return configs
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def create_specialist_config(cls, base_config: MambaConfig, domain_id: int) -> MambaConfig:
|
| 113 |
+
"""Create a specialist configuration for a specific domain"""
|
| 114 |
+
specialist_config = MambaConfig(**base_config.__dict__)
|
| 115 |
+
specialist_config.specialist_id = domain_id
|
| 116 |
+
return specialist_config
|
utils/selective_scan.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# utils/selective_scan.py
|
| 3 |
+
# =============================================================================
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False):
|
| 9 |
+
"""
|
| 10 |
+
Selective scan function - core of Mamba's state space model
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
u: input sequence [batch, seq_len, d_inner]
|
| 14 |
+
delta: time step [batch, seq_len, d_inner]
|
| 15 |
+
A: state matrix [d_inner, d_state]
|
| 16 |
+
B: input matrix [batch, seq_len, d_state]
|
| 17 |
+
C: output matrix [batch, seq_len, d_state]
|
| 18 |
+
D: skip connection [d_inner]
|
| 19 |
+
z: gating [batch, seq_len, d_inner] (optional)
|
| 20 |
+
delta_bias: bias for delta (optional)
|
| 21 |
+
delta_softplus: whether to apply softplus to delta
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
y: output [batch, seq_len, d_inner]
|
| 25 |
+
"""
|
| 26 |
+
batch_size, seq_len, d_inner = u.shape
|
| 27 |
+
d_state = A.shape[1]
|
| 28 |
+
|
| 29 |
+
if delta_bias is not None:
|
| 30 |
+
delta = delta + delta_bias[None, None, :]
|
| 31 |
+
|
| 32 |
+
if delta_softplus:
|
| 33 |
+
delta = F.softplus(delta)
|
| 34 |
+
|
| 35 |
+
# Discretization
|
| 36 |
+
deltaA = torch.exp(delta.unsqueeze(-1) * A) # [batch, seq_len, d_inner, d_state]
|
| 37 |
+
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1) # [batch, seq_len, d_inner, d_state]
|
| 38 |
+
|
| 39 |
+
# Initialize hidden state
|
| 40 |
+
h = torch.zeros(batch_size, d_inner, d_state, device=u.device, dtype=u.dtype)
|
| 41 |
+
|
| 42 |
+
outputs = []
|
| 43 |
+
for i in range(seq_len):
|
| 44 |
+
h = deltaA[:, i] * h + deltaB_u[:, i] # State update
|
| 45 |
+
y = torch.sum(h * C[:, i].unsqueeze(1), dim=-1) # Output projection
|
| 46 |
+
if D is not None:
|
| 47 |
+
y = y + D * u[:, i]
|
| 48 |
+
outputs.append(y)
|
| 49 |
+
|
| 50 |
+
y = torch.stack(outputs, dim=1) # [batch, seq_len, d_inner]
|
| 51 |
+
|
| 52 |
+
if z is not None:
|
| 53 |
+
y = y * F.silu(z)
|
| 54 |
+
|
| 55 |
+
return y
|