Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit loading of models at different precision at load time for sentence_transformers #331

Open
dawu415 opened this issue Aug 2, 2024 · 2 comments

Comments

@dawu415
Copy link

dawu415 commented Aug 2, 2024

Feature request

Pass in torch_dtype in model_kwargs, as supported by sentence_transformers when specifying dtype in the infinity_emb v2 cli when InferenceEngine type is torch.

This would allow the loading of the Transformer model at a lower precision at load time instead of post-loading, which could cause an OOM error.

Post loading quantization, e.g. self.half(), would still be needed since it appears the non-transformer Pytorch modules in some models are still in fp32 and would cause issues downstream where matrix computations later would fail due to mixed type computation.

Motivation

While the current code quantizes the precision post loading a model, the issue is that if we're loading full 32-bit float models against a GPU that has a small amount of memory, it could fail to load since the total size of the model would exceed the GPUs maximum memory limit. A use case of this would be seen in small multi-instance GPUs e.g. deploying a 10Gb GPU instance in an NVIDIA A100 in MIG mode.

By specifying the precision at load time, we would be able to load a model without OOM errors and successfully use the model.

Your contribution

Yes. Happy to submit a PR.

Current tests on my end utilize the following code modifications in the init function of the SentenceTransformerPatched class.

        if engine_args.engine == InferenceEngine.torch and \
           engine_args.device in [Device.auto, Device.cuda] and \
           engine_args.dtype == Dtype.float16:
            model_kwargs["torch_dtype"] = torch.float16

but would need further work for better support/handling of other types.
Some thoughts on what can be done for each dtype per the torch attributes here: https://pytorch.org/docs/stable/tensor_attributes.html

auto -> (leave blank)
float32 -> torch.float
float16 -> torch.half
float8 -> (leave blank) ?
int8 -> (leave blank) ?

The 8-bit options were left blank since I noticed that quantization is performed later for these types - unsure if there is a better approach here.

Copy link
Contributor

greptile-apps bot commented Aug 2, 2024

To permit loading of models at different precision at load time for sentence_transformers, follow these steps:

  1. Modify the SentenceTransformerPatched class initialization to include torch_dtype in model_kwargs based on engine_args.dtype.
if engine_args.engine == InferenceEngine.torch and 
   engine_args.device in [Device.auto, Device.cuda]:
    if engine_args.dtype == Dtype.float16:
        model_kwargs["torch_dtype"] = torch.float16
    elif engine_args.dtype == Dtype.float32:
        model_kwargs["torch_dtype"] = torch.float
    # Add other dtype handling as needed
  1. Ensure that the model_kwargs is passed correctly when initializing the model.

Refer to the PyTorch tensor attributes documentation for handling other dtypes: https://pytorch.org/docs/stable/tensor_attributes.html

References

/.github/ISSUE_TEMPLATE/new-model-addition.yml
/libs/infinity_emb/tests/end_to_end/test_optimum_embedding.py
/docs/docs/index.md
/docs/assets/openapi.json
/libs/infinity_emb/tests/conftest.py
/README.md
/docs/docs/deploy.md
/.github/ISSUE_TEMPLATE/bug-report.yml
/.github/ISSUE_TEMPLATE
/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py
/libs/infinity_emb/infinity_emb/transformer/quantization/interface.py
/docs/docs
/libs/infinity_emb
/libs/infinity_emb/infinity_emb/transformer/embedder/dummytransformer.py
/libs/infinity_emb/infinity_emb/inference/select_model.py
/docs/benchmarks/simple_app.py
/libs/infinity_emb/tests/unit_test/inference/test_select_model.py
/libs/infinity_emb/tests/unit_test/transformer/quantization
/libs/infinity_emb/infinity_emb/fastapi_schemas
/docs/docs/python_engine.md
/libs/infinity_emb/infinity_emb/transformer/classifier

/libs/infinity_emb/Makefile
/libs/infinity_emb/tests/end_to_end/test_torch_classify.py
/docs

About Greptile

This response provides a starting point for your research, not a precise solution.

Help us improve! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

Ask Greptile · Edit Issue Bot Settings

@michaelfeil
Copy link
Owner

So your idea is to safe memory at loading time before performing model.half()? Potentially a method that could be combined with device. Need to wrap my head around, maybe add a function at Dtype

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants