Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I convert a siglip only and not a siglip based LLM? #16

Closed
aliencaocao opened this issue Jun 6, 2024 · 18 comments
Closed

Can I convert a siglip only and not a siglip based LLM? #16

aliencaocao opened this issue Jun 6, 2024 · 18 comments

Comments

@aliencaocao
Copy link

Based on supported models, conversion of SigLIP to TRT is already done, but can I use it standalone for a SigLIP model only?

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 6, 2024 via email

@aliencaocao
Copy link
Author

Do i have to convert siglip to TRT myself, or can nanollm handle the conversion if I supply a pytorch model from HF?

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 6, 2024 via email

@aliencaocao
Copy link
Author

Thanks!

@aliencaocao
Copy link
Author

Sorry another question, where is the code that does the inference on TRT engine? I am trying to write my own inference script but faced some issues. I thought you are experienced with TRT inference so i want to try my luck and see how you implemented it.

This is the script I have for context

class TRTInference:
    def __init__(self, engine_path, output_names_mapping: dict = None, fp16=True, verbose=False):
        self.engine_path = engine_path
        self.output_names_mapping = output_names_mapping or {}
        self.logger = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger(trt.Logger.INFO)
        self.engine = None
        self.load_engine()
        assert self.engine is not None, 'Failed to load TensorRT engine.'

        self.context = self.engine.create_execution_context()

        self.input_names = self.get_input_names()
        self.output_names = self.get_output_names()

        self.dtype = np.float16 if fp16 else np.float32

    def load_engine(self):
        with open(self.engine_path, 'rb') as f, trt.Runtime(self.logger) as runtime:
            self.engine = runtime.deserialize_cuda_engine(f.read())

    def get_input_names(self):
        names = []
        for _, name in enumerate(self.engine):
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
                names.append(name)
        return names

    def get_output_names(self):
        names = []
        for _, name in enumerate(self.engine):
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
                names.append(name)
        return names

    def get_bindings(self):
        Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
        bindings = OrderedDict()

        for i, name in enumerate(self.engine):
            shape = self.engine.get_tensor_shape(name)
            shape = tuple(shape)
            if any(s < 0 for s in shape):  # set dynamic axis to be 1
                shape = tuple(1 if s < 0 else s for s in shape)
            dtype = trt.nptype(self.engine.get_tensor_dtype(name))
            if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
                data = np.random.randn(*shape).astype(dtype)
                ptr = cuda.mem_alloc(data.nbytes)
                bindings[name] = Binding(name, dtype, shape, data, ptr)
            else:
                data = cuda.pagelocked_empty(trt.volume(shape), dtype)
                ptr = cuda.mem_alloc(data.nbytes)
                bindings[name] = Binding(name, dtype, shape, data, ptr)

        return bindings

    def __call__(self, blob):
        self.stream = cuda.Stream()
        input_bs = next(v for v in blob.values()).shape[0]
        blob = {n: np.ascontiguousarray(v) for n, v in blob.items()}
        for n in self.input_names:
            input_shape = self.engine.get_tensor_shape(n)
            input_shape = (input_bs, *input_shape[1:])
            self.context.set_input_shape(n, input_shape)
            d_input = cuda.mem_alloc(np.random.randn(*input_shape).astype(self.dtype).nbytes)
            self.context.set_tensor_address(n, int(d_input))
            cuda.memcpy_htod_async(d_input, blob[n], self.stream)

        output_shape_list = []
        for n in self.output_names:
            output_shape = self.engine.get_tensor_shape(n)
            output_shape = (input_bs, *output_shape[1:])
            d_output = int(cuda.mem_alloc(np.random.randn(*output_shape).astype(self.dtype).nbytes))
            self.context.set_tensor_address(n, d_output)
            output_shape_list.append(output_shape)

        assert self.context.all_binding_shapes_specified

        self.context.execute_async_v3(stream_handle=self.stream.handle)

        outputs = {}
        for n, output_shape in zip(self.output_names, output_shape_list):
            output = np.empty(output_shape, dtype=self.dtype)
            cuda.memcpy_dtoh_async(output, self.context.get_tensor_address(n), self.stream)
            outputs[self.output_names_mapping.get(n, n)] = output

        self.stream.synchronize()

        return outputs

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 8, 2024

@aliencaocao I am using torch2trt from @jaybdub which gives you a model object with the same interface as pytorch so you can just transparently replace your pytorch model with the TRT version. Which is what I do here:

self.model = trt_model

If I wanted to make something that didn't depend on pytorch, then yea I would use the TRT API directly, or onnxruntime if I wanted to be able to fallback to cuDNN.

@aliencaocao
Copy link
Author

Thanks for pointing out. Do you know how can I load a .engine directly using torch2trt? I have exported it separately to use my own shape configs.

@aliencaocao
Copy link
Author

Actually i managed to port it over but im getting some
[06/08/2024-12:35:15] [TRT] [E] 1: [convBaseRunner.cpp::nvinfer1::rt::task::CaskConvBaseRunner::execute::300] Error Code 1: Cask (Cask convolution execution)
This is probably not related but have you seen this before?

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 8, 2024

