diff --git a/.github/workflows/full-test.yml b/.github/workflows/full-test.yml index c00b1853..f487a26b 100644 --- a/.github/workflows/full-test.yml +++ b/.github/workflows/full-test.yml @@ -12,7 +12,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest", "macos-13", "windows-latest"] - vers: [ {pt_ver: "1.6.0", tv_ver: "0.7.0"}, {pt_ver: "1.7.0", tv_ver: "0.8.1"}, {pt_ver: "1.8.0", tv_ver: "0.9.0"}, {pt_ver: "1.9.0", tv_ver: "0.10.0"}, {pt_ver: "1.10.0", tv_ver: "0.11.1"}, {pt_ver: "1.11.0", tv_ver: "0.12.0"}, {pt_ver: "1.12.0", tv_ver: "0.13.0"} ] + vers: [ {pt_ver: "1.10.0", tv_ver: "0.11.1"}, {pt_ver: "1.11.0", tv_ver: "0.12.0"}, {pt_ver: "1.12.0", tv_ver: "0.13.0"} ] include: - os: macos-latest vers: @@ -36,7 +36,7 @@ jobs: - uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - python-version: 3.8 + python-version: 3.9 - name: Install PyTorch env: PYTORCH_VER: ${{ matrix.vers.pt_ver }} @@ -45,19 +45,14 @@ jobs: if [ "$RUNNER_OS" == "macOS" ]; then if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then conda install pytorch::pytorch torchvision -c pytorch - elif [[ "$TORCHVISION_VER" == "0.9."* || "$TORCHVISION_VER" == "0.10."* ]]; then - conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER pillow=6 -c pytorch else conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER -c pytorch fi elif [ "$RUNNER_OS" == "Windows" ]; then if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then - conda install pytorch torchvision cpuonly pillow=6 -c pytorch - elif [[ "$TORCHVISION_VER" == "0.9."* || "$TORCHVISION_VER" == "0.10."* ]]; then - conda install pillow=6 -c conda-forge - conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly -c pytorch + conda install pytorch torchvision cpuonly pillow=8 -c pytorch else - conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly pillow=6 -c pytorch + conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly pillow=8 -c pytorch fi else if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then @@ -72,14 +67,9 @@ jobs: env: PYTORCH_VER: ${{ matrix.vers.pt_ver }} TORCHVISION_VER: ${{ matrix.vers.tv_ver }} - run: | - if [[ "$RUNNER_OS" == "Linux" && "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then - pip install black 'ruff<0.0.234' 'tensorflow<2.12' pytest scipy interval - else - pip install black 'ruff<0.0.234' tensorflow pytest scipy interval - fi + run: pip install black==22.3.0 ruff tensorflow pytest scipy interval - name: Lint checks - run: python -m ruff . + run: python -m ruff check . - name: Run tests run: | source activate.sh diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index 4c82f6ca..e885a3e4 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -12,7 +12,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest", "macos-13", "macos-latest", "windows-latest"] - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11"] defaults: run: shell: bash -l {0} diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 0b52e731..c557e686 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: ["ubuntu-latest", "macos-13", "windows-latest"] - vers: [ {pt_ver: "1.6.0", tv_ver: "0.7.0"}, {pt_ver: "latest", tv_ver: "latest"} ] + vers: [ {pt_ver: "1.10.0", tv_ver: "0.11.1"}, {pt_ver: "latest", tv_ver: "latest"} ] include: - os: macos-latest vers: @@ -30,7 +30,7 @@ jobs: - uses: conda-incubator/setup-miniconda@v3 with: auto-update-conda: true - python-version: 3.8 + python-version: 3.9 - name: Install PyTorch env: PYTORCH_VER: ${{ matrix.vers.pt_ver }} @@ -39,19 +39,14 @@ jobs: if [ "$RUNNER_OS" == "macOS" ]; then if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then conda install pytorch::pytorch torchvision -c pytorch - elif [[ "$TORCHVISION_VER" == "0.9."* || "$TORCHVISION_VER" == "0.10."* ]]; then - conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER pillow=6 -c pytorch else conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER -c pytorch fi elif [ "$RUNNER_OS" == "Windows" ]; then if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then - conda install pytorch torchvision cpuonly pillow=6 -c pytorch - elif [[ "$TORCHVISION_VER" == "0.9."* || "$TORCHVISION_VER" == "0.10."* ]]; then - conda install pillow=6 -c conda-forge - conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly -c pytorch + conda install pytorch torchvision cpuonly pillow=8 -c pytorch else - conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly pillow=6 -c pytorch + conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly pillow=8 -c pytorch fi else if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then @@ -66,16 +61,11 @@ jobs: env: PYTORCH_VER: ${{ matrix.vers.pt_ver }} TORCHVISION_VER: ${{ matrix.vers.tv_ver }} - run: | - if [[ "$RUNNER_OS" == "Linux" && "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then - pip install black 'ruff<0.0.234' 'tensorflow<2.12' scipy interval - else - pip install black 'ruff<0.0.234' tensorflow scipy interval - fi + run: pip install black==22.3.0 ruff tensorflow scipy interval - name: Lint checks - run: python -m ruff . + run: python -m ruff check . - name: Style checks - run: python -m black . + run: python -m black --check . - name: Run unit tests run: | cd tests @@ -109,7 +99,7 @@ jobs: - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true - python-version: 3.8 + python-version: 3.9 - name: Install PyTorch run: conda install pytorch torchvision cpuonly -c pytorch - name: Install TinyNeuralNetwork diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c7d6ebc..6b4eed1f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,6 +11,6 @@ repos: - id: black exclude: ^tinynn/converter/schemas - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.150 + rev: v0.7.0 hooks: - id: ruff diff --git a/README.md b/README.md index 199eb556..1b72991c 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ TinyNeuralNetwork is an efficient and easy-to-use deep learning model compressio ## Installation -Python >= 3.8, PyTorch >= 1.4( PyTorch >= 1.6 if quantization-aware training is involved, see [here](docs/quantization_support.md) for details ) +Python >= 3.9, PyTorch >= 1.10 ```shell # Install the TinyNeuralNetwork framework diff --git a/README_zh-CN.md b/README_zh-CN.md index 28a5056d..edbf0751 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -5,7 +5,7 @@ TinyNeuralNetwork是一个高效、易用的深度学习模型压缩框架。它 ## 安装 -python >= 3.8, pytorch >= 1.4(如果使用量化训练 pytorch >= 1.6,详细可见[这里](docs/quantization_support.md) ) +python >= 3.9, pytorch >= 1.10 ```shell # 安装TinyNeuralNetwork软件包 diff --git a/pyproject.toml b/pyproject.toml index 2bca1672..b196a1f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ exclude = [ # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" -target-version = "py36" +target-version = "py39" [tool.ruff.per-file-ignores] "tinynn/converter/schemas/torch/*.py" = ["E501"] @@ -86,6 +86,7 @@ target-version = "py36" "examples/*.py" = ["E402"] "__init__.py" = ["F401", "F403"] "tests/import_test.py" = ["F401"] +"tutorials/quantization/basic.ipynb" = ["F811", "F401"] [tool.ruff.mccabe] # Unlike Flake8, default to a complexity level of 10. diff --git a/requirements.txt b/requirements.txt index e61fe38e..b23a7923 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -numpy>=1.18.5,<=1.24.4; python_version < '3.9' -numpy>=1.18.5; python_version >= '3.9' +numpy>=1.18.5 PyYAML>=5.3.1 ruamel.yaml>=0.16.12 igraph>=0.9 diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index fd6a1d3d..06e53d85 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -565,7 +565,7 @@ def test_reduce_ops_single_dim(self): def model(x): res = func(x, dim=1) - return res if type(res) == torch.Tensor else res[0] + return res if type(res) is torch.Tensor else res[0] model_path = get_model_path() converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) @@ -599,7 +599,7 @@ def test_reduce_ops_single_dim_keepdim(self): def model(x): res = func(x, dim=1, keepdim=True) - return res if type(res) == torch.Tensor else res[0] + return res if type(res) is torch.Tensor else res[0] model_path = get_model_path() converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) diff --git a/tinynn/converter/operators/optimize.py b/tinynn/converter/operators/optimize.py index 9d0ac09b..f6f7e060 100644 --- a/tinynn/converter/operators/optimize.py +++ b/tinynn/converter/operators/optimize.py @@ -4661,7 +4661,7 @@ def elinimate_sequences( first_node = seq[0] last_node = seq[-1] - if type(skip_pred) == bool: + if type(skip_pred) is bool: skip = skip_pred elif skip_pred is not None: skip = skip_pred(seq) @@ -4669,13 +4669,13 @@ def elinimate_sequences( if skip: continue - if type(remove_first_pred) == bool: + if type(remove_first_pred) is bool: remove_first = remove_first_pred custom_data = None elif remove_first_pred is not None: remove_first, custom_data = remove_first_pred(seq) - if type(remove_last_pred) == bool: + if type(remove_last_pred) is bool: remove_last = remove_last_pred custom_data_last = None elif remove_last_pred is not None: diff --git a/tinynn/converter/operators/tflite/base.py b/tinynn/converter/operators/tflite/base.py index 2818df77..3fdf0e1c 100644 --- a/tinynn/converter/operators/tflite/base.py +++ b/tinynn/converter/operators/tflite/base.py @@ -186,7 +186,7 @@ def __init__( self.index = 0 self.is_variable = is_variable - if type(tensor) == FakeQuantTensor: + if type(tensor) is FakeQuantTensor: self.quantization = QuantizationParameters(tensor.scale, tensor.zero_point, tensor.dim) tensor = tensor.tensor @@ -195,7 +195,7 @@ def __init__( if type(tensor).__module__ == 'numpy': self.tensor = tensor - elif type(tensor) == torch.Tensor: + elif type(tensor) is torch.Tensor: assert tensor.is_contiguous, "Tensor should be contiguous" if tensor.dtype == torch.quint8: self.tensor = torch.int_repr(tensor.detach()).numpy() @@ -253,7 +253,7 @@ def __init__( self.quantization = QuantizationParameters(scales, zero_points, dim) else: self.tensor = tensor.detach().numpy() - elif type(tensor) == torch.Size: + elif type(tensor) is torch.Size: self.tensor = np.asarray(tensor, dtype='int32') elif type(tensor) in (tuple, list): self.tensor = np.asarray(tensor, dtype=dtype) @@ -390,7 +390,7 @@ def build(self, builder: flatbuffers.Builder) -> Offset: def create_offset_vector(builder: flatbuffers.Builder, prop: typing.Callable, vec: typing.Iterable): if type(vec) not in (tuple, list): assert False, "type of vec unexpected, expected: list or tuple" - elif type(vec) == tuple: + elif type(vec) is tuple: vec = list(vec) prop_name = prop.__name__ @@ -426,7 +426,7 @@ def create_numpy_array(builder: flatbuffers.Builder, prop: typing.Callable, vec: def create_string(builder: flatbuffers.Builder, prop: typing.Callable, val: str): - if type(val) != str: + if type(val) is not str: assert False, "type of val unexpected, expected: str" prop_name = prop.__name__ diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 70897770..9126f1e1 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -1561,7 +1561,7 @@ def parse(self, node, attrs, args, graph_converter): self.run(node) dim = self.input_tensors[1] - assert type(dim) == int + assert type(dim) is int if dim < 0: dim += self.input_tensors[0][0].ndim + 1 @@ -1619,7 +1619,7 @@ def parse(self, node, attrs, args, graph_converter): self.run(node) dim = self.input_tensors[1] - assert type(dim) == int + assert type(dim) is int if dim < 0: dim += self.input_tensors[0][0].ndim @@ -2067,8 +2067,8 @@ def parse(self, node, attrs, args, graph_converter): input_tensor = self.find_or_create_input(0, graph_converter) dim, index = self.input_tensors[1:] - assert type(dim) == int - assert type(index) == int + assert type(dim) is int + assert type(index) is int if dim < 0: dim += input_tensor.tensor.ndim @@ -2166,11 +2166,11 @@ def parse(self, node, attrs, args, graph_converter): self.parse_common(node, attrs, args, graph_converter) def parse_common(self, node, attrs, args, graph_converter): - if type(self) == ATenClampOperator: + if type(self) is ATenClampOperator: min_value, max_value = self.input_tensors[1:] - elif type(self) == ATenClampMinOperator: + elif type(self) is ATenClampMinOperator: min_value, max_value = self.input_tensors[1], None - elif type(self) == ATenClampMaxOperator: + elif type(self) is ATenClampMaxOperator: min_value, max_value = None, self.input_tensors[1] has_min = min_value is not None @@ -3808,7 +3808,7 @@ def parse(self, node, attrs, args, graph_converter): def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, out_idx=0): for i in (input_idx, other_idx): t = self.input_tensors[i] - if type(t) == torch.Tensor: + if type(t) is torch.Tensor: if t.dtype == torch.float64: self.input_tensors[i] = t.to(dtype=torch.float32) elif t.dtype == torch.int64: @@ -3826,7 +3826,7 @@ def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, ou input_tensor, mask_tensor = [self.find_or_create_input(i, graph_converter) for i in (input_idx, mask_idx)] ops = [] - if type(other) == torch.Tensor: + if type(other) is torch.Tensor: other_t = self.find_or_create_input(other_idx, graph_converter) if out.dtype != other.dtype: casted = other.clone().to(dtype=out.dtype) @@ -4491,4 +4491,4 @@ def parse(self, node, attrs, args, graph_converter): ops.append(tfl.TileOperator([actual_input, repeat_tensor], [outp])) for op in ops: - graph_converter.add_operator(op) \ No newline at end of file + graph_converter.add_operator(op) diff --git a/tinynn/converter/operators/torch/base.py b/tinynn/converter/operators/torch/base.py index 0f561e01..5a61afa1 100644 --- a/tinynn/converter/operators/torch/base.py +++ b/tinynn/converter/operators/torch/base.py @@ -190,7 +190,7 @@ def to_tfl_tensors( tfl_tensors = [] if has_buffers is None: has_buffers = [None] * len(tensors) - elif type(has_buffers) == bool: + elif type(has_buffers) is bool: has_buffers = [has_buffers] * len(tensors) assert len(names) == len(tensors) == len(has_buffers) for n, t, b in zip(names, tensors, has_buffers): @@ -491,7 +491,7 @@ def handle_padding(self, pad_h, pad_w, pad_op_index, ops, ceil_mode=False): input_size = [input_tensor.shape[2], input_tensor.shape[3]] if not all((i + 2 * p - k) % s == 0 for i, p, k, s in zip(input_size, padding, kernel_size, stride)): - assert type(ops[1]) == tfl.MaxPool2dOperator, 'ceil_mode=True for AvgPool not supported' + assert type(ops[1]) is tfl.MaxPool2dOperator, 'ceil_mode=True for AvgPool not supported' fill_nan = True ceil_pad = get_pool_ceil_padding(input_tensor, kernel_size, stride, padding) ceil_pad = list(np.add(ceil_pad, padding)) @@ -503,7 +503,7 @@ def handle_padding(self, pad_h, pad_w, pad_op_index, ops, ceil_mode=False): pad_input = ops[pad_op_index - 1].outputs[0] inputs = [pad_input, pad_tensor] - if type(ops[1]) == tfl.MaxPool2dOperator: + if type(ops[1]) is tfl.MaxPool2dOperator: constant_tensor = self.get_minimum_constant(pad_input) inputs.append(constant_tensor) pad_array = np.pad(pad_input.tensor, pad, constant_values=constant_tensor.tensor[0]) diff --git a/tinynn/converter/operators/torch/quantized.py b/tinynn/converter/operators/torch/quantized.py index b2459b94..16dc2f39 100644 --- a/tinynn/converter/operators/torch/quantized.py +++ b/tinynn/converter/operators/torch/quantized.py @@ -153,7 +153,7 @@ def parse(self, node, attrs, args, graph_converter): self.run(node) dim = self.input_tensors[1] - assert type(dim) == int + assert type(dim) is int if dim < 0: dim += self.input_tensors[0][0].ndim diff --git a/tinynn/graph/configs/gen_creation_funcs_yml.py b/tinynn/graph/configs/gen_creation_funcs_yml.py index 7cb7280f..83bd7536 100644 --- a/tinynn/graph/configs/gen_creation_funcs_yml.py +++ b/tinynn/graph/configs/gen_creation_funcs_yml.py @@ -17,7 +17,7 @@ if k in block_list: continue c = getattr(torch, k) - if inspect.isclass(c) and k.endswith('Tensor') and c.__bases__[0] == object: + if inspect.isclass(c) and k.endswith('Tensor') and c.__bases__[0] is object: print(k) final_dict['torch'].append(k) elif inspect.isbuiltin(c): diff --git a/tinynn/graph/modifier.py b/tinynn/graph/modifier.py index 6e2bbc31..932109db 100644 --- a/tinynn/graph/modifier.py +++ b/tinynn/graph/modifier.py @@ -1314,7 +1314,7 @@ def apply_mask(self, modifiers): args_parsed = self.node.module.args_parsed_origin if len(args_parsed) > 1: - if type(args_parsed[1]) == list: + if type(args_parsed[1]) is list: ch = [int(i) for i in args_parsed[1]] ch_new = [] @@ -2556,7 +2556,7 @@ def register_mask(self, modifiers, importance, sparsity): def create_channel_modifier(n): for key in CHANNEL_MODIFIERS.keys(): - if type(key) == str: + if type(key) is str: if n.kind() == key: return CHANNEL_MODIFIERS[key](n) elif isinstance(n.module, key): @@ -2611,7 +2611,7 @@ def calc_prune_idx_by_bn_variance( ignored_bn = set() for leaf in self.leaf: - if type(leaf.module()) != nn.BatchNorm2d: + if type(leaf.module()) is not nn.BatchNorm2d: continue while True: @@ -2621,7 +2621,7 @@ def calc_prune_idx_by_bn_variance( break if leaf in self.leaf: - if type(leaf.module()) != nn.BatchNorm2d: + if type(leaf.module()) is not nn.BatchNorm2d: continue ignored_bn.add(leaf) @@ -2629,7 +2629,7 @@ def calc_prune_idx_by_bn_variance( for leaf in self.leaf: if leaf in ignored_bn: continue - if type(leaf.module()) != nn.BatchNorm2d: + if type(leaf.module()) is not nn.BatchNorm2d: continue is_real_leaf = True diff --git a/tinynn/graph/quantization/fused_modules.py b/tinynn/graph/quantization/fused_modules.py index da4c5db5..0b21270b 100644 --- a/tinynn/graph/quantization/fused_modules.py +++ b/tinynn/graph/quantization/fused_modules.py @@ -13,7 +13,7 @@ class ConvTransposeBn2d(_FusedModule): def __init__(self, conv, bn): assert ( - type(conv) == nn.ConvTranspose2d and type(bn) == nn.BatchNorm2d + type(conv) is nn.ConvTranspose2d and type(bn) is nn.BatchNorm2d ), 'Incorrect types for input modules{}{}'.format(type(conv), type(bn)) super(ConvTransposeBn2d, self).__init__(conv, bn) diff --git a/tinynn/graph/quantization/qat_modules.py b/tinynn/graph/quantization/qat_modules.py index 83fc892e..1ea2fbce 100644 --- a/tinynn/graph/quantization/qat_modules.py +++ b/tinynn/graph/quantization/qat_modules.py @@ -95,12 +95,12 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' assert mod.qconfig, 'Input float module must have a valid qconfig' - if type(mod) == ConvReLU1d: + if type(mod) is ConvReLU1d: mod = mod[0] qconfig = mod.qconfig qat_conv = cls( @@ -224,7 +224,7 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' @@ -335,7 +335,7 @@ def from_float(cls, mod): Args: `mod` a float module, either produced by torch.quantization utilities or directly from user """ - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' @@ -566,7 +566,7 @@ def from_float(cls, mod): """ # The ignore is because _FLOAT_MODULE is a TypeVar here where the bound # has no __name__ (code is fine though) - assert type(mod) == cls._FLOAT_MODULE, ( + assert type(mod) is cls._FLOAT_MODULE, ( 'qat.' + cls.__name__ + '.from_float only works for ' + cls._FLOAT_MODULE.__name__ ) # type: ignore[attr-defined] assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined' diff --git a/tinynn/graph/quantization/quantizable/gru.py b/tinynn/graph/quantization/quantizable/gru.py index 016e61ea..135b4ffa 100644 --- a/tinynn/graph/quantization/quantizable/gru.py +++ b/tinynn/graph/quantization/quantizable/gru.py @@ -132,7 +132,7 @@ def from_params(cls, wi, wh, bi=None, bh=None): @classmethod def from_float(cls, other): - assert type(other) == cls._FLOAT_MODULE + assert type(other) is cls._FLOAT_MODULE assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'" observed = cls.from_params(other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh) observed.qconfig = other.qconfig diff --git a/tinynn/graph/quantization/quantizer.py b/tinynn/graph/quantization/quantizer.py index 03b12d32..c21e631f 100644 --- a/tinynn/graph/quantization/quantizer.py +++ b/tinynn/graph/quantization/quantizer.py @@ -584,7 +584,7 @@ def _qat_analysis(node: TraceNode, quantized: bool): log.debug(f"[QUANTIZED]{node.unique_name}:{quantized}") for i, n in enumerate(node.next_nodes): - if type(n.module) == TraceFunction: + if type(n.module) is TraceFunction: if n.kind() in ('shape', 'size', 'dtype', 'device'): continue if n.kind() == 'expand_as' and i > 0: @@ -647,7 +647,7 @@ def _is_params_in_module(node, custom_data): 'ConvTranspose2d', ) if is_known_mod and n.module.full_name == 'weight' and prev_node.quantized: - if next_node.type() == torch_q.QuantStub: + if next_node.type() is torch_q.QuantStub: mod = nn.Sequential(torch_q.DeQuantStub(), torch_q.QuantStub()) orig_mod = next_node.module next_node.module = mod @@ -1130,11 +1130,11 @@ def new_no_observer_set(): q_mod = new_mod[-1] elif isinstance(new_mod, torch_q.DeQuantStub): q_mod = new_mod - elif type(new_mod) != nn.Identity: + elif type(new_mod) is not nn.Identity: state = True else: is_prev_float_functional = ( - len(n.prev_nodes) > 1 and n.prev_nodes[0].type() == torch.nn.quantized.FloatFunctional + len(n.prev_nodes) > 1 and n.prev_nodes[0].type() is torch.nn.quantized.FloatFunctional ) if is_prev_float_functional: q_mod = getattr(n.prev_nodes[0].module, n.kind()) @@ -1294,7 +1294,7 @@ def _find_quantized_cat_nodes(node: TraceNode, custom_node): new_fq_count = 2 else: is_prev_float_functional = ( - len(n.prev_nodes) > 1 and n.prev_nodes[0].type() == torch.nn.quantized.FloatFunctional + len(n.prev_nodes) > 1 and n.prev_nodes[0].type() is torch.nn.quantized.FloatFunctional ) if n.type() == 'cat': mode = 'both' @@ -1528,7 +1528,7 @@ def rewrite_quantize_graph(self, graph: TraceGraph) -> None: break for n in graph.other_init_nodes: - if n.type() == nnq.FloatFunctional: + if n.type() is nnq.FloatFunctional: graph_quantized = True break @@ -1934,7 +1934,7 @@ def _is_convertible_node(node: TraceNode, custom_data): q.put(node.prev_nodes[0]) while not q.empty(): n = q.get() - if type(n.module) == TraceFunction: + if type(n.module) is TraceFunction: prev_aliases = n.module.get_aliases() if prev_aliases is not None: for pa in reversed(prev_aliases): @@ -2109,7 +2109,7 @@ def _is_add_relu_fusable_node(node: TraceNode, custom_data) -> bool: if not fuse: return False - if type(next_node.module) == TraceFunction: + if type(next_node.module) is TraceFunction: inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string else: inplace = getattr(next_node.module, 'inplace', False) @@ -2128,7 +2128,7 @@ def _is_add_relu_fusable_node(node: TraceNode, custom_data) -> bool: if last_order > node.forward_order: fuse = False break - if type(n.module) == TraceFunction and n.module.get_aliases(): + if type(n.module) is TraceFunction and n.module.get_aliases(): q.put(n.prev_nodes[0]) elif getattr(n.module, 'inplace', False): q.put(n.prev_nodes[0]) @@ -2143,10 +2143,10 @@ def _is_add_relu_fusable_node(node: TraceNode, custom_data) -> bool: func_type = kind is_class = False nodes_to_fuse = [node, next_node] - while next_node.type() == nn.Identity: + while next_node.type() is nn.Identity: next_node = next_node.next_nodes[0] nodes_to_fuse.append(next_node) - if type(next_node.module) == TraceFunction: + if type(next_node.module) is TraceFunction: inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string else: inplace = next_node.module.inplace @@ -2199,10 +2199,10 @@ def _parse_args(alpha=1.0, *args, **kwargs): # noqa: F811 elif node.module.kind == 'prelu': weight_t = node.prev_tensors[1] weight_node = node.prev_nodes[1] - if weight_node.type() == torch_q.QuantStub: + if weight_node.type() is torch_q.QuantStub: weight_node = weight_node.prev_nodes[0] - if weight_node.type() != ConstantNode or not weight_node.module.is_parameter: + if weight_node.type() is not ConstantNode or not weight_node.module.is_parameter: log.warning('Rewrite for F.prelu(x, buffer) to nn.PReLU is skipped as it changes the semantics') continue @@ -2427,7 +2427,7 @@ def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs): mod_fc = node_fc.module mod_bn = node_bn1d.module - assert type(mod_fc) == nn.Linear and type(mod_bn) == nn.BatchNorm1d, "the rewrite struct is\'t [fc-bn1d]" + assert type(mod_fc) is nn.Linear and type(mod_bn) is nn.BatchNorm1d, "the rewrite struct is\'t [fc-bn1d]" if len(node_fc.prev_tensors[0].shape) != 2: log.debug('the [fc-bn]\'s input dimension != 2') @@ -2498,7 +2498,7 @@ def _is_batch_norm_1d(node, custom_data): batch_norm_1d_nodes = graph.filter_forward_nodes(_is_batch_norm_1d) for idx, node in enumerate(batch_norm_1d_nodes): mod = node.module - if type(mod) == nn.BatchNorm1d: + if type(mod) is nn.BatchNorm1d: new_bn = torch.nn.BatchNorm2d( mod.num_features, mod.eps, @@ -2679,7 +2679,7 @@ def _is_not_quantizable(node, custom_data): unsupported_types = tuple( k for k, v in disable_quantize_op_list.items() - if type(k) != str + if type(k) is not str and k not in Q_MODULES_MAPPING and (v is None or LooseVersion(torch.__version__) < v) ) @@ -2695,7 +2695,7 @@ def _is_not_quantizable(node, custom_data): next_nodes = {n.unique_name: n for n in node.next_nodes}.values() for inner_idx, next_node in enumerate(next_nodes): prev_tensor_ptrs = [] - if type(next_node.module) == TraceFunction and next_node.module.is_property: + if type(next_node.module) is TraceFunction and next_node.module.is_property: continue for pt in next_node.prev_tensors: @@ -3096,7 +3096,7 @@ def forward_hook(module, input, output): device = get_module_device(model) - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: actual_input = [dummy_input] elif isinstance(dummy_input, (tuple, list)): actual_input = list(dummy_input) @@ -3106,7 +3106,7 @@ def forward_hook(module, input, output): for i in range(len(actual_input)): dummy_input = actual_input[i] - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: if dummy_input.device != device: actual_input[i] = dummy_input.to(device) @@ -4037,7 +4037,7 @@ def rewrite_quantize_graph(self, graph: TraceGraph) -> None: break for n in graph.other_init_nodes: - if n.type() == nnq.FloatFunctional: + if n.type() is nnq.FloatFunctional: graph_quantized = True break @@ -4062,7 +4062,7 @@ def _is_add_relu_node(node: TraceNode, custom_data): return ( cur_module.kind == 'add_relu' and len(node.prev_nodes) > 1 - and node.prev_nodes[0].type() == nnq.FloatFunctional + and node.prev_nodes[0].type() is nnq.FloatFunctional ) # Split FloatFunctional.add_relu to FloatFunctional.add and torch.relu @@ -4088,7 +4088,7 @@ def _is_add_mul_scalar_node(node: TraceNode, custom_data): return ( cur_module.kind in ('add_scalar', 'mul_scalar') and len(node.prev_nodes) > 1 - and node.prev_nodes[0].type() == nnq.FloatFunctional + and node.prev_nodes[0].type() is nnq.FloatFunctional ) add_mul_scalar_nodes = graph.filter_forward_nodes(_is_add_mul_scalar_node) @@ -4107,7 +4107,7 @@ def _is_add_mul_cat_node(node: TraceNode, custom_data): return ( cur_module.kind in ('add', 'mul', 'cat') and len(node.prev_nodes) > 1 - and node.prev_nodes[0].type() == nnq.FloatFunctional + and node.prev_nodes[0].type() is nnq.FloatFunctional ) add_mul_cat_nodes = graph.filter_forward_nodes(_is_add_mul_cat_node) @@ -4128,7 +4128,7 @@ def _is_add_mul_cat_node(node: TraceNode, custom_data): names = [n.unique_name for n in graph.other_init_nodes] for name in names: n = graph.nodes_map[name] - if n.type() == nnq.FloatFunctional: + if n.type() is nnq.FloatFunctional: graph.other_init_nodes.remove(n) del graph.nodes_map[n.unique_name] diff --git a/tinynn/graph/tracer.py b/tinynn/graph/tracer.py index 6deff9c9..77d28a2b 100644 --- a/tinynn/graph/tracer.py +++ b/tinynn/graph/tracer.py @@ -223,7 +223,7 @@ def __init__( if isinstance(module, nn.Module) and id(module) in cur_graph.module_original_name_dict: self.original_name = cur_graph.module_original_name_dict[id(module)] - elif type(module) == ConstantNode: + elif type(module) is ConstantNode: self.original_name = module.original_name else: self.original_name = self.unique_name @@ -246,21 +246,21 @@ def __init__( def type(self): """Returns the original name of the function or the type of the module""" - if type(self.module) == TraceFunction: + if type(self.module) is TraceFunction: return self.module.func_type return type(self.module) def kind(self): """Returns the kind of the function or the type of the module""" - if type(self.module) == TraceFunction: + if type(self.module) is TraceFunction: return self.module.kind return type(self.module) def is_class(self) -> bool: """Judges whether it is a class function or not""" - if type(self.module) == TraceFunction: + if type(self.module) is TraceFunction: return self.module.is_class else: return False @@ -282,7 +282,7 @@ def prev_node_unique_name(self, idx, inplace=False) -> str: getattr_on_module = False if ( isinstance(self.prev_nodes[idx].module, torch.nn.Module) - and type(self.module) == TraceFunction + and type(self.module) is TraceFunction and self.module.is_property and '.' not in self.module.full_name ): @@ -376,7 +376,7 @@ def _stringify_list(content) -> str: return f'[{inner_content}]' elif type(content) in (int, float, bool): return str(content) - elif type(content) == str: + elif type(content) is str: return f'"{content}"' # If `convert_to_parameter` is `True`, the content of the data will not be written inline. @@ -603,7 +603,7 @@ def _parse_args(arg): new_arg.append('None') elif a is Ellipsis: new_arg.append('...') - elif type(a) == slice: + elif type(a) is slice: t = (a.start, a.stop, a.step) parts = [] for x in t: @@ -753,7 +753,7 @@ def no_catch_handle_func(): def args_as_string(args, kwargs): """String representation of the args and the keyword args""" - cleaned_args = [f'"{arg}"' if type(arg) == str else str(arg) for arg in args] + cleaned_args = [f'"{arg}"' if type(arg) is str else str(arg) for arg in args] args_content = ', '.join(cleaned_args) kwargs_content = ', '.join((f'{k}="{v}"' if type(v) is str else f'{k}={v}' for k, v in kwargs.items())) args_connector = '' if args_content == '' or kwargs_content == '' else ', ' @@ -836,7 +836,7 @@ def new_getattr(obj, name): orig = result result = {} for k, v in orig.items(): - if type(v) == str and (not k.startswith('__') and not k.endswith('__')): + if type(v) is str and (not k.startswith('__') and not k.endswith('__')): result[k] = f'"{orig[k]}"' else: result[k] = v @@ -937,7 +937,7 @@ def new_getattr(obj, name): # then we connect it to the graph. # Otherwise, don't track it. old_result = None - if type(result) == torch.Size and isinstance(obj, torch.Tensor): + if type(result) is torch.Size and isinstance(obj, torch.Tensor): # Create a list of new tensors for the sake of tracking # The reason to use that instead of a tensor is stated below. # e.g. Users may use the following clause to deal with sizes @@ -1054,13 +1054,13 @@ def new_func(*args, **kwargs): if key == 'torch.Tensor.size' and len(args) > 1: # Tracking torch.Tensor.size with optional int argument result = torch.tensor(result) - if type(result) == torch.Size: + if type(result) is torch.Size: # Handling dynamic shape # If the torch.Size object is generated by a tensor, # then we connect it to the graph. # Otherwise, don't track it. - if len(args) > 0 and type(args[0]) == torch.Tensor: + if len(args) > 0 and type(args[0]) is torch.Tensor: # Create a list of new tensors for the sake of tracking # The reason to use that instead of a tensor is stated below. # e.g. Users may use the following clause to deal with sizes @@ -1444,7 +1444,7 @@ def _skip_ignored_args(name, *args, **kwargs): # Skip the arg if it has the same value with the default one if default_value == prop_value: continue - if type(prop_value) == str: + if type(prop_value) is str: prop_value_str = f'"{prop_value}"' else: prop_value_str = prop_value @@ -1702,13 +1702,13 @@ def _model_pre_tracer(module, inputs): def _model_tracer(module, inputs, outputs): log.debug('tracer in _model_tracer') - if type(outputs) == torch.Tensor: + if type(outputs) is torch.Tensor: node = TraceNode(TraceFunction("output")) add_output_node(node, outputs) elif isinstance(outputs, (list, tuple)): for i in outputs: - if type(i) == torch.Tensor or ( - isinstance(i, (list, tuple)) and all((type(x) == torch.Tensor for x in i)) + if type(i) is torch.Tensor or ( + isinstance(i, (list, tuple)) and all((type(x) is torch.Tensor for x in i)) ): node = TraceNode(TraceFunction("output")) add_output_node(node, i) @@ -1719,8 +1719,8 @@ def _model_tracer(module, inputs, outputs): ) elif isinstance(outputs, dict): for k, v in outputs.items(): - if type(v) == torch.Tensor or ( - isinstance(v, (list, tuple)) and all((type(x) == torch.Tensor for x in v)) + if type(v) is torch.Tensor or ( + isinstance(v, (list, tuple)) and all((type(x) is torch.Tensor for x in v)) ): node = TraceNode(TraceFunction("output")) add_output_node(node, v) @@ -1733,7 +1733,7 @@ def _model_tracer(module, inputs, outputs): log.warning(f'Output type is not supported: {type(outputs).__name__}, try to extract tensors from it') for k in outputs.__dir__(): v = getattr(outputs, k) - if type(v) == torch.Tensor or (type(v) in (list, tuple) and all((type(x) == torch.Tensor for x in v))): + if type(v) is torch.Tensor or (type(v) in (list, tuple) and all((type(x) is torch.Tensor for x in v))): node = TraceNode(TraceFunction("output")) add_output_node(node, v) @@ -2016,7 +2016,7 @@ def init(self) -> None: device = get_module_device(self.module) with self.__numbering_context(): - if type(self.dummy_input) == torch.Tensor: + if type(self.dummy_input) is torch.Tensor: actual_input = [self.dummy_input] elif isinstance(self.dummy_input, (tuple, list)): actual_input = list(self.dummy_input) @@ -2026,7 +2026,7 @@ def init(self) -> None: for i in range(len(actual_input)): dummy_input = actual_input[i] - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: new_input = dummy_input.detach().clone() if new_input.is_floating_point(): new_input.requires_grad = True @@ -2192,8 +2192,8 @@ def reset_input_output_for_graph(self, input_names: typing.List[str], output_nam for node in output_nodes: node.next_nodes.clear() for i, t in enumerate(node.next_tensors): - if type(t) == torch.Tensor or ( - isinstance(t, (list, tuple)) and all((type(x) == torch.Tensor for x in t)) + if type(t) is torch.Tensor or ( + isinstance(t, (list, tuple)) and all((type(x) is torch.Tensor for x in t)) ): with override_current_trace_graph(self): new_node = TraceNode(TraceFunction("output")) @@ -2270,7 +2270,7 @@ def __gen_init_code(self) -> str: orig_constructor_line = module_constructor_lines[id(node.module)] line = f' self.{node.unique_name} = {orig_constructor_line}' lines.append(line) - elif type(node.module) == ConstantNode: + elif type(node.module) is ConstantNode: # Parameter generation self.used_namespaces.add('torch') @@ -2295,7 +2295,7 @@ def __gen_init_code(self) -> str: f' dtype={node.module.dtype}{requires_grad_prop}), persistent=False)' ) lines.append(line) - elif type(node.module) != TraceFunction: + elif type(node.module) is not TraceFunction: # Generate the module even if the constructor is not caught log.info( f'the constructor of the module {node.unique_name} of type {type(node.module).__name__} is not' @@ -2320,7 +2320,7 @@ def __gen_forward_code(self, inplace=False) -> str: output = ", ".join([node.unique_name]) param = ", ".join([node.prev_node_unique_name(i, inplace) for i in range(len(node.prev_nodes))]) - if type(node.module) == TraceFunction: + if type(node.module) is TraceFunction: full_name = node.full_name() if not full_name.startswith('torch.') and not full_name.startswith('self.') and '.' in full_name: ns = '.'.join(full_name.split('.')[:-1]) @@ -2352,7 +2352,7 @@ def __gen_forward_code(self, inplace=False) -> str: mod_name = mod_name_dict[node.module] if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0: continue - if node.type() == nn.LSTM and len(node.prev_nodes) == 3 and len(node.prev_tensors) == 3: + if node.type() is nn.LSTM and len(node.prev_nodes) == 3 and len(node.prev_tensors) == 3: first_arg = node.prev_node_unique_name(0) param = ", ".join([node.prev_node_unique_name(i) for i in range(1, len(node.prev_nodes))]) line = f" {output} = self.{mod_name}({first_arg}, ({param}))" @@ -2555,9 +2555,9 @@ def get_submodule_with_parent_from_name(self, module_name: str, inplace: bool = for ns in module_name_parts: last_obj = cur_obj - if type(cur_obj) == nn.ModuleList: + if type(cur_obj) is nn.ModuleList: cur_obj = cur_obj[int(ns)] - elif type(cur_obj) == nn.ModuleDict: + elif type(cur_obj) is nn.ModuleDict: cur_obj = cur_obj[ns] else: cur_obj = getattr(cur_obj, ns) @@ -2574,18 +2574,18 @@ def update_submodule_in_node(self, node: TraceNode, module: nn.Module, inplace: cur_obj = self.module for ns in module_name_parts[:-1]: - if type(cur_obj) == nn.ModuleList: + if type(cur_obj) is nn.ModuleList: cur_obj = cur_obj[int(ns)] - elif type(cur_obj) == nn.ModuleDict: + elif type(cur_obj) is nn.ModuleDict: cur_obj = cur_obj[ns] else: cur_obj = getattr(cur_obj, ns) ns = module_name_parts[-1] new_obj = module - if type(cur_obj) == nn.ModuleList: + if type(cur_obj) is nn.ModuleList: cur_obj[int(ns)] = new_obj - elif type(cur_obj) == nn.ModuleDict: + elif type(cur_obj) is nn.ModuleDict: cur_obj[ns] = new_obj else: setattr(cur_obj, ns, new_obj) @@ -2604,7 +2604,7 @@ def filter_forward_nodes(self, predicate, custom_data=None, reverse=False) -> ty def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None): """Insert a module or an existing node after a node in the computation graph""" # Create a new node and connects it to the next node/tensors - if type(module) != TraceNode: + if type(module) is not TraceNode: new_node = TraceNode(module, cur_graph=self) if node in self.input_nodes or node in self.constant_nodes: self.forward_nodes.insert(0, new_node) @@ -2656,7 +2656,7 @@ def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[ty # Since the function calls are rendered beforehand, # we need to change them as well. - if type(next_node.module) == TraceFunction: + if type(next_node.module) is TraceFunction: if next_node.module.args_string is not None: for idx in updated_indices: old_unique_name = tensor_name_from_parts(node.unique_name, idx, is_constant_node) @@ -2673,7 +2673,7 @@ def insert_new_after( next_tensors: typing.Optional[typing.List[torch.Tensor]] = None, before_node: typing.Optional[TraceNode] = None, ): - assert type(module_or_func) != TraceNode + assert type(module_or_func) is not TraceNode new_node = TraceNode(module_or_func, cur_graph=self) if next_tensors is None: @@ -2715,7 +2715,7 @@ def insert_between( old_unique_name = prev_node.unique_name is_constant_node = type(prev_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) - if type(module) != TraceNode: + if type(module) is not TraceNode: new_node = TraceNode(module, cur_graph=self) if prev_node not in next_node.prev_nodes or next_node not in prev_node.next_nodes: @@ -2813,7 +2813,7 @@ def insert_between( break # Update previous node name for next nodes (TraceFunction) - if type(next_node.module) == TraceFunction: + if type(next_node.module) is TraceFunction: for old_idx, new_idx in index_mapping: prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node) next_unique_name = tensor_name_from_parts(new_node.unique_name, new_idx, is_new_constant_node) @@ -2831,7 +2831,7 @@ def insert_before( ): """Insert a module or an existing node before a node in the computation graph""" # Create a new node and connects it to the previous node/tensors - if type(module) != TraceNode: + if type(module) is not TraceNode: if not isinstance(module, (tuple, list)): modules = [module] rev_mode = False @@ -2927,7 +2927,7 @@ def insert_before( break # Update previous node name for next nodes (TraceFunction) - if type(node.module) == TraceFunction and node not in self.output_nodes: + if type(node.module) is TraceFunction and node not in self.output_nodes: new_node = new_nodes[0] for i in range(len(new_node.prev_nodes)): old_unique_name = new_node.prev_nodes[i].unique_name @@ -3051,7 +3051,7 @@ def replace_node_module(self, node: TraceNode, module: torch.nn.Module) -> None: for nt in node.next_tensors: if id(nt) == id(pt): idx = n.prev_indices[i] - if type(n.module) == TraceFunction: + if type(n.module) is TraceFunction: prev_unique_name = tensor_name_from_parts(old_unique_name, idx, is_constant_node) next_unique_name = tensor_name_from_parts(node.unique_name, idx, is_new_constant_node) log.debug(f'node rename: {prev_unique_name} -> {next_unique_name}') @@ -3068,7 +3068,7 @@ def fuse_nodes_to_func( # Otherwise, we need to construct one. next_nodes = [] next_tensors = [] - if type(nodes[0].module) == TraceFunction: + if type(nodes[0].module) is TraceFunction: next_nodes.extend(nodes[-1].next_nodes) next_tensors.extend(nodes[-1].next_tensors) @@ -3114,7 +3114,7 @@ def fuse_nodes_to_func( if id(pt) == id(nt): idx = n.prev_indices[i] # Rewrite func calls in next nodes - if type(n.module) == TraceFunction: + if type(n.module) is TraceFunction: old_unique_name = tensor_name_from_parts(last_node_unique_name, idx, False) new_unique_name = tensor_name_from_parts(first_node_unique_name, idx, False) n.module.replace_tensor_name(old_unique_name, new_unique_name) @@ -3179,7 +3179,7 @@ def remove_node(self, node: TraceNode) -> None: n.prev_indices[i] = index_dict[pt] new_idx = n.prev_indices[i] # Rewrite func calls in next nodes - if type(n.module) == TraceFunction: + if type(n.module) is TraceFunction: if n.module.args_string is not None: prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node) new_unique_name = tensor_name_from_parts( @@ -3335,7 +3335,7 @@ def check_tensor_type(value) -> bool: res = check_tensor_type(item) if res: return res - elif type(value) == torch.Tensor: + elif type(value) is torch.Tensor: return True return False @@ -3346,7 +3346,7 @@ def check_creation_args(args: typing.Iterable) -> typing.Tuple: for arg in args: if isinstance(arg, (tuple, list)): new_args.append(check_creation_args(arg)) - elif type(arg) == torch.Tensor: + elif type(arg) is torch.Tensor: if arg.dim() == 0: new_args.append(arg.item()) else: diff --git a/tinynn/llm_quant/util.py b/tinynn/llm_quant/util.py index 366dcce1..6f3d76bf 100644 --- a/tinynn/llm_quant/util.py +++ b/tinynn/llm_quant/util.py @@ -39,9 +39,9 @@ def get_submodule_with_parent_from_name(model, module_name): for ns in module_name_parts: last_obj = cur_obj - if type(cur_obj) == nn.ModuleList: + if type(cur_obj) is nn.ModuleList: cur_obj = cur_obj[int(ns)] - elif type(cur_obj) == nn.ModuleDict: + elif type(cur_obj) is nn.ModuleDict: cur_obj = cur_obj[ns] else: cur_obj = getattr(cur_obj, ns) diff --git a/tinynn/util/bn_restore.py b/tinynn/util/bn_restore.py index 59594a95..89730b8e 100644 --- a/tinynn/util/bn_restore.py +++ b/tinynn/util/bn_restore.py @@ -97,9 +97,9 @@ def get_submodule_with_parent_from_name(model, module_name): for ns in module_name_parts: last_obj = cur_obj - if type(cur_obj) == nn.ModuleList: + if type(cur_obj) is nn.ModuleList: cur_obj = cur_obj[int(ns)] - elif type(cur_obj) == nn.ModuleDict: + elif type(cur_obj) is nn.ModuleDict: cur_obj = cur_obj[ns] else: cur_obj = getattr(cur_obj, ns) diff --git a/tinynn/util/converter_util.py b/tinynn/util/converter_util.py index 8a3e8176..c084bc06 100644 --- a/tinynn/util/converter_util.py +++ b/tinynn/util/converter_util.py @@ -71,17 +71,17 @@ def generate_converter_config( the size of the list should be the same of that of the inputs """ if type(input_transpose) in (tuple, list): - if len(input_transpose) != len(inputs) or not all((type(x) == bool for x in input_transpose)): + if len(input_transpose) != len(inputs) or not all((type(x) is bool for x in input_transpose)): raise AssertionError('input transpose should either be boolean or list of booleans') - elif type(input_transpose) == bool or input_transpose is None: + elif type(input_transpose) is bool or input_transpose is None: input_transpose = [input_transpose] * len(inputs) else: raise AssertionError('input transpose should either be boolean or list of booleans') if type(output_transpose) in (tuple, list): - if len(output_transpose) != len(outputs) or not all((type(x) == bool for x in output_transpose)): + if len(output_transpose) != len(outputs) or not all((type(x) is bool for x in output_transpose)): raise AssertionError('output transpose should either be boolean or list of booleans') - elif type(output_transpose) == bool or output_transpose is None: + elif type(output_transpose) is bool or output_transpose is None: output_transpose = [output_transpose] * len(inputs) else: raise AssertionError('output transpose should either be boolean or list of booleans') diff --git a/tinynn/util/quantization_analysis_util.py b/tinynn/util/quantization_analysis_util.py index 78dc2979..c47b8a9d 100644 --- a/tinynn/util/quantization_analysis_util.py +++ b/tinynn/util/quantization_analysis_util.py @@ -122,7 +122,7 @@ def forward_hook(module, input, output): device = get_module_device(model) - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: actual_input = [dummy_input] elif isinstance(dummy_input, (tuple, list)): actual_input = list(dummy_input) @@ -132,7 +132,7 @@ def forward_hook(module, input, output): for i in range(len(actual_input)): dummy_input = actual_input[i] - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: if dummy_input.device != device: actual_input[i] = dummy_input.to(device) @@ -223,7 +223,7 @@ def forward_hook(module, input, output): device = get_module_device(model) - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: actual_input = [dummy_input] elif isinstance(dummy_input, (tuple, list)): actual_input = list(dummy_input) @@ -233,7 +233,7 @@ def forward_hook(module, input, output): for i in range(len(actual_input)): dummy_input = actual_input[i] - if type(dummy_input) == torch.Tensor: + if type(dummy_input) is torch.Tensor: if dummy_input.device != device: actual_input[i] = dummy_input.to(device) diff --git a/tutorials/model_conversion/basic.ipynb b/tutorials/model_conversion/basic.ipynb index 18a776b6..4e0c53ec 100644 --- a/tutorials/model_conversion/basic.ipynb +++ b/tutorials/model_conversion/basic.ipynb @@ -73,6 +73,9 @@ "source": [ "import os\n", "from torch.hub import download_url_to_file\n", + "from PIL import Image\n", + "from torchvision import transforms\n", + "import numpy as np\n", "\n", "cwd = os.path.abspath(os.getcwd())\n", "img_path = os.path.join(cwd, 'dog.jpg')\n", @@ -85,10 +88,6 @@ "# If you have diffculties accessing Github, then you may try out the second link\n", "download_url_to_file(img_urls[0], img_path)\n", "\n", - "from PIL import Image\n", - "from torchvision import transforms\n", - "import numpy as np\n", - "\n", "img = Image.open(img_path)\n", "\n", "mean = np.array([0.485, 0.456, 0.406], dtype='float32')\n",