Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated quantization to support more OS and architectures #118

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
60b37c1
test status updated
Apr 24, 2024
68db81a
Merge branch 'main' of https://github.com/arjbingly/Capstone_5
sanchitvj Apr 26, 2024
81c8dc4
Merge branch 'main' of https://github.com/arjbingly/Capstone_5
sanchitvj Apr 26, 2024
08bd61f
Merge branch 'main' of https://github.com/arjbingly/Capstone_5
sanchitvj Apr 29, 2024
9b1af78
quantization modified
sanchitvj Apr 30, 2024
f8923a1
added inference in quantization
sanchitvj Apr 30, 2024
f7a0fb8
Merge remote-tracking branch 'origin/main' into quantize
sanchitvj May 1, 2024
cf05b50
added some tests
sanchitvj May 1, 2024
a235393
ruff & type checked, all tests passed
sanchitvj May 1, 2024
118b2ff
Update pyproject.toml
sanchitvj May 1, 2024
82c21b4
Update branch_Jenkinsfile
sanchitvj May 1, 2024
b06280a
Update branch_Jenkinsfile
sanchitvj May 1, 2024
783919e
Update branch_Jenkinsfile
sanchitvj May 1, 2024
afd150e
added mypy type-requests
sanchitvj May 1, 2024
12cc461
Change min python version to 3.10
arjbingly May 2, 2024
b386cf1
added exception for root path
sanchitvj May 2, 2024
2a8c71c
Syntax error for multiple exceptions
arjbingly May 2, 2024
fa6757d
Parse yes/no
arjbingly May 2, 2024
bc9a48d
gated repo exception handling
sanchitvj May 2, 2024
26037b8
Error handling output_dir in quantize_model
arjbingly May 2, 2024
bb6d0c8
Huggingface-cli login response handling
arjbingly May 2, 2024
5847a3a
HuggingFace url resolver
arjbingly May 3, 2024
305801c
modified url resolver
sanchitvj May 4, 2024
c4bc51e
quantize test passed
sanchitvj May 4, 2024
89ca88c
lower cased the system and arch
sanchitvj May 7, 2024
9ab317a
support for AMD architecture
sanchitvj May 7, 2024
688bdd1
Update get_started.llms.rst
sanchitvj May 7, 2024
5a570f6
quantize compatible for windows
sanchitvj May 7, 2024
6f8bc90
Merge branch 'quantize' of https://github.com/arjbingly/Capstone_5 in…
sanchitvj May 7, 2024
78f1d9b
Update get_started.llms.rst
sanchitvj May 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pipeline {
steps {
withPythonEnv(PYTHONPATH){
sh 'pip install mypy'
sh 'python3 -m pip install types-requests'
catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE'){
sh 'python3 -m mypy -p src.grag --junit-xml mypy-report.xml'
}
Expand Down
1 change: 1 addition & 0 deletions ci/branch_Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pipeline {
steps {
withPythonEnv(PYTHONPATH){
sh 'pip install mypy'
sh 'python3 -m pip install types-requests'
catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE'){
sh 'python3 -m mypy -p src.grag --junit-xml mypy-report.xml'
}
Expand Down
6 changes: 3 additions & 3 deletions config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ n_ctx : 6000
n_gpu_layers : -1
# The number of layers to put on the GPU. Mixtral-18, gemma-20
std_out : True
base_dir : ${root:root_path}/models
;base_dir : ${root:root_path}/models

[chroma_client]
host : localhost
Expand Down Expand Up @@ -64,5 +64,5 @@ env_path : ${root:root_path}/.env
[root]
root_path : /home/ubuntu/volume_2k/Capstone_5

[quantize]
llama_cpp_path : ${root:root_path}
;[quantize]
;llama_cpp_path : ${root:root_path}
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "grag"
dynamic = ["version"]
description = 'A simple package for implementing RAG'
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.10"
license = { file = 'LICENSE' }
keywords = ["RAG", "Retrieval Augmented Generation", "LLM", "retrieval", "quantization"]
authors = [
Expand All @@ -17,8 +17,6 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand Down Expand Up @@ -46,7 +44,7 @@ dependencies = [
"bitsandbytes>=0.42.0",
"accelerate>=0.28.0",
"poppler-utils>=0.1.0",
"tesseract>=0.1.3"
"tesseract>=0.1.3",
]

[project.optional-dependencies]
Expand Down
10 changes: 8 additions & 2 deletions src/docs/get_started.llms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@ After running the above command, user will be prompted with the following:

2. Input the **model path**:

* If user wants to download a model from `HuggingFace <https://huggingface.co/models>`_, the user should provide the repository path from HuggingFace.
* If user wants to download a model from `HuggingFace <https://huggingface.co/models>`_, the user should provide the repository path or URL from HuggingFace.

* If the user has the model downloaded locally, then user will be instructed to copy the model and input the name of the model directory.

3. Finally, the user will be prompted to enter **quantization** settings (recommended Q5_K_M or Q4_K_M, etc.). For more details, check `llama.cpp/examples/quantize/quantize.cpp <https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19>`_.
3. The user will be asked where to put the quantized model otherwise it will go in the directory where you downloaded model repository.

4. Finally, the user will be prompted to enter **quantization** settings (recommended Q5_K_M or Q4_K_M, etc.). For more details, check `llama.cpp/examples/quantize/quantize.cpp <https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19>`_.

5. Optionally, user can inference the quantized model with the next prompt. This inference will be on CPU so it takes time if model is large one.

Note: Windows users have to use WSL, and follow linux guidelines for quantizing models.
2 changes: 1 addition & 1 deletion src/grag/components/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
device_map: str = "auto",
task: str = "text-generation",
max_new_tokens: str = "1024",
temperature: Union[str, int] = 0.1,
temperature: Union[str, float] = 0.1,
n_batch: Union[str, int] = 1024,
n_ctx: Union[str, int] = 6000,
n_gpu_layers: Union[str, int] = -1,
Expand Down
95 changes: 67 additions & 28 deletions src/grag/quantize/quantize.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,91 @@
"""Interactive file for quantizing models."""

import platform
import sys
from pathlib import Path

from grag.components.utils import get_config
from grag.quantize.utils import (
building_llamacpp,
download_release_asset,
fetch_model_repo,
get_asset_download_url,
get_llamacpp_repo,
inference_quantized_model,
quantize_model,
repo_id_resolver,
)

config = get_config()
root_path = Path(config["quantize"]["llama_cpp_path"])

if __name__ == "__main__":
user_input = input(
"Enter the path to the llama_cpp cloned repo, or where you'd like to clone it. Press Enter to use the default config path: "
).strip()
"Enter the path which you want to download all the source files. Press Enter to use the default path: ").strip()

if user_input != "":
if user_input == "":
try:
root_path = Path(config["quantize"]["llama_cpp_path"])
print(f'Using {root_path} from config.ini')
except (KeyError, TypeError):
root_path = Path('./grag-quantize')
print(f'Using {root_path}, default.')
else:
root_path = Path(user_input)

res = get_llamacpp_repo(root_path)
get_llamacpp_repo(destination_folder=root_path)
os_name = str(platform.system()).lower()
architecture = str(platform.machine()).lower()
asset_name_pattern = 'bin'
match os_name, architecture:
case ('darwin', 'x86_64'):
asset_name_pattern += '-macos-x64'
case ('darwin', 'arm64'):
asset_name_pattern += '-macos-arm64'
case ('windows', 'x86_64'):
asset_name_pattern += '-win-arm64-x64'
case ('windows', 'arm64'):
asset_name_pattern += '-win-arm64-x64'
case ('windows', 'amd64'):
asset_name_pattern += '-win-arm64-x64'
case ('linux', 'x86_64'):
asset_name_pattern += '-ubuntu-x64'
case _:
raise ValueError(f"{os_name=}, {architecture=} is not supported by llama.cpp releases.")

if "Already up to date." in str(res.stdout):
print("Repository is already up to date. Skipping build.")
else:
print("Updates found. Starting build...")
building_llamacpp(root_path)

response = (
input("Do you want us to download the model? (y/n) [Enter for yes]: ")
.strip()
.lower()
)
if response == "n":
print("Please copy the model folder to 'llama.cpp/models/' folder.")
_ = input("Enter if you have already copied the model:")
model_dir = Path(input("Enter the model directory name: "))
elif response == "y" or response == "":
download_url = get_asset_download_url(asset_name_pattern)
if download_url:
download_release_asset(download_url, root_path)

response = input("Do you want us to download the model? (yes[y]/no[n]) [Enter for yes]: ").strip().lower()
if response == '':
response = 'yes'
if response.lower()[0] == "n":
model_dir = Path(input("Enter path to the model directory: "))
elif response.lower()[0] == "y":
repo_id = input(
"Please enter the repo_id for the model (you can check on https://huggingface.co/models): "
"Please enter the repo_id or the url for the model (you can check on https://huggingface.co/models): "
).strip()
fetch_model_repo(repo_id, root_path)
# model_dir = repo_id.split('/')[1]
model_dir = root_path / "llama.cpp" / "models" / repo_id.split("/")[1]
if repo_id == "":
raise ValueError("Repo ID you entered was empty. Please enter the repo_id for the model.")
repo_id = repo_id_resolver(repo_id)
model_dir = fetch_model_repo(repo_id, root_path / 'models')
else:
raise ValueError("Please enter either 'yes', 'y' or 'no', 'n'.")

sys.stdin.flush()

output_dir = input(
f"Enter path where you want to save the quantized model, else the following path will be used [{model_dir}]: ").strip()
quantization = input(
"Enter quantization, recommended - Q5_K_M or Q4_K_M for more check https://github.com/ggerganov/llama.cpp/blob/master/examples/quantize/quantize.cpp#L19 : "
)
quantize_model(model_dir, quantization, root_path)
).strip()

target_path, quantized_model_file = quantize_model(model_dir, quantization, root_path, output_dir)

inference = input(
"Do you want to inference the quantized model to check if quantization is successful? Warning: It takes time as it inferences on CPU. (y/n) [Enter for yes]: ").strip().lower()
if response == '':
response = 'yes'
if response.lower()[0] == "y":
inference_quantized_model(target_path, quantized_model_file)
else:
print("Model quantized, but not tested.")
Loading