@aliencaocao not specifically as related to this (being CLIP/SigLIP)

@aliencaocao
Copy link
Author

Thank you. I will try to convert using torch2trt myself and see how.

I tried to use nanoLLM but it is missing the text model of siglip which I also need, else I would have used it straight away.

Thank you for your help!

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 8, 2024

OK thanks, let me know if you find that TRT or torch2trt can build/run the SigLIP text encoder, that would be good for me to add too.

@aliencaocao

This comment was marked as resolved.

@aliencaocao
Copy link
Author

Got it to work after NVIDIA-AI-IOT/torch2trt#932

@aliencaocao
Copy link
Author

aliencaocao commented Jun 8, 2024

Full conversion script, requires pip install git+https://github.com/aliencaocao/torch2trt.git@patch-1

import torch
from torch2trt import torch2trt
from transformers import SiglipModel

model = SiglipModel.from_pretrained('siglip ckpt path/HF id', torch_dtype=torch.float16).cuda().eval()

text_model = model.text_model
vision_model = model.vision_model

dummy = torch.ones(1, 3, 384, 384, dtype=torch.float16, device='cuda')
model_trt = torch2trt(vision_model, [dummy], fp16_mode=True, min_shapes=[(1, 3, 384, 384)], opt_shapes=[(4, 3, 384, 384)], max_shapes=[(10, 3, 384, 384)], use_onnx=True)  # change the shapes here - for me I know max bs is 10 but you should probably not limit
y = vision_model(dummy).pooler_output
y_trt = model_trt(dummy)['pooler_output']
torch.save(model_trt.state_dict(), 'vision_trt.pth')
print('Vision model exported. atol:', torch.max(torch.abs(y - y_trt)))

dummy = torch.ones(1, 64, dtype=torch.long, device='cuda')  # siglip tokenizer should always pad to 64
model_trt = torch2trt(text_model, [dummy], fp16_mode=True, min_shapes=[(1, 64)], opt_shapes=[(1, 64)], max_shapes=[(1, 64)], use_onnx=True)
y = text_model(dummy).pooler_output
y_trt = model_trt(dummy)['pooler_output']
torch.save(model_trt.state_dict(), 'text_trt.pth')
print('Text model exported. atol:', torch.max(torch.abs(y - y_trt)))

Remove all the torch.float16 if you want to be in fp32

Then can just load via TRTModule.load_state_dict(torch.load(os.path.join(clip_path, 'vision_trt.pth')))

To get the logits:

            vision_input = self.clip_image_processor(images=boxes, return_tensors='pt').to(self.device)
            text_inputs = self.clip_tokenizer(im_captions_templated, return_tensors='pt', padding='max_length', truncation=True).to(self.device)  # processor wont work since it dont pad to max_length=64
            vision_input = vision_input['pixel_values'].type(torch.float16)
            image_feat = self.clip_vision_trt(vision_input)['pooler_output']
            text_feat = self.clip_text_trt(text_inputs['input_ids'])['pooler_output']
            image_feat /= image_feat.norm(p=2, dim=-1, keepdim=True)
            text_feat /= text_feat.norm(p=2, dim=-1, keepdim=True)
            scores = image_feat @ text_feat.T * self.clip_logit_scale_exp + self.clip_logit_bias
            scores = scores.squeeze(-1).tolist()  # sigmoid not needed as it dont change the ranking

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 8, 2024

Thanks @aliencaocao , that's great! I'm going to unify the various CLIP/SigLIP implementations I have floating around between NanoLLM/NanoDB with support for the text encoder in TRT alongside the vision encoder 👍

@dusty-nv
Copy link
Owner

dusty-nv commented Jun 9, 2024

@aliencaocao did you get the text encoder working in TRT with real token ID's? The output delta is small when the input_id's are all 1, but when I actually tokenize a real string it doesn't work. Which version of TensorRT are you using?

edit: I also tried using the attention_mask from the tokenizer

@aliencaocao
Copy link
Author

TRT 10.1. Yes all tokenid, i am using it in on over 20k samples already

@aliencaocao
Copy link
Author

aliencaocao commented Jun 10, 2024

One very important thing is do not use HF Processor but instead HF tokenzier. The processor does not pad input to 64 ('max-length') token which is what siglip has been trained on. also double check that padding token id is 1.
This differs from other CLIPs where the padding would be the max length in a batch and not always 64/whatever context length.

Also note that I hard coded batchsize of text to be 1, and you may have to change it to a more dynamic one.

Attenmask is not needed.

Also make sure you are using the right logit_scale_exp and logit_bias from the original HF model, they changes for fine-tuned models. And don't forget to .exp() on model.logit_scale. I usually just precalculates it since it is not exported to TRT

And to get the same output as HF pipeline (python/js), you need to add a torch.sigmoid(scores) before the last line. For purpose of image reranking that's not needed since sigmoid wont change the order.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants