From 8809d0371bedfc08f87d548c559930cab845d505 Mon Sep 17 00:00:00 2001 From: Zhang Jie Date: Sat, 30 Mar 2024 12:01:07 +0000 Subject: [PATCH] add sm2 key pre-check before using it --- poetry.lock | 252 ++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + src/pygmssl/sm2.py | 57 +++++++--- tests/test_sm2.py | 5 +- 4 files changed, 300 insertions(+), 15 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9ec31e1..f7701fd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,23 @@ # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +[[package]] +name = "asttokens" +version = "2.4.1" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24"}, + {file = "asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0"}, +] + +[package.dependencies] +six = ">=1.12.0" + +[package.extras] +astroid = ["astroid (>=1,<2)", "astroid (>=2,<4)"] +test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] + [[package]] name = "autopep8" version = "2.1.0" @@ -89,6 +107,31 @@ files = [ [package.extras] toml = ["tomli"] +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + +[[package]] +name = "executing" +version = "2.0.1" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = ">=3.5" +files = [ + {file = "executing-2.0.1-py2.py3-none-any.whl", hash = "sha256:eac49ca94516ccc753f9fb5ce82603156e590b27525a8bc32cce8ae302eb61bc"}, + {file = "executing-2.0.1.tar.gz", hash = "sha256:35afe2ce3affba8ee97f2d69927fa823b08b472b7b994e36a52a964b93d16147"}, +] + +[package.extras] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] + [[package]] name = "flake8" version = "7.0.0" @@ -116,6 +159,41 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "ipython" +version = "8.22.2" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.10" +files = [ + {file = "ipython-8.22.2-py3-none-any.whl", hash = "sha256:3c86f284c8f3d8f2b6c662f885c4889a91df7cd52056fd02b7d8d6195d7f56e9"}, + {file = "ipython-8.22.2.tar.gz", hash = "sha256:2dcaad9049f9056f1fef63514f176c7d41f930daa78d05b82a176202818f2c14"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\" and sys_platform != \"emscripten\""} +prompt-toolkit = ">=3.0.41,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5.13.0" + +[package.extras] +all = ["ipython[black,doc,kernel,nbconvert,nbformat,notebook,parallel,qtconsole,terminal]", "ipython[test,test-extra]"] +black = ["black"] +doc = ["docrepr", "exceptiongroup", "ipykernel", "ipython[test]", "matplotlib", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "sphinxcontrib-jquery", "stack-data", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pickleshare", "pytest (<8)", "pytest-asyncio (<0.22)", "testpath"] +test-extra = ["curio", "ipython[test]", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] + [[package]] name = "isort" version = "5.13.2" @@ -130,6 +208,39 @@ files = [ [package.extras] colors = ["colorama (>=0.4.6)"] +[[package]] +name = "jedi" +version = "0.19.1" +description = "An autocompletion tool for Python that can be used for text editors." +optional = false +python-versions = ">=3.6" +files = [ + {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, + {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, +] + +[package.dependencies] +parso = ">=0.8.3,<0.9.0" + +[package.extras] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] + +[[package]] +name = "matplotlib-inline" +version = "0.1.6" +description = "Inline Matplotlib backend for Jupyter" +optional = false +python-versions = ">=3.5" +files = [ + {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, + {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, +] + +[package.dependencies] +traitlets = "*" + [[package]] name = "mccabe" version = "0.7.0" @@ -152,6 +263,35 @@ files = [ {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, ] +[[package]] +name = "parso" +version = "0.8.3" +description = "A Python Parser" +optional = false +python-versions = ">=3.6" +files = [ + {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, + {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, +] + +[package.extras] +qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +testing = ["docopt", "pytest (<6.0.0)"] + +[[package]] +name = "pexpect" +version = "4.9.0" +description = "Pexpect allows easy control of interactive console applications." +optional = false +python-versions = "*" +files = [ + {file = "pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523"}, + {file = "pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f"}, +] + +[package.dependencies] +ptyprocess = ">=0.5" + [[package]] name = "pluggy" version = "1.4.0" @@ -167,6 +307,45 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "prompt-toolkit" +version = "3.0.43" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.43-py3-none-any.whl", hash = "sha256:a11a29cb3bf0a28a387fe5122cdb649816a957cd9261dcedf8c9f1fef33eacf6"}, + {file = "prompt_toolkit-3.0.43.tar.gz", hash = "sha256:3527b7af26106cbc65a040bcc84839a3566ec1b051bb0bfe953631e704b0ff7d"}, +] + +[package.dependencies] +wcwidth = "*" + +[[package]] +name = "ptyprocess" +version = "0.7.0" +description = "Run a subprocess in a pseudo terminal" +optional = false +python-versions = "*" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.2" +description = "Safely evaluate AST nodes without side effects" +optional = false +python-versions = "*" +files = [ + {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, + {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "pycodestyle" version = "2.11.1" @@ -230,6 +409,21 @@ files = [ {file = "pyflakes-3.2.0.tar.gz", hash = "sha256:1c61603ff154621fb2a9172037d84dca3500def8c8b630657d1701f026f8af3f"}, ] +[[package]] +name = "pygments" +version = "2.17.2" +description = "Pygments is a syntax highlighting package written in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, + {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, +] + +[package.extras] +plugins = ["importlib-metadata"] +windows-terminal = ["colorama (>=0.4.6)"] + [[package]] name = "pytest" version = "8.1.1" @@ -268,7 +462,63 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + +[[package]] +name = "stack-data" +version = "0.6.3" +description = "Extract data from python stack frames and tracebacks for informative displays" +optional = false +python-versions = "*" +files = [ + {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, + {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, +] + +[package.dependencies] +asttokens = ">=2.1.0" +executing = ">=1.2.0" +pure-eval = "*" + +[package.extras] +tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] + +[[package]] +name = "traitlets" +version = "5.14.2" +description = "Traitlets Python configuration system" +optional = false +python-versions = ">=3.8" +files = [ + {file = "traitlets-5.14.2-py3-none-any.whl", hash = "sha256:fcdf85684a772ddeba87db2f398ce00b40ff550d1528c03c14dbf6a02003cd80"}, + {file = "traitlets-5.14.2.tar.gz", hash = "sha256:8cdd83c040dab7d1dee822678e5f5d100b514f7b72b01615b26fc5718916fdf9"}, +] + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.1)", "pytest-mock", "pytest-mypy-testing"] + +[[package]] +name = "wcwidth" +version = "0.2.13" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, +] + [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "9abf2dd59f578b20a7a5e8f0a885e2b18326020d95abf0966a71db8f336b37c8" +content-hash = "db26508f9a2ef7f26b3bddbf7f502ad7b310db7f075836bb6e7f9904ab1ddead" diff --git a/pyproject.toml b/pyproject.toml index d633df2..c060bf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.11" pycryptodomex = "^3.20.0" +ipython = "^8.22.2" [tool.poetry.group.dev.dependencies] autopep8 = "^2.1.0" diff --git a/src/pygmssl/sm2.py b/src/pygmssl/sm2.py index 4ecc5fe..939ee6d 100644 --- a/src/pygmssl/sm2.py +++ b/src/pygmssl/sm2.py @@ -1,7 +1,9 @@ import base64 from ctypes import byref, c_uint8, c_size_t, Structure, c_char_p, pointer +import functools import tempfile import os +from typing import Callable, Literal, Concatenate, Self from Cryptodome.Util.asn1 import DerSequence @@ -41,38 +43,63 @@ class _SM2_SIGN_CTX(Structure): class SM2: def __init__(self, pub_key: bytes | None = None, pri_key: bytes | None = None): self._sm2_key = _SM2_KEY() + self._has_pub = self._has_pri = False if pub_key and len(pub_key) == 65 and pub_key[0] == 4: # if 65 bytes, 0x04 + pub.x + pub.y pub_key = pub_key[1:] if pub_key: - if len(pub_key) != 64: - raise ValueError('the length of sm2 public key should be 64 bytes') - self._sm2_key.pub.x[:32] = pub_key[:32] - self._sm2_key.pub.y[:32] = pub_key[32:64] + self.set_pub(pub_key) if pri_key: - if len(pri_key) != 32: - raise ValueError('the length of sm2 private key should be 32 bytes') - self._sm2_key.pri[:32] = pri_key + self.set_pri(pri_key) + + def set_pub(self, pub_key: bytes): + if len(pub_key) != 64: + raise ValueError('the length of sm2 public key should be 64 bytes') + self._sm2_key.pub.x[:32] = pub_key[:32] + self._sm2_key.pub.y[:32] = pub_key[32:64] + self._has_pub = True + + def set_pri(self, pri_key: bytes): + if len(pri_key) != 32: + raise ValueError('the length of sm2 private key should be 32 bytes') + self._sm2_key.pri[:32] = pri_key + self._has_pri = True + + @staticmethod + def check(propery: Literal['_has_pri'] | Literal['_has_pub']): + def _func[**P, R](fn: Callable[Concatenate[Self, P], R]): + @functools.wraps(fn) + def wrapper(self: Self, *args: P.args, **kwargs: P.kwargs) -> R: + if not getattr(self, propery): + raise ValueError(f'{propery} not set') + return fn(self, *args, **kwargs) + return wrapper + return _func @classmethod def generate_new_pair(cls) -> 'SM2': obj = cls() _gm.sm2_key_generate(byref(obj._sm2_key)) + obj._has_pri = obj._has_pub = True return obj @property - def pub_key(self) -> bytes: + @check('_has_pub') + def pub_key(self: 'SM2') -> bytes: return bytes(self._sm2_key.pub) @property + @check('_has_pri') def pri_key(self) -> bytes: return bytes(self._sm2_key.pri) + @check('_has_pub') def compute_z(self, id: bytes = SM2_DEFAULT_ID) -> bytes: z = (c_uint8 * 32)() _gm.sm2_compute_z(byref(z), byref(self._sm2_key.pub), c_char_p(id), len(id)) return bytes(z) + @check('_has_pri') def sign(self, data: bytes, id: bytes = SM2_DEFAULT_ID, asn1: bool = False) -> bytes: _sign_ctx = _SM2_SIGN_CTX() _gm.sm2_sign_init(byref(_sign_ctx), byref(self._sm2_key), c_char_p(id), len(id)) @@ -88,18 +115,18 @@ def sign(self, data: bytes, id: bytes = SM2_DEFAULT_ID, asn1: bool = False) -> b if asn1: _k = DerSequence() _k.decode(sig) - r, s = _k[0], _k[1] - sig = r.to_bytes(32, 'big') + s.to_bytes(32, 'big') + sig = bytes(_k[0].to_bytes(32, 'big') + _k[1].to_bytes(32, 'big')) # type: ignore return sig + @check('_has_pub') def verify(self, data: bytes, sig: bytes, id: bytes = SM2_DEFAULT_ID, asn1: bool = False) -> bool: if len(sig) == 64: if not asn1: raise ValueError('when sig is 64 bytes, ans1 flag must be true') # asn1 der格式的, 通常是JAVA搞过来的 _k = DerSequence() - _k.append(int.from_bytes(sig[:32], 'big')) - _k.append(int.from_bytes(sig[32:], 'big')) + _k.append(int.from_bytes(sig[:32], 'big')) # type: ignore + _k.append(int.from_bytes(sig[32:], 'big')) # type: ignore sig = _k.encode() _verify_ctx = _SM2_SIGN_CTX() _gm.sm2_verify_init(byref(_verify_ctx), byref(self._sm2_key), c_char_p(id), len(id)) @@ -111,6 +138,7 @@ def verify(self, data: bytes, sig: bytes, id: bytes = SM2_DEFAULT_ID, asn1: bool ret = _gm.sm2_verify_finish(byref(_verify_ctx), c_char_p(sig), len(sig)) return ret == 1 + @check('_has_pub') def encrypt(self, data: bytes) -> bytes: if len(data) > SM2_MAX_PLAINTEXT_SIZE: raise ValueError('to encrypt data\'s length must <= sm2.SM2_MIN_PLAINTEXT_SIZE') @@ -121,6 +149,7 @@ def encrypt(self, data: bytes) -> bytes: _gm.sm2_encrypt(byref(self._sm2_key), byref(buff), len(data), byref(out), byref(length)) return bytes(out[:length.value]) + @check('_has_pri') def decrypt(self, data: bytes) -> bytes: if len(data) > SM2_MAX_CIPHERTEXT_SIZE: raise ValueError('to decrypt data\'s length must <= sm2.SM2_MAX_CIPHERTEXT_SIZE') @@ -193,6 +222,7 @@ def _nix_import_private_key_from_encrypted_pem(cls, pem: bytes, password: bytes) obj = SM2() assert _gm.sm2_private_key_info_decrypt_from_pem(byref(obj._sm2_key), c_char_p(password), fp) == 1 libc.fclose(fp) + obj._has_pri = obj._has_pub = True return obj @classmethod @@ -205,6 +235,7 @@ def _nix_import_public_key_from_pem(cls, pem: bytes) -> 'SM2': obj = SM2() assert _gm.sm2_public_key_info_from_pem(byref(obj._sm2_key), fp) == 1 libc.fclose(fp) + obj._has_pub = True return obj @staticmethod @@ -231,6 +262,7 @@ def _win_import_private_key_from_encrypted_pem(cls, pem: bytes, password: bytes) assert _gm.sm2_private_key_info_decrypt_from_der(byref(obj._sm2_key), byref( p), byref(attr_len), password, byref(cp), byref(buflen)) == 1 assert buflen.value == 0 + obj._has_pri = obj._has_pub = True return obj @classmethod @@ -243,6 +275,7 @@ def _win_import_public_key_from_pem(cls, pem: bytes) -> 'SM2': cp = pointer(buf) assert _gm.sm2_public_key_info_from_der(byref(obj._sm2_key), byref(cp), byref(vlen)) == 1 assert vlen.value == 0 + obj._has_pub = True return obj if win32: diff --git a/tests/test_sm2.py b/tests/test_sm2.py index 32b29b7..ee1dca3 100644 --- a/tests/test_sm2.py +++ b/tests/test_sm2.py @@ -92,8 +92,9 @@ def test_101_pub_pem_export_and_import(self): assert obj.pri_key != b'\x00' * 32 pem = obj.export_public_key_to_pem() new_obj = SM2.import_public_key_from_pem(pem) - assert new_obj.pri_key == b'\x00' * 32 - assert new_obj.pub_key != b'\x00' * 64 + with self.assertRaises(ValueError): + new_obj.pri_key + assert new_obj.pub_key assert new_obj.pub_key == obj.pub_key def test_102_error_import_private_pem(self):