Skip to content

dmis-lab/Monet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Monet: Mixture of Monosemantic Experts for Transformers

arXiv Models Demo code License

Introduction

Monet presents a novel approach to enhancing mechanistic interpretability in large language models (LLMs) through an innovative Sparse Mixture-of-Experts (SMoE) architecture. By directly incorporating sparse dictionary learning into end-to-end pretraining, Monet addresses the fundamental challenge of polysemanticity - where individual neurons respond to multiple unrelated concepts - while maintaining model performance.

✨Key Highlights

  • 📈 Scalable Expert Architecture: Monet introduces parameter-efficient expert decomposition methods that enable scaling to 262,144 experts per layer while ensuring total parameters scale proportionally to the square root of expert count.
  • 📊 Monosemantic Experts: Through fine-grained expert specialization, Monet achieves monosemantic experts that demonstrate mutual exclusivity of knowledge, allowing transparent observation of model behavior and parametric knowledge.
  • 🛠️ Robust Knowledge Control: The architecture enables precise manipulation of domain-specific knowledge, language capabilities, and toxicity mitigation without compromising general performance.

Why Monet?

Unlike traditional approaches using post-hoc reconstruction (like Sparse Autoencoders), Monet integrates interpretability directly into its architecture. This enables both transparent understanding of model internals and fundamental behavior control. By scaling monosemantic experts, Monet paves the way for more transparent and controllable language models.

News

  • 2024-12-06: Released Monet: Mixture of Monosemantic Experts for Transformers on arXiv, with GitHub, models, and demo.

Model Checkpoints

Base Models

Model Dataset #Params #Tokens Checkpoint Demo
Monet-VD FineWeb-Edu 850M 100BT 🤗monet-vd-850M-100BT-hf
1.4B 100BT 🤗monet-vd-1.4B-100BT-hf 🔍Viewer
4.1B 100BT 🤗monet-vd-4.1B-100BT-hf
StarCoderData 1.4B 100BT 🤗codemonet-vd-1.4B-100BT-hf 🔍Viewer
Monet-HD FineWeb-Edu 850M 100BT 🤗monet-hd-850M-100BT-hf
1.4B 100BT 🤗monet-hd-1.4B-100BT-hf
4.1B 100BT 🤗monet-hd-4.1B-100BT-hf

Instruction-Tuned Models

Model Purpose Recipe #Params Checkpoint
Monet-VD Chat Completion SmolLM 1.4B 🤗monet-vd-1.4B-100BT-chat-hf
Vision-Language Model LLaVA 1.6B 🤗visionmonet-vd-1.4B-100BT-hf

Quickstart

You can explore the core implementation of Monet in modeling_monet.py. We've made it easy to use Monet by including our custom code in the 🤗Hugging Face model zoo. Simply set trust_remote_code=True when loading the models through the Transformers library.

Text Generation

from transformers import pipeline

model_name = "MonetLLM/monet-vd-1.4B-100BT-hf"
pipe = pipeline(
    "text-generation",
    model_name,
    tokenizer=AutoTokenizer.from_pretrained(model_name),
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
print(pipe("The key to life is", max_new_tokens=20, do_sample=True)[0]["generated_text"])

Output:

<s> The key to life is learning how to live creatively. The question is: how do we do that, and what will

Code Generation

from transformers import pipeline

model_name = "MonetLLM/codemonet-vd-1.4B-100BT-hf"
pipe = pipeline(
    "text-generation",
    model_name,
    tokenizer=AutoTokenizer.from_pretrained(model_name),
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

text = '''
def print_len(x: str):
    """For a given string x, print the length of x."""
'''
print(pipe(text, max_new_tokens=10)[0]["generated_text"].split("\n\n")[0])

Output:

<s>
def print_len(x: str):
    """For a given string x, print the length of x."""
    print(len(x))

Chat Completion

from transformers import pipeline

model_name = "MonetLLM/codemonet-vd-1.4B-100BT-chat-hf"
pipe = pipeline(
    "text-generation",
    model_name,
    tokenizer=AutoTokenizer.from_pretrained(model_name),
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

text = tokenizer.apply_chat_template(
    [{"role": "user", "content": "Hi! How are you?"}],
    add_generation_prompt=True,
    tokenize=False,
)
print(pipe(text, max_new_tokens=30, do_sample=True)[0]["generated_text"])

Output:

<s>[INST] Hi! How are you? [/INST] I'm good, thanks! How can I help you today? </s>

Using vLLM

For enhanced inference performance, Monet can be integrated with the vLLM engine. Note that Monet requires manual registration with vLLM's ModelRegistry before initialization. The custom implementation is provided in modeling_monet_vllm.py.

from vllm import LLM, ModelRegistry, SamplingParams
from modeling_monet_vllm import MonetForCausalLM

# Register Monet architecture with vLLM
ModelRegistry.register_model("MonetForCausalLM", MonetForCausalLM)

model = LLM(
    "MonetLLM/monet-vd-1.4B-100BT-hf",
    trust_remote_code=True,
    dtype="bfloat16",
    gpu_memory_utilization=0.8
)
sampling_params = SamplingParams(max_tokens=20, temperature=1.0)
print(model.generate("The key to life is", sampling_params)[0].outputs[0].text)

Output:

 what you’re born with. If you think that you don’t have the same control and

Get Expert Routing Probabilities

Based on expert routing probabilities, Monet enables mechanistic interpretability by understanding which sparse features are activated to which token. Following the standard MoE approach, you can obtain expert routing probabilities for all layers by setting output_router_probs=True. The example below demonstrates how to compute and analyze the expert activation patterns:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "MonetLLM/monet-vd-1.4B-100BT-hf",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("MonetLLM/monet-vd-1.4B-100BT-hf")

inputs = tokenizer("City and County of San Francisco", return_tensors="pt")
outputs = model(**inputs.to(model.device), output_router_probs=True)

# Get full expert routing probabilities: [batch_size, seq_len, moe_heads, moe_experts**2]
g1, g2 = outputs.router_probs[0][0], outputs.router_probs[0][1]
g = torch.einsum("bthi,bthj->bthij", g1, g2).flatten(-2)
print(g.shape)

# Print number of activated experts per token.
for token, routing in zip(inputs.input_ids.squeeze(0), g.squeeze(0)):
    token = tokenizer.decode(token).ljust(16, " ")
    expert_indices = (routing.sum(0) > 1e-2).argwhere().squeeze(-1)
    print(f"Token: {token} Activated Experts: {len(expert_indices)}")

Output:

torch.Size([1, 7, 8, 262144])
Token: <s>              Activated Experts: 62
Token: City             Activated Experts: 60
Token: and              Activated Experts: 16
Token: County           Activated Experts: 102
Token: of               Activated Experts: 11
Token: San              Activated Experts: 70
Token: Francisco        Activated Experts: 67

Citation

Please cite related papers/blogs using this BibTeX if you find this useful for your research and applications.

@article{park2024monet,
      title={{Monet: Mixture of Monosemantic Experts for Transformers}}, 
      author={Jungwoo Park and Young Jin Ahn and Kee-Eung Kim and Jaewoo Kang},
      journal={arXiv preprint arXiv:2404.05567},
      year={2024}
}