-
-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Prabhu Subramanian <[email protected]>
- Loading branch information
Showing
6 changed files
with
135 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
dataset/ | ||
prabhuat/ | ||
HuggingFaceTB/* | ||
adapters/ | ||
out.txt | ||
unsloth/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#!/usr/bin/env bash | ||
set -e | ||
TUNING_TOOL=mlx | ||
BASE_MODEL=unsloth/phi-4 | ||
BASE_MODEL_MLX=${BASE_MODEL}-${TUNING_TOOL} | ||
HF_ORG=prabhuat | ||
TOOL_BASE_MODEL=cdx1 | ||
ITERATION=1000 | ||
NUM_LAYERS=16 | ||
|
||
FUSED_MODEL=${HF_ORG}/${TOOL_BASE_MODEL}-${TUNING_TOOL} | ||
QUANT_MODEL_8BIT=${HF_ORG}/${TOOL_BASE_MODEL}-${TUNING_TOOL}-8bit | ||
QUANT_MODEL_6BIT=${HF_ORG}/${TOOL_BASE_MODEL}-${TUNING_TOOL}-6bit | ||
QUANT_MODEL_4BIT=${HF_ORG}/${TOOL_BASE_MODEL}-${TUNING_TOOL}-4bit | ||
|
||
### mlx-lm needs train.jsonl and valid.jsonl | ||
rm -rf dataset prabhuat adapters | ||
mkdir -p dataset prabhuat | ||
|
||
node prepare.js dataset | ||
node validator.js dataset | ||
|
||
echo "Test base model" | ||
mlx_lm.generate --model ${BASE_MODEL} --prompt "Tell me about cdxgen" --temp 0.05 | ||
|
||
# We first convert from HF to mlx | ||
rm -rf ${BASE_MODEL} ${BASE_MODEL_MLX} | ||
mlx_lm.convert --hf-path ${BASE_MODEL} --mlx-path ${BASE_MODEL_MLX} | ||
|
||
echo "LoRA fine-tuning ${BASE_MODEL_MLX} with cdx1 dataset" | ||
mlx_lm.lora --model ${BASE_MODEL_MLX} --train --data dataset --fine-tune-type dora --batch-size 1 --num-layers ${NUM_LAYERS} --iters ${ITERATION} --grad-checkpoint | ||
|
||
echo "Fuse model to ${FUSED_MODEL} using the cdx1 adapters" | ||
rm -rf ${FUSED_MODEL} | ||
mlx_lm.fuse --model ${BASE_MODEL_MLX} --adapter-path adapters --hf-path ${FUSED_MODEL} --save-path ${FUSED_MODEL} | ||
|
||
echo "Test fused model" | ||
mlx_lm.generate --model ${FUSED_MODEL} --prompt "Tell me about cdxgen" --temp 0.05 | ||
|
||
echo "Create quantized models" | ||
rm -rf ${QUANT_MODEL_8BIT} | ||
mlx_lm.convert --hf-path ${FUSED_MODEL} --mlx-path ${QUANT_MODEL_8BIT} -q --q-bits 8 --dtype bfloat16 | ||
echo "${QUANT_MODEL_8BIT} created successfully." | ||
rm -rf ${QUANT_MODEL_6BIT} | ||
mlx_lm.convert --hf-path ${FUSED_MODEL} --mlx-path ${QUANT_MODEL_6BIT} -q --q-bits 6 --dtype bfloat16 | ||
echo "${QUANT_MODEL_6BIT} created successfully." | ||
rm -rf ${QUANT_MODEL_4BIT} | ||
mlx_lm.convert --hf-path ${FUSED_MODEL} --mlx-path ${QUANT_MODEL_4BIT} -q --q-bits 4 --dtype bfloat16 | ||
echo "${QUANT_MODEL_4BIT} created successfully." | ||
|
||
rm -rf dataset adapters |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import { mkdirSync, readFileSync, writeFileSync } from "node:fs"; | ||
import { join } from "node:path"; | ||
import process from "node:process"; | ||
|
||
import { dirNameStr, getAllFiles } from "../../lib/helpers/utils.js"; | ||
const jsonlFiles = getAllFiles(dirNameStr, "**/*.jsonl", { | ||
exclude: ["**/{train, valid}.jsonl"], | ||
}); | ||
let datasetDir = "dataset"; | ||
const argv = process.argv.slice(2); | ||
if (argv.length > 1) { | ||
datasetDir = argv[1]; | ||
} | ||
|
||
const TRAIN_FILE = join(datasetDir, "train.jsonl"); | ||
const VALID_FILE = join(datasetDir, "valid.jsonl"); | ||
const trainData = []; | ||
const validData = []; | ||
|
||
for (const jf of jsonlFiles) { | ||
const lines = readFileSync(jf, "utf-8"); | ||
trainData.push(lines); | ||
if (jf.includes("readme") || jf.includes("cdxgen") || jf.includes("cli")) { | ||
validData.push(lines); | ||
} | ||
} | ||
|
||
mkdirSync(datasetDir, { recursive: true }); | ||
if (trainData.length) { | ||
writeFileSync(TRAIN_FILE, trainData.join("\n")); | ||
} | ||
if (validData.length) { | ||
writeFileSync(VALID_FILE, validData.join("\n")); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters