-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathhubconf.py
79 lines (62 loc) · 2.69 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
dependencies = [
"torch",
"compressai",
"clip",
"tqdm",
"numpy",
] # dependencies required for loading a model
import os
import torch
from hub import ClipCompressor as _ClipCompressor
PATH = "https://github.com/YannDubs/lossyless/releases/download/v1.0/beta{beta:0.0e}_factorized_rate.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# TODO: add JIT once https://github.com/InterDigitalInc/CompressAI/issues/72 is resolved
# possible issue with JIT is that needs version 1.7.1 from pytorch for CLIP
def clip_compressor_b005(device=DEVICE, **kwargs):
ckpt_path = PATH.format(beta=0.05)
pretrained_state_dict = torch.hub.load_state_dict_from_url(
ckpt_path, progress=False
)
compressor = _ClipCompressor(
pretrained_state_dict=pretrained_state_dict, device=device, **kwargs
)
return compressor, compressor.preprocess
def clip_compressor_b001(device=DEVICE, **kwargs):
ckpt_path = PATH.format(beta=0.01)
pretrained_state_dict = torch.hub.load_state_dict_from_url(
ckpt_path, progress=False
)
compressor = _ClipCompressor(
pretrained_state_dict=pretrained_state_dict, device=device, **kwargs
)
return compressor, compressor.preprocess
def clip_compressor_b01(device=DEVICE, **kwargs):
ckpt_path = PATH.format(beta=0.1)
pretrained_state_dict = torch.hub.load_state_dict_from_url(
ckpt_path, progress=False
)
compressor = _ClipCompressor(
pretrained_state_dict=pretrained_state_dict, device=device, **kwargs
)
return compressor, compressor.preprocess
DOCSTRING = """
Load invariant CLIP compressor with beta={beta:.0e} (beta proportional to compression not like in paper).
Parameters
----------
device : str
Device on which to load the model.
Return
------
compressor : nn.Module
Pytorch module that when called as `compressor(X)` on a batch of image, will return
decompressed representations. Use `compressor.compress(X)` to get a batch of compressed
representations (in bytes). To save a compressed torch dataset to file use
`compressor.compress_dataset(dataset,file)` and `dataset = compressor.decompress_dataset(file)`.
For more information check the docstrings of all functions of the module and the examples below.
transform : callable
Transforms that can be used directly in place of torchvision transform. It will resize the
image to (3,224,224), apply clip normalization and convert it to tensor.
"""
clip_compressor_b005.__doc__ = DOCSTRING.format(beta=0.05)
clip_compressor_b001.__doc__ = DOCSTRING.format(beta=0.01)
clip_compressor_b01.__doc__ = DOCSTRING.format(beta=0.1)