Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding-support-for-mamba2 #1009

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
49b9fc1
Create mamba2.py
Goekdeniz-Guelmez Oct 2, 2024
409ddc4
updating ACKNOWLEDGMENTS.md file
Goekdeniz-Guelmez Oct 2, 2024
264ba43
update trainer/lora.py and adding DepthWiseConv1d because mlx 0.18.0 …
Goekdeniz-Guelmez Oct 2, 2024
52d6ca0
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 4, 2024
4e1236c
fixing loading the model
Goekdeniz-Guelmez Oct 11, 2024
9c075a7
Merge branch 'adding-support-for-mamba2' of https://github.com/Goekde…
Goekdeniz-Guelmez Oct 11, 2024
6f88dd5
quick clean up and fix
Goekdeniz-Guelmez Oct 11, 2024
00ba27f
adding debug statements
Goekdeniz-Guelmez Oct 11, 2024
3f1c1dd
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 14, 2024
855fcc4
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 16, 2024
8073cb4
adding debug statements (somehiw generating only goes through the fis…
Goekdeniz-Guelmez Oct 16, 2024
181d6ab
Merge branch 'adding-support-for-mamba2' of https://github.com/Goekde…
Goekdeniz-Guelmez Oct 16, 2024
cd036cc
fix generation works too (almost)
Goekdeniz-Guelmez Oct 16, 2024
4ab5139
quick save
Goekdeniz-Guelmez Oct 20, 2024
ab4cf1d
generation works but outputs gibberish
Goekdeniz-Guelmez Oct 20, 2024
c1634ce
still generating gibberish
Goekdeniz-Guelmez Oct 20, 2024
0ef73f3
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 21, 2024
b9c57cd
generation works! trying training now
Goekdeniz-Guelmez Oct 22, 2024
5326d93
Merge branch 'adding-support-for-mamba2' of https://github.com/Goekde…
Goekdeniz-Guelmez Oct 22, 2024
758597e
adding multi token input and correct cache handling in ssm step
Goekdeniz-Guelmez Oct 22, 2024
55485b9
update
Goekdeniz-Guelmez Oct 22, 2024
e43a2ab
not working, incorrect handling with cache probably
Goekdeniz-Guelmez Oct 22, 2024
9ab581d
notes
Goekdeniz-Guelmez Oct 22, 2024
a677638
inference works but is hella slow
Goekdeniz-Guelmez Oct 22, 2024
7c8849e
update
Goekdeniz-Guelmez Oct 24, 2024
3b70708
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 25, 2024
ffc7ab0
Merge branch 'ml-explore:main' into adding-support-for-mamba2
Goekdeniz-Guelmez Oct 30, 2024
58b448d
updates
Goekdeniz-Guelmez Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba version 1`, `Mamba version 2` and support for `full-fine-tuning`.
13 changes: 13 additions & 0 deletions llms/mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def trim(self, n):
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
self.offset = 0

def __setitem__(self, idx, value):
self.cache[idx] = value
Expand All @@ -338,3 +339,15 @@ def state(self):
@state.setter
def state(self, v):
self.cache = v


class Mamba2Cache:
def __init__(self, batch_size, conv_dim, kernel_size, num_heads, head_dim, state_size):
self.conv_states = mx.zeros((batch_size, conv_dim, kernel_size - 1))
self.ssm_states = mx.zeros((batch_size, num_heads, head_dim, state_size))
self.seqlen_offset = 0

def update(self, new_conv_state, new_ssm_state):
self.conv_states = new_conv_state
self.ssm_states = new_ssm_state
self.seqlen_offset += 1
Loading