Skip to content

Commit

Permalink
improving documentations
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 28, 2024
1 parent 88e9bb9 commit 143bb3a
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 93 deletions.
14 changes: 7 additions & 7 deletions docs/Install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-re
```


## Installing GO

#### Note this Library needs golang to run (for some tracking stuff on TPU/GPU/CPU)

#### Ubuntu GO installation
Installing GO
------
Note this Library needs golang to run (for some tracking stuff on TPU/GPU/CPU)
Ubuntu GO installation
------

```shell
sudo apt-get update && apt-get upgrade -y
sudo apt-get install golang -y
```

#### Manjaro/Arch GO installation
Manjaro/Arch GO installation
------

```shell
sudo pacman -Syyuu go
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
AttentionModule
========
# AttentionModule: A Versatile Attention Mechanism Factory

The `AttentionModule` class is designed to simplify the creation and execution of different attention mechanisms within
your EasyDeL models. It provides a unified interface for working with various attention types, allowing you to easily
switch between them and experiment with different configurations.

**Key Features:**

* **Mechanism Selection:** The `attn_mechanism` argument lets you choose the specific attention algorithm you want to
use (e.g., "vanilla," "flash," "splash," "ring," "cudnn").
* **Sharding and Partitioning:** The class supports advanced JAX sharding techniques to distribute attention
computations across multiple devices for efficient processing of large models. It handles partitioning of query, key,
value, bias, and attention weight matrices using `PartitionSpec`.
* **Blockwise Attention:** Enables the use of blockwise attention for increased memory efficiency, especially with long
sequences.
* **Caching Support:** Facilitates the use of attention caching to speed up inference and generation tasks.
* **Dropout and Determinism:** Allows for applying dropout to attention weights and controlling the deterministic
behavior of the attention computation.
* **Testing Utility:** Provides a `test_attentions` method to compare different attention mechanisms in terms of
accuracy, gradient stability, and computation time.

**How it Works:**

1. **Initialization:**
- During initialization, you provide the desired `attn_mechanism`, JAX `mesh` for sharding, scaling
factor (`sm_scale`), number of attention heads, head dimensions, and other configuration parameters.
- The class automatically sets default values for many parameters based on the chosen attention mechanism and the
provided EasyDeL configuration (`base_module_class`).
2. **Calling the Module:**
- When you call the `AttentionModule` object, you pass in the query, key, and value states, along with optional
parameters like attention masks, biases, and causal flags.
- The module internally selects the appropriate attention function based on the specified `attn_mechanism`.
- It performs any necessary sharding and partitioning based on the configured partition specifications.
- The attention computation is executed, and the attention outputs (and optionally attention weights) are returned.

**Advantages:**

* **Flexibility:** Allows you to easily switch between different attention mechanisms without major code changes.
* **Efficiency:** Supports advanced JAX sharding for distributed computation, enabling the handling of large models.

what is `AttentionModule`
--------
AttentionModule is a EasyDeL module that can perform attention operation with different strategies to help user achieve
the best possible performance and numerical stability, here are some strategies supported right now.

Expand All @@ -14,9 +49,28 @@ the best possible performance and numerical stability, here are some strategies
6. Local Ring attention via "local_ring"
7. Wise Ring attention via "wise_ring"
8. sharded Attention with shard map known as "sharded_vanilla"
9. Other Attention modules might be added you can check source code for that..

## Testing which Attention Module works best

in order to test which attention module in what axis dims works best for you you can run
```python
from easydel import AttentionModule

print(
AttentionModule.test_attentions(
axis_dims=(1, 1, 1, -1),
sequence_length=128 * 8,
num_attention_heads=32,
num_key_value_heads=32,
chunk_size=128,

)
)
```

## Example of Using Flash Attention on TPU

Example of Using Flash Attention on TPU
--------
```python
import jax
import flax.linen.attention as flt
Expand Down
3 changes: 1 addition & 2 deletions docs/data_processing.rst → docs/data_processing.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
Data Processing
======
# Data Processing

here in this case you will see an example data required by EasyDeL to pre-train or fine-tune models

Expand Down
4 changes: 2 additions & 2 deletions docs/finetuning_example.rst → docs/finetuning_example.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
FineTuning Causal Language Model 🥵
=====
# FineTuning Causal Language Model with EasyDeL

with using EasyDeL FineTuning LLM (CausalLanguageModels) are easy as much as possible with using Jax and Flax
and having the benefit of `TPUs` for the best speed here's a simple code to use in order to finetune your
own Model
Expand Down
12 changes: 4 additions & 8 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ EasyDeL is an open-source framework designed to enhance and streamline the train
With a primary focus on Jax/Flax, EasyDeL aims to provide convenient and effective solutions for training Flax/Jax
models on TPU/GPU for both serving and training purposes.

## Key Features

Key Features
---------
1. **Trainers**: EasyDeL offers a range of trainers, including DPOTrainer, ORPOTrainer, SFTTrainer, and VideoCLM
Trainer, tailored for specific training requirements.

Expand Down Expand Up @@ -79,8 +79,8 @@ the team with a particular toolset.
Hands on Code Kaggle Examples
---------------------------------------------------------------

1. for mindset of using EasyDeL CausalLanguageModelTrainer on kaggle, but you can do much more. CLMScript
2. SuperVised Finetuning with EasyDeL. SFTScript_
1. `script <https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example>`_ for mindset of using EasyDeL CausalLanguageModelTrainer on kaggle, but you can do much more.
2. `script <https://www.kaggle.com/code/citifer/easydel-sfttrainer-example>`_ SuperVised Finetuning with EasyDeL.

