Skip to content

Commit

Permalink
Improving EasyDeLState, AttentionModule, and Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 28, 2024
1 parent 7e1be21 commit 1e5390a
Show file tree
Hide file tree
Showing 139 changed files with 596 additions and 762 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ __pycache__/
.idea
*.so
env.ipynb
env.ipynb
env.py
pallas_env.py
test_EasyDeLState.py
io.py
Expand Down
10 changes: 10 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
sphinx:
configuration: docs/conf.py
python:
install:
- requirements: docs/requirements.txt
94 changes: 80 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,47 @@ APIs are changing ...
> versioned
> API and efficient.
## EasyDeLState A Snapshot of Your EasyDeL Model
The `EasyDeLState` class acts like a comprehensive container that holds all the essential information about your EasyDeL
model at a given point in time. Think of it as a snapshot of your model. It includes:
* **Training Progress:**
* `step`: Tracks the current training step.
* **Model Itself:**
* `module`: Holds the actual instance of your EasyDeL model.
* `module_config`: Stores the model's configuration settings.
* `module_config_args`: Keeps track of arguments used to create the configuration (useful for reloading).
* `apply_fn`: References the core function that applies your model to data.
* **Learned Parameters:**
* `params`: Contains the trained weights and biases of your model.
* **Optimizer Information:**
* `tx`: Stores the optimizer you're using to update the model's parameters (e.g., AdamW).
* `opt_state`: Keeps track of the optimizer's internal state (this is important for things like momentum in
optimizers).
* `tx_init`: Remembers the initial settings used to create the optimizer (again, for reloading purposes).
* **Additional Settings:**
* `hyperparameters`: Provides a flexible place to store other hyperparameters related to your model or training
process.
**Key Capabilities of EasyDeLState:**
* **Initialization (`create`)**: Lets you create a brand new `EasyDeLState` to start training.
* **Loading (`load`, `load_state`, `from_pretrained`)**: Enables you to reload a saved model from a checkpoint file or
even a pre-trained model from a repository like Hugging Face Hub.
* **Saving (`save_state`)**: Allows you to save your model's current state, including its parameters and optimizer
state.
* **Optimizer Management (`apply_gradients`, `free_opt_state`, `init_opt_state`)**: Provides methods for updating the
model's parameters using gradients, releasing optimizer memory, and re-initializing the optimizer if needed.
* **Sharding (`shard_params`)**: Helps you distribute your model's parameters efficiently across multiple devices (
important for training large models).
* **PyTorch Conversion (`to_pytorch`)**: Gives you a way to convert your EasyDeL model to its PyTorch equivalent.
**In Essence:**
`EasyDeLState` streamlines the process of managing, saving, loading, and even converting your EasyDeL models. It ensures
that you can easily work with your models and maintain consistency throughout your machine learning workflow.
## Supervised Fine-Tuning with EasyDeL
EasyDeL supports both DPO and SFT Trainers, so dealing with LLMs in jax is a lot easier right now
Expand Down Expand Up @@ -630,20 +671,45 @@ model = model.half() # it's a huggingface model now
`EasyDeLState` have a general use you can use it everywhere in easydel for example for a stand-alone model
, serve, fine-tuning and many other features, it's up to you to test how creative you are 😇.

## Flash Attention and Splash Attention Are Here 🥵

here's a simple example about how can you use Flash Attention in EasyDeL

```python
# Config is built in config for every model (EasyDeLPretrainedConfig)
config.add_basic_configurations(
attn_mechanism="flash", # Any supported Attention Mechanism
block_b=1,
block_q=512,
block_k=512,
block_k_major=512
)
```
## AttentionModule: A Versatile Attention Mechanism Factory

The `AttentionModule` class is designed to simplify the creation and execution of different attention mechanisms within
your EasyDeL models. It provides a unified interface for working with various attention types, allowing you to easily
switch between them and experiment with different configurations.

**Key Features:**

* **Mechanism Selection:** The `attn_mechanism` argument lets you choose the specific attention algorithm you want to
use (e.g., "vanilla," "flash," "splash," "ring," "cudnn").
* **Sharding and Partitioning:** The class supports advanced JAX sharding techniques to distribute attention
computations across multiple devices for efficient processing of large models. It handles partitioning of query, key,
value, bias, and attention weight matrices using `PartitionSpec`.
* **Blockwise Attention:** Enables the use of blockwise attention for increased memory efficiency, especially with long
sequences.
* **Caching Support:** Facilitates the use of attention caching to speed up inference and generation tasks.
* **Dropout and Determinism:** Allows for applying dropout to attention weights and controlling the deterministic
behavior of the attention computation.
* **Testing Utility:** Provides a `test_attentions` method to compare different attention mechanisms in terms of
accuracy, gradient stability, and computation time.

**How it Works:**

1. **Initialization:**
- During initialization, you provide the desired `attn_mechanism`, JAX `mesh` for sharding, scaling
factor (`sm_scale`), number of attention heads, head dimensions, and other configuration parameters.
- The class automatically sets default values for many parameters based on the chosen attention mechanism and the
provided EasyDeL configuration (`base_module_class`).
2. **Calling the Module:**
- When you call the `AttentionModule` object, you pass in the query, key, and value states, along with optional
parameters like attention masks, biases, and causal flags.
- The module internally selects the appropriate attention function based on the specified `attn_mechanism`.
- It performs any necessary sharding and partitioning based on the configured partition specifications.
- The attention computation is executed, and the attention outputs (and optionally attention weights) are returned.

**Advantages:**

* **Flexibility:** Allows you to easily switch between different attention mechanisms without major code changes.
* **Efficiency:** Supports advanced JAX sharding for distributed computation, enabling the handling of large models.

_Flash Attention works on TPU with ease but for gpu there are still some improvements in process._

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
149 changes: 40 additions & 109 deletions docs/EasyStateExample.md
Original file line number Diff line number Diff line change
@@ -1,109 +1,40 @@
## EasyDeLState

EasyDeLState is a cool feature in easydel and have a lot of options like
storing `Model Parameters`, _Optimizer State, Model Config, Model Type, Optimizer and Scheduler Configs_

Let see and examples of using EasyDeLState

### Fine-tuning

Fine-tuning from a previous State or a new state

```python
from easydel import (
AutoEasyDeLConfig,
EasyDeLState
)
from transformers import AutoTokenizer
from jax import numpy as jnp, lax
import jax

huggingface_model_repo_id = "REPO_ID"
checkpoint_name = "CKPT_NAME"

state = EasyDeLState.from_pretrained(
pretrained_model_name_or_path=huggingface_model_repo_id,
filename=checkpoint_name,
optimizer="adamw",
scheduler="none",
tx_init=None,
device=jax.devices('cpu')[0], # Offload Device
dtype=jnp.bfloat16,
param_dtype=jnp.bfloat16,
precision=lax.Precision("fastest"),
sharding_axis_dims=(1, -1, 1, 1),
sharding_axis_names=("dp", "fsdp", "tp", "sp"),
query_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
key_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
value_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
bias_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), None, None, None),
attention_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
shard_attention_computation=False,
input_shape=(1, 1),
backend=None,
init_optimizer_state=False,
free_optimizer_state=True,
verbose=True,
state_shard_fns=None,
)

config = AutoEasyDeLConfig.from_pretrained(
huggingface_model_repo_id
)

tokenizer = AutoTokenizer.from_pretrained(
huggingface_model_repo_id,
trust_remote_code=True
)

max_length = config.max_position_embeddings

configs_to_initialize_model_class = {
'config': config,
'dtype': jnp.bfloat16,
'param_dtype': jnp.bfloat16,
'input_shape': (8, 8)
}
```

`EasyDeLState` also has `.load_state()` and `.save_state()` with some other usable options like `.free_opt_state()`
which
free optimizer state or `.shard_params()` which shard parameters you can read docs in order to find out more about these
options.

### Converting to Huggingface and Pytorch

Let see how you can convert a EasyDeLMistral Model to Huggingface Pytorch Mistral Model from a trained State

