diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 41557c29..1f1fe7d6 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`. \ No newline at end of file +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba version 1`, `Mamba version 2` and support for `full-fine-tuning`. diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index a6a56e0a..2764ce44 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -324,6 +324,7 @@ def trim(self, n): class MambaCache(_BaseCache): def __init__(self): self.cache = [None, None] + self.offset = 0 def __setitem__(self, idx, value): self.cache[idx] = value @@ -338,3 +339,15 @@ def state(self): @state.setter def state(self, v): self.cache = v + + +class Mamba2Cache: + def __init__(self, batch_size, conv_dim, kernel_size, num_heads, head_dim, state_size): + self.conv_states = mx.zeros((batch_size, conv_dim, kernel_size - 1)) + self.ssm_states = mx.zeros((batch_size, num_heads, head_dim, state_size)) + self.seqlen_offset = 0 + + def update(self, new_conv_state, new_ssm_state): + self.conv_states = new_conv_state + self.ssm_states = new_ssm_state + self.seqlen_offset += 1 \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2 copy.py b/llms/mlx_lm/models/mamba2 copy.py new file mode 100644 index 00000000..fc3f23d8 --- /dev/null +++ b/llms/mlx_lm/models/mamba2 copy.py @@ -0,0 +1,424 @@ +import math +from dataclasses import dataclass, field +from typing import Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + intermediate_size: int = None + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED + + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +def selective_scan(x, A, B, C, chunk_size): + """ + Selective scan implementation for training. + + Arguments + x: (batch, seqlen, n_heads, d_head) + A: (batch, seqlen, n_heads) + B: (batch, seqlen, n_heads, d_state) + C: (batch, seqlen, n_heads, d_state) + + Return + y: (batch, seqlen, n_heads, d_head) + """ + assert x.shape[1] % chunk_size == 0 + + # Reshape into chunks + def chunk_reshape(m): + shape = list(m.shape) + shape[1:2] = [shape[1] // chunk_size, chunk_size] + return m.reshape(shape) + + x, A, B, C = map(chunk_reshape, (x, A, B, C)) + A = mx.transpose(A, [0, 3, 1, 2]) + + # Compute cumulative sums + A_cumsum = mx.cumsum(A, axis=-1) + + # Process chunks + L = mx.exp(selective_cumsum(A)) + Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x) + + decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) + states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x) + + initial_states = mx.zeros_like(states[:, :1]) + states = mx.concatenate([initial_states, states], axis=1) + decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0))))) + new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states) + states = new_states[:, :-1] + + state_decay_out = mx.exp(A_cumsum) + Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:])) + return Y + +def selective_cumsum(x: mx.array) -> mx.array: + """Stable selective cumulative sum calculation.""" + T = x.shape[-1] + x = mx.repeat(x[..., None], T, axis=-1) + mask = mx.tril(mx.ones((T, T)), k=-1) + x = x * mask + x_cumsum = mx.cumsum(x, axis=-2) + mask = mx.tril(mx.ones((T, T)), k=0) + return mx.where(mask, x_cumsum, float('-inf')) + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + + # Project input to get various components [z, x, B, C, dt] + projection_size = (2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads) + self.in_proj = nn.Linear( + args.hidden_size, + projection_size, + bias=args.use_bias + ) + + # Convolution layer + conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size + self.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=args.conv_kernel, + groups=conv_dim, + padding=args.conv_kernel - 1, + bias=args.use_conv_bias + ) + + # SSM parameters + self.dt_bias = mx.zeros(args.num_heads) + self.A_log = mx.zeros(args.num_heads) + self.D = mx.ones(args.num_heads) + + # Output projections + self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon) + self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) + + def __call__(self, u: mx.array, cache=None) -> mx.array: + # return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache) + + # def forward_training(self, u: mx.array) -> mx.array: + # # Reset cache during training + # self.cache = None + + # # Input projection and splitting + # zxbcdt = self.in_proj(u) + # z, xBC, dt = mx.split( + # zxbcdt, + # [ + # self.args.hidden_size, + # self.args.hidden_size + 2 * self.args.state_size + # ], + # axis=-1 + # ) + + # # Time step processing + # dt = mx.clip( + # nn.softplus(dt + self.dt_bias), + # self.args.time_step_min, + # self.args.time_step_max + # ) + + # # Convolution processing + # xBC_t = mx.transpose(xBC, [0, 2, 1]) + # conv_out = self.conv1d(xBC_t) + # xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]] + # xBC = mx.sigmoid(xBC) * xBC # SiLU + + # # Split states + # x, B, C = mx.split( + # xBC, + # [self.args.hidden_size, self.args.state_size], + # axis=-1 + # ) + + # # Reshape for selective scan + # x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim)) + # A = -mx.exp(self.A_log) + + # # Apply selective scan + # y = selective_scan( + # x * dt[..., None], + # A * dt, + # B[..., None, :], + # C[..., None, :], + # self.args.chunk_size + # ) + + # # Output processing + # y = y + x * self.D[None, None, :, None] + # y = y.reshape((-1, y.shape[1], self.args.hidden_size)) + # y = self.norm(y, z) + # y = self.out_proj(y) + + # return y + + # def forward_inference(self, u: mx.array, cache=None) -> mx.array: + # """ + # u: (B, 1, D) + # cache: (h_cache, conv_cache) + # """ + # """Single token processing during inference.""" + # assert u.shape[1] == 1, "Inference mode expects single token" + + # batch_size = u.shape[0] + # # Use provided cache or create new one + # self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) + + # # Project input + # zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D) + # d_mlp = (zxbcdt.shape[-1] - 2 * self.args.hidden_size - 2 * self.args.n_groups * self.args.state_size - self.args.num_heads) // 2 + + # # (1, 768) (1, 0) (1, 0) (1, 256) (1, 0) (1, 3328) + # y0, z0, x0, z, xBC, dt = mx.split( + # zxbcdt, + # [ + # d_mlp, + # d_mlp, + # self.args.hidden_size, + # self.args.hidden_size + 2 * self.args.n_groups * self.args.state_size, + # self.args.num_heads + # ], + # axis=-1 + # ) + + # # Update convolution state and apply + # conv_state = self.cache.update_conv_state(xBC) + # xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) # (B, D) (4, 1792) + + # if self.args.use_conv_bias: + # xBC = xBC + self.conv1d.bias + + # xBC = mx.sigmoid(xBC) * xBC # SiLU (4, 1792) + + # # Split states and ensure proper shapes + # a0, x, B, C = mx.split( + # xBC, # (4, 1792) + # [ + # self.args.hidden_size, + # self.args.n_groups * self.args.state_size, + # self.args.n_groups * self.args.state_size + # ], + # axis=-1 + # ) + + # # SSM step with explicit shapes + # A = -mx.exp(self.A_log) # (num_heads) (24,) + # print(A.shape) # (24,) + # print(dt.shape) # (1, 3328) + # dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) <------- her eis the error + + # # Reshape x considering intermediate size + # # x shape should be (batch_size * num_heads, head_dim) + # x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) + # assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}" + + # B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size) + # C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size) + + # # Compute dBx with explicit shapes + # dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x) + + # ssm_state = self.cache.update_ssm_state(dA, dBx) + + # y = mx.einsum('bhds,bs->bhd', ssm_state, C) + # y = y + x * self.D[None, :, None] + # y = mx.reshape(y, (batch_size, self.args.hidden_size)) + + # # Output processing + # y = self.norm(y, z) + + # if d_mlp > 0: + # y = mx.cat([nn.silu(z0) * x0, y], axis=-1) + + # y = self.out_proj(y) + + # return mx.expand_dims(y, 1) + + assert u.shape[1] == 1, "Inference mode expects single token" + + batch_size = u.shape[0] + # Use provided cache or create new one + self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None) + + # Project input + zxbcdt = self.in_proj(u.squeeze(1)) # (B, projection_size) + + # Calculate splits based on model dimensions + d_mlp = self.args.intermediate_size + d_state = self.args.state_size * self.args.n_groups + + # Split the projection into its components + splits = [ + d_mlp, # y0 + d_mlp, # z0 + self.args.hidden_size, # x0 + self.args.hidden_size, # z + d_state * 2, # xBC (includes both B and C) + self.args.num_heads # dt + ] + + y0, z0, x0, z, xBC, dt = mx.split(zxbcdt, splits[:-1], axis=-1) + + # Update convolution state and apply + conv_state = self.cache.update_conv_state(xBC) + xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) + + if self.args.use_conv_bias: + xBC = xBC + self.conv1d.bias + + xBC = mx.sigmoid(xBC) * xBC # SiLU + + # Split states and reshape + x, BC = mx.split(xBC, [self.args.intermediate_size], axis=-1) + B, C = mx.split(BC, [d_state], axis=-1) + + # Reshape for SSM computation + x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) # (B, H, head_dim) + B = mx.reshape(B, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head) + C = mx.reshape(C, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head) + + # Process dt to match expected shape + dt = mx.reshape(dt, (batch_size, self.args.num_heads)) # (B, H) + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.args.time_step_min, + self.args.time_step_max + ) + + # SSM step + A = -mx.exp(self.A_log) # (H,) + dA = mx.exp(dt * A[None, :]) # (B, H) + + # Compute dBx + dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, x) + + # Update SSM state and compute output + ssm_state = self.cache.update_ssm_state(dA, dBx) + y = mx.einsum('bhds,bhs->bhd', ssm_state, C) + y = y + x * self.D[None, :, None] + + # Reshape output + y = mx.reshape(y, (batch_size, self.args.hidden_size)) + + # Final output processing + y = self.norm(y, z) + + if d_mlp > 0: + y = mx.concat([nn.silu(z0) * x0, y], axis=-1) + + y = self.out_proj(y) + + return mx.expand_dims(y, 1) # (B, 1, D) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = Mamba2Block(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache=None) -> mx.array: + # x : (B, L, D) + return self.mixer(self.norm(x), cache) + x # (B, L, D) + + +class Mamba2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache=None) -> mx.array: + # x : (B, L) + x = self.embeddings(x) + # x : (B, L, D) + if cache is None: + cache = [None] * len(self.layers) + + for layer, layer_cache in zip(self.layers, cache): + x = layer(x, layer_cache) + return self.norm_f(x) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.backbone = Mamba2Model(args) + + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None) -> mx.array: + # inputs : (B, L) + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size=1): + return [Mamba2Cache( + batch_size=batch_size, + hidden_size=self.args.hidden_size, + state_size=self.args.state_size, + conv_kernel=self.args.conv_kernel, + num_heads=self.args.num_heads, + head_dim=self.args.head_dim + ) for _ in range(len(self.backbone.layers))] + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2-other.py b/llms/mlx_lm/models/mamba2-other.py new file mode 100644 index 00000000..22064021 --- /dev/null +++ b/llms/mlx_lm/models/mamba2-other.py @@ -0,0 +1,288 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass, field +from typing import Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "mamba2" + num_heads: int = 128 + head_dim: int = 64 + vocab_size: int = 32768 + hidden_size: int = 4096 + state_size: int = 128 + num_hidden_layers: int = 64 + layer_norm_epsilon: float = 1e-5 + expand: int = 2 + conv_kernel: int = 4 + n_groups: int = 8 + use_bias: bool = False + use_conv_bias: bool = True + initializer_range: float = 0.1 + residual_in_fp32: bool = True + time_step_rank: Union[int, str] = "auto" + time_step_min: float = 0.001 + time_step_max: float = 0.1 + time_step_floor: float = 1e-4 + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + rescale_prenorm_residual: bool = False + use_cache: bool = True + rms_norm: bool = True + chunk_size: int = 256 + tie_word_embeddings: bool = False + + def __post_init__(self): + if not hasattr(self, "intermediate_size"): + self.intermediate_size = int(self.expand * self.hidden_size) + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class Mamba2Cache: + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones((hidden_size,)) + self.variance_epsilon = eps + + def __call__(self, hidden_states, gate=None): + if gate is not None: + hidden_states = hidden_states * nn.silu(gate) + variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.groups = groups if groups is not None else in_channels + + # Ensure in_channels and out_channels are the same for depthwise conv + assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution" + # Ensure groups is equal to in_channels for depthwise conv + assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" + + # Initialize weight with shape (out_channels, kernel_size, 1) + self.weight = mx.random.normal((out_channels, kernel_size, 1)) + self.bias = mx.zeros((out_channels,)) if bias else None + + def __call__(self, x, cache=None): + B, L, C = x.shape + _, K, _ = self.weight.shape + + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + + y = mx.conv_general(x, self.weight, groups=self.groups) + + if self.bias is not None: + y = y + self.bias + + return y, x[:, -K + 1 :, :] + + +class Mamba2Mixer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.intermediate_size = args.intermediate_size + self.time_step_rank = args.time_step_rank + self.conv_kernel_size = args.conv_kernel + self.hidden_size = args.hidden_size + self.state_size = args.state_size + self.num_heads = args.num_heads + self.head_dim = args.hidden_size // args.num_heads + self.n_groups = args.n_groups + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size + self.conv1d = DepthWiseConv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=args.use_conv_bias, + kernel_size=args.conv_kernel, + groups=self.conv_dim, + padding=args.conv_kernel - 1 + ) + + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=args.use_bias + ) + + self.dt_bias = mx.ones((self.num_heads,)) + self.A_log = mx.log(mx.arange(1, self.num_heads + 1)) + self.D = mx.ones((self.num_heads,)) + + self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon) + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) + + def ssm_step(self, x, state, dt_proj): + A = -mx.exp(self.A_log) + D = self.D + delta = nn.softplus(dt_proj + self.dt_bias) + + B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) + + B = B.reshape(-1, self.n_groups, self.state_size) + C = C.reshape(-1, self.n_groups, self.state_size) + + if state is None: + new_state = mx.expand_dims(delta, -1) * B + else: + new_state = mx.expand_dims(delta, -1) * (B + state * mx.exp(mx.expand_dims(delta, -1) * A)) + + y = mx.sum(new_state * C, axis=-1) + y = y + D * x[:, :self.num_heads] + return y, new_state + + def __call__(self, x, cache): + B, T, D = x.shape + if cache is None: + cache = [None, None] + + outputs = [] + for t in range(T): + xt = x[:, t, :] + xz = self.in_proj(xt) + + x_t, z_t, dt_proj = mx.split( + xz, + indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], + axis=-1 + ) + + conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) + x_t = conv_out.squeeze(1) + x_t = nn.silu(x_t) + y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj) + z_t = nn.silu(z_t) + + # Print shapes for debugging + print(f"y_t shape: {y_t.shape}") + print(f"z_t shape: {z_t.shape}") + + # Reshape y_t to (B, num_heads, head_dim) + y_t_reshaped = y_t.reshape(B, self.num_heads, -1) + + # Reshape z_t to (B, num_heads, intermediate_size // num_heads) + z_t_reshaped = z_t.reshape(B, self.num_heads, -1) + + print(f"y_t_reshaped shape: {y_t_reshaped.shape}") + print(f"z_t_reshaped shape: {z_t_reshaped.shape}") + + # Element-wise multiplication (broadcasting across the last dimension) + output_t = y_t_reshaped * z_t_reshaped + + # Reshape to match the expected input of out_proj + output_t = output_t.reshape(B, -1) + + print(f"output_t shape before out_proj: {output_t.shape}") + print(f"out_proj weight shape: {self.out_proj.weight.shape}") + + output_t = self.out_proj(output_t) + outputs.append(output_t) + + output = mx.stack(outputs, axis=1) + return output + + +class Mamba2Block(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = Mamba2Mixer(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache): + return self.mixer(self.norm(x), cache) + x + + +class Mamba2(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + inputs: mx.array, + cache=None + ): + hidden_states = self.embeddings(inputs) + + if cache is None: + cache = Mamba2Cache(len(self.layers)) + + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, cache[i]) + + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba2(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None): + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + + def make_cache(self, batch_size: int = 1): + return [Mamba2Cache() for _ in range(len(self.layers))] + + @property + def layers(self): + return self.backbone.layers diff --git a/llms/mlx_lm/models/mamba2-prch.py b/llms/mlx_lm/models/mamba2-prch.py new file mode 100644 index 00000000..84bf2174 --- /dev/null +++ b/llms/mlx_lm/models/mamba2-prch.py @@ -0,0 +1,490 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA2 model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +logger = logging.get_logger(__name__) + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +class Mamba2Cache: + """ + Arguments: + config: ModelArgs + batch_size: int + dtype: torch.dtype + device: torch.device + + Attributes: + seqlen_offset: int + dtype: torch.dtype + conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size] + ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size] + """ + + def __init__( + self, config: ModelArgs, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None + ): + self.seqlen_offset = 0 + self.dtype = dtype + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.n_groups * config.state_size, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros( + batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype + ) + for i in range(config.num_hidden_layers) + } + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states + + +class Mamba2Mixer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.use_conv_bias = config.use_conv_bias + self.act = nn.silu + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + + self.dt_bias = torch.ones(self.num_heads) + A = torch.arange(1, self.num_heads + 1) + self.A_log = torch.log(A) + self.D = torch.ones(self.num_heads) + + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + + def forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = ( + projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device + ) + hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] + + if cache_params is not None and cache_params.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache_params is not None and cache_params.seqlen_offset > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output) # [batch, seq_len, hidden_size] + return contextualized_states + + +class Mamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states + + +class Mamba2Block(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config) + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + x = self.mixer( + self.norm(hidden_states), cache_params=cache_params, cache_position=cache_position + ) + return x + hidden_states + + +class Mamba2Model(nn.Module): + def __init__(self, config): + super().__init__(config) + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + inputs_embeds = self.embeddings(input_ids) + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) + else: + cache_params = None + + hidden_states = inputs_embeds + for mixer_block in self.layers: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + ) + + if use_cache: + cache_params.seqlen_offset += inputs_embeds.shape[1] + + return self.norm_f(hidden_states), cache_params if use_cache else None + + + +class Mamba2ForCausalLM(nn.Module): + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + ): + mamba2_outputs = self.backbone( + input_ids, + cache_params=cache_params, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = mamba2_outputs[0] + + logits = self.lm_head(hidden_states) + return logits, mamba2_outputs.cache_params, mamba2_outputs.hidden_states \ No newline at end of file diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py new file mode 100644 index 00000000..bd0f17ee --- /dev/null +++ b/llms/mlx_lm/models/mamba2.py @@ -0,0 +1,361 @@ +import math +from dataclasses import dataclass, field +from typing import Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs +from .cache import Mamba2Cache + +@dataclass +class ModelArgs(BaseModelArgs): + num_heads: int + head_dim: int + vocab_size: int + hidden_size: int + state_size: int + num_hidden_layers: int + layer_norm_epsilon: float + expand: int + conv_kernel: int + n_groups: int + use_bias: bool + use_conv_bias: bool + initializer_range: float + residual_in_fp32: bool + time_step_min: float + time_step_max: float + time_step_floor: float + rescale_prenorm_residual: bool + use_cache: bool + rms_norm: bool + chunk_size: int + tie_word_embeddings: bool + intermediate_size: int = None + time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) + time_step_rank: Union[int, str] = "auto" + model_type: str = "mamba2" + + def __post_init__(self): + self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED + + if not hasattr(self, "head_dim"): + self.head_dim = self.hidden_size // self.num_heads + if self.time_step_rank == "auto": + self.time_step_rank = math.ceil(self.hidden_size / 16) + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = mx.ones(hidden_size) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mx.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Mamba2Mixer(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + # Model dimensions + self.hidden_size = args.hidden_size + self.num_heads = args.num_heads + self.head_dim = args.head_dim + self.ssm_state_size = args.state_size + self.n_groups = args.n_groups + self.intermediate_size = int(args.expand * args.hidden_size) + + # Convolution parameters + self.conv_kernel = args.conv_kernel + self.use_conv_bias = args.use_conv_bias + + # Time step parameters + self.time_step_rank = int(args.time_step_rank) + self.time_step_min = args.time_step_min + self.time_step_max = args.time_step_max + + # Processing parameters + self.chunk_size = args.chunk_size + self.layer_norm_epsilon = args.layer_norm_epsilon + + # Calculate dimensions + self.conv_dim = (self.intermediate_size + + 2 * self.n_groups * self.ssm_state_size) + projection_size = (self.intermediate_size + + self.conv_dim + + self.num_heads) + + # Initialize layers + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=args.use_bias + ) + + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + kernel_size=self.conv_kernel, + groups=self.conv_dim, + padding=self.conv_kernel - 1, + bias=self.use_conv_bias + ) + + # Initialize parameters + self.dt_bias = mx.ones(self.num_heads) + A = mx.arange(1, self.num_heads + 1) + self.A_log = mx.log(A) + self.D = mx.ones(self.num_heads) + + # Output layers + self.norm = MambaRMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon + ) + self.out_proj = nn.Linear( + self.intermediate_size, + self.hidden_size, + bias=args.use_bias + ) + + def reshape_into_chunks(self, tensor, pad_size, chunk_size): + if pad_size > 0: + pad_shape = list(tensor.shape) + pad_shape[1] = pad_size + padding = mx.zeros(pad_shape, dtype=tensor.dtype) + tensor = mx.concatenate([tensor, padding], axis=1) + + chunk_shape = list(tensor.shape) + chunk_shape[1] = -1 + chunk_shape.insert(2, chunk_size) + return tensor.reshape(chunk_shape) + + def segment_sum(self, x): + return mx.cumsum(x, axis=-1) + + def process_single_token(self, hidden_states, B, C, dt, cache): + batch_size = hidden_states.shape[0] + + # Process convolution state + if cache is not None: + conv_state = cache.conv_states + # Roll the conv state and update the last position + conv_state = mx.roll(conv_state, shift=-1, axis=-1) + # Create new conv state with updated last position + new_conv_state = mx.array(conv_state) + new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states) + conv_state = new_conv_state + + # Compute convolution + conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1) + if self.use_conv_bias: + conv_out = conv_out + self.conv1d.bias + + # Apply SiLU activation + conv_out = mx.sigmoid(conv_out) * conv_out + + else: + # Initialize new cache + conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1)) + conv_out = self.conv1d(hidden_states) + conv_out = mx.sigmoid(conv_out) * conv_out + + # Process SSM + dt = mx.clip( + nn.softplus(dt + self.dt_bias), + self.time_step_min, + self.time_step_max + ) + + A = -mx.exp(self.A_log) + dA = mx.exp(dt * A[None, :]) + + if cache is not None: + ssm_state = cache.ssm_states + else: + ssm_state = mx.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size) + ) + + # Compute SSM updates + dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states) + next_state = ssm_state * dA[:, :, None, None] + dBx + y = mx.einsum('bhds,bhs->bhd', next_state, C) + + # Add skip connection + y = y + hidden_states * self.D[None, :, None] + + return y, conv_state, next_state + + def process_long_sequence(self, hidden_states, B, C, dt, ssm_state): + batch_size, seq_len = hidden_states.shape[:2] + pad_size = self.chunk_size - (seq_len % self.chunk_size) + + # Reshape into chunks + x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size) + B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size) + C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size) + + # Process time steps + dt = nn.softplus(dt + self.dt_bias) + dt = mx.clip(dt, self.time_step_min) + + # Prepare matrices + A = -mx.exp(self.A_log) + A = A * dt[:, None] + + # Process chunks + A_chunks = self.reshape_into_chunks( + mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)), + pad_size, + self.chunk_size + ) + + # Compute cumulative sums + A_cumsum = mx.cumsum(A_chunks, axis=-1) + L = mx.exp(self.segment_sum(A_chunks)) + + # Process diagonal blocks + G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks) + M = G * L[..., None, :] + Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks) + + # Process off-diagonal blocks + decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum) + B_decay = B_chunks * decay_states[..., None] + states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks) + + # Combine results + y = Y_diag + states + + # Remove padding if necessary + if pad_size > 0: + y = y[:, :seq_len] + + return y, ssm_state + + def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: + batch_size, seq_len, _ = x.shape + + # Project input + projected_states = self.in_proj(x.squeeze(1)) + + # Calculate d_mlp based on projection size + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * + self.n_groups * self.ssm_state_size - self.num_heads) // 2 + + # Split projections with corrected dimensions + splits = [ + d_mlp, # z0 + d_mlp, # x0 + self.intermediate_size, # gate + self.conv_dim, # hidden_states + self.num_heads # dt + ] + + z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1) + + # Split hidden states into components + x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1) + B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1) + + # Process based on sequence length + if seq_len > 1 and cache is None: + y, next_state = self.process_long_sequence( + x_conv, B, C, dt, + mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size)) + ) + else: + # Reshape for single token processing + x_conv = x_conv.reshape(batch_size, -1, self.head_dim) + B = B.reshape(batch_size, self.num_heads, -1) + C = C.reshape(batch_size, self.num_heads, -1) + y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache) + + if cache is not None: + cache.update(conv_state, next_state) + + # Apply normalization and final projection + y = self.norm(y) * gate + return self.out_proj(y) + + +class ResidualBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.mixer = Mamba2Mixer(args) + self.norm = nn.RMSNorm(args.hidden_size) + + def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array: + return self.mixer(self.norm(x), cache) + x + +class Mamba2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] + self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__(self, x: mx.array, cache=None) -> mx.array: + x = self.embeddings(x) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, layer_cache in zip(self.layers, cache): + x = layer(x, layer_cache) + + return self.norm_f(x) + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.backbone = Mamba2Model(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__(self, inputs: mx.array, cache=None) -> mx.array: + B, T = inputs.shape + + x = self.backbone(inputs, cache) + + if self.args.tie_word_embeddings: + logits = self.backbone.embeddings.as_linear(x) + else: + logits = self.lm_head(x) + + return logits + + def make_cache(self, batch_size=1): + return [ + Mamba2Cache( + batch_size=batch_size, + conv_dim=self.args.intermediate_size + 2 * self.args.n_groups * self.args.state_size, + kernel_size=self.args.conv_kernel, + num_heads=self.args.num_heads, + head_dim=self.args.head_dim, + state_size=self.args.state_size + ) + for _ in range(len(self.backbone.layers)) + ] + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.ndim == 3: + weights[k] = v.moveaxis(2, 1) + return weights + + @property + def layers(self): + return self.backbone.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 7c78ee91..a44663fb 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -143,6 +143,13 @@ def to_lora(layer): "mixer.out_proj", ] ) + elif model.model_type == "mamba2": + keys = set( + [ + "mixer.in_proj", + "mixer.out_proj", + ] + ) else: raise ValueError(f"Lora does not support {model.model_type}")