Citing EasyDeL 🥶
---------------------------------------------------------------
Expand Down Expand Up @@ -124,7 +124,3 @@ To cite this Project
attentionmodule_example
data_processing



.. _SFTScript https://www.kaggle.com/code/citifer/easydel-sfttrainer-example
.. _CLMScript https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
EasyDeLXRapTure for layer tuning and LoRA
---------
# EasyDeLXRapTure for layer tuning and LoRA

in case of using LoRA and applying that on the EasyDeL models there are some other things
that you might need to config on your own but a lot of things being handled by EasyDeL so let just jump into an example
for LoRA fine-tuning section and use _EasyDeLXRapTure_ in for mistral models with flash attention example
Expand Down Expand Up @@ -90,7 +90,6 @@ train_arguments = TrainArguments(
# What this does ? this will merge the lora parameters with the original model parameters and the end of training
)

def ultra_chat_prompting_sample(
data_chunk
):
Expand All @@ -117,7 +116,7 @@ tokenization_process = lambda data_chunk: tokenizer(
)

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_sample, num_proc=12)
dataset_train = dataset_train.map(
tokenization_process,
num_proc=12,
Expand Down
79 changes: 15 additions & 64 deletions docs/parameterquantization.rst → docs/parameterquantization.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
What's 8-bit quantization? How does it help ?
=======
# 8-bit EasyDeL Models

## What's 8-bit quantization? How does it help ?

Quantization in the context of deep learning is the process of constraining the number of bits that represent the
weights and biases of the model.

Expand All @@ -8,8 +10,8 @@ Weights and Biases numbers that we need in backpropagation.
In 8-bit quantization, each weight or bias is represented using only 8 bits as opposed to the typical 32 bits used in
single-precision floating-point format (float32).

Why does it use less GPU/TPU Memory?
---------
## Why does it use less GPU/TPU Memory?

The primary advantage of using 8-bit quantization is the reduction in model size and memory usage. Here's a simple
explanation:

Expand All @@ -35,8 +37,8 @@ To convert these to bytes (since memory is often measured in bytes):
- 8-bit integer would use ( 8/8 = 1 ) bytes.
- A 16-bit integer would use ( 16/8 = 2 ) bytes.

Example of Using Parameters Quantization in EasyDeL
---------
## Example of Using Parameters Quantization in EasyDeL

in case of serving models or using them with `JAX` The Easiest and the best way you can find
is EasyDeL (you can explore more if you want) you have 4 ways to use models

Expand All @@ -47,17 +49,14 @@ is EasyDeL (you can explore more if you want) you have 4 ways to use models

let assume we want to run a 7B model on only 12 GB of vram let just jump into codding

Using Quantized Model via generate Function
---------
## Using Quantized Model via generate Function

let assume we want to run `Qwen/Qwen1.5-7B-Chat`

```python
from jax import numpy as jnp
from easydel import AutoEasyDeLModelForCausalLM, create_generate_function
from easydel import AutoEasyDeLModelForCausalLM

from transformers import AutoTokenizer, GenerationConfig
import pickle
import torch

repo_id = "Qwen/Qwen1.5-7B-Chat"
Expand All @@ -77,58 +76,10 @@ model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
torch_dtype=torch.float16,
device_map="cpu" # this one will be passed to transformers.AutoModelForCausalLM
)
# Now params are loaded as 8-bit fjformer kernel

# params is now an 8 Bit pytree.
tokenizer = AutoTokenizer.from_pretrained(repo_id)
mesh = model.config.jax_mesh()
gen_fn = create_generate_function(
model,
GenerationConfig(
do_sample=True,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
temperature=0.2,
top_p=0.95,
top_k=10,
num_beams=1
),
{"params": params},
return_prediction_only=True
)
tokenizer.padding_side = "left"
encoded = tokenizer.apply_chat_template(
[{"role": "user", "content": "generate an story about stars"}],
return_tensors="np",
return_dict=True,
max_length=512,
padding="max_length",
add_generation_prompt=True
)
rep = 1 # in case that you are using fsdp instead of sequence sharing change this to your fsdp mesh shape
input_ids, attention_mask = encoded.input_ids.repeat(rep, 0), encoded.attention_mask.repeat(rep, 0)
with mesh:
response = gen_fn(
{"params": params},
input_ids,
attention_mask
)
response_string = tokenizer.decode(response[0], skip_special_tokens=True)
print(
f"Model Response:\n{response_string}"
)
# you want to save these quantized parameters for later?
pickle.dump((model, params, tokenizer), open("EasyDeL-Qwen7B-Chat", "wb"))
# And load that like this ;)
```

(model, params, tokenizer) = pickle.load(open("EasyDeL-Qwen7B-Chat", "wb"))
### TIP

```
you can do everything but training with 8Bit params in EasyDeL, and use pickle to saving and loading them

0 comments on commit 143bb3a

Please sign in to comment.