```python

from transformers import MistralForCausalLM
from easydel import (
AutoEasyDeLConfig,
EasyDeLState,
easystate_to_huggingface_model
)
import jax

huggingface_model_repo_id = "REPO_ID"

config = AutoEasyDeLConfig.from_pretrained(
huggingface_model_repo_id
)
with jax.default_device(jax.devices("cpu")[0]):
model = easystate_to_huggingface_model(
state=EasyDeLState.load_state(
"PATH_TO_CKPT",
input_shape=(8, 2048)
), # You can Pass EasyDeLState here
base_huggingface_module=MistralForCausalLM,
config=config,
)

model = model.half() # it's a huggingface model now
```

### Other Use Cases

`EasyDeLState` have a general use you can use it everywhere in easydel for example for a stand-alone model
, serve, fine-tuning and many other features, it's up to you to test how creative you are 😇.
**EasyDeLState: A Snapshot of Your EasyDeL Model**

The `EasyDeLState` class acts like a comprehensive container that holds all the essential information about your EasyDeL
model at a given point in time. Think of it as a snapshot of your model. It includes:

* **Training Progress:**
* `step`: Tracks the current training step.
* **Model Itself:**
* `module`: Holds the actual instance of your EasyDeL model.
* `module_config`: Stores the model's configuration settings.
* `module_config_args`: Keeps track of arguments used to create the configuration (useful for reloading).
* `apply_fn`: References the core function that applies your model to data.
* **Learned Parameters:**
* `params`: Contains the trained weights and biases of your model.
* **Optimizer Information:**
* `tx`: Stores the optimizer you're using to update the model's parameters (e.g., AdamW).
* `opt_state`: Keeps track of the optimizer's internal state (this is important for things like momentum in
optimizers).
* `tx_init`: Remembers the initial settings used to create the optimizer (again, for reloading purposes).
* **Additional Settings:**
* `hyperparameters`: Provides a flexible place to store other hyperparameters related to your model or training
process.

**Key Capabilities of EasyDeLState:**

* **Initialization (`create`)**: Lets you create a brand new `EasyDeLState` to start training.
* **Loading (`load`, `load_state`, `from_pretrained`)**: Enables you to reload a saved model from a checkpoint file or
even a pre-trained model from a repository like Hugging Face Hub.
* **Saving (`save_state`)**: Allows you to save your model's current state, including its parameters and optimizer
state.
* **Optimizer Management (`apply_gradients`, `free_opt_state`, `init_opt_state`)**: Provides methods for updating the
model's parameters using gradients, releasing optimizer memory, and re-initializing the optimizer if needed.
* **Sharding (`shard_params`)**: Helps you distribute your model's parameters efficiently across multiple devices (
important for training large models).
* **PyTorch Conversion (`to_pytorch`)**: Gives you a way to convert your EasyDeL model to its PyTorch equivalent.

**In Essence:**

`EasyDeLState` streamlines the process of managing, saving, loading, and even converting your EasyDeL models. It ensures
that you can easily work with your models and maintain consistency throughout your machine learning workflow.
File renamed without changes.
File renamed without changes.
File renamed without changes.
20 changes: 20 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
33 changes: 33 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import sys

sys.path.insert(0, os.path.abspath(".."))

project = "EasyDeL"
copyright = "2023, Erfan Zare Chavoshi - EasyDeL"
author = "Erfan Zare Chavoshi"

extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.intersphinx",
"sphinx_autodoc_typehints",
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

intersphinx_mapping = {
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"pytorch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
}

html_theme = "sphinx_book_theme"
html_static_path = ["_static"]
html_css_files = [
"custom.css",
]

source_suffix = [".rst", ".md", ".ipynb"]
2 changes: 0 additions & 2 deletions docs/generated-cli-cli.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-cli-train-cl_train_cli.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-data_preprocessing-data_processor.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-etils-auto_tx.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-etils-configs.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-etils-easystate.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-etils-errors.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-etils-etils.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-_blockwise_attention.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-_ring_attention.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-_vanilla_attention.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-arctic-arctic_configuration.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-arctic-modelling_arctic_flax.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-attention_module.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-auto_easydel_model.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/generated-modules-cohere-cohere_configuration.md

This file was deleted.

Loading

0 comments on commit 1e5390a

Please sign in to comment.