diff --git a/.circleci/config.yml b/.circleci/config.yml index 90d715db..d4baa859 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ version: 2 jobs: build: docker: - - image: circleci/python:2.7 + - image: circleci/python:2.7.14-jessie working_directory: ~/python-mbedtls steps: - checkout @@ -22,63 +22,63 @@ jobs: - restore_cache: keys: - - py27-2.7.14 + - py27-v1-{{ arch }}-2.7.15 - run: - name: install python 2.7.14 + name: install python 2.7 command: | - if [ ! -d ".pyenv/versions/2.7.14" ]; then + if [ ! -d ".pyenv/versions/2.7.15" ]; then eval "$(pyenv init -)" - pyenv install 2.7.14 + pyenv install 2.7.15 fi - save_cache: - key: py27-2.7.14 + key: py27-v1-{{ arch }}-2.7.15 paths: - - .pyenv/versions/2.7.14 + - .pyenv/versions/2.7.15 - restore_cache: keys: - - py34-3.4.8 + - py34-v1-{{ arch }}-3.4.8 - run: - name: install python 3.4.8 + name: install python 3.4 command: | if [ ! -d ".pyenv/versions/3.4.8" ]; then eval "$(pyenv init -)" pyenv install 3.4.8 fi - save_cache: - key: py34-3.4.8 + key: py34-v1-{{ arch }}-3.4.8 paths: - .pyenv/versions/3.4.8 - restore_cache: keys: - - py35-3.5.5 + - py35-v1-{{ arch }}-3.5.5 - run: - name: install python 3.5.5 + name: install python 3.5 command: | if [ ! -d ".pyenv/versions/3.5.5" ]; then eval "$(pyenv init -)" pyenv install 3.5.5 fi - save_cache: - key: py35-3.5.5 + key: py35-v1-{{ arch }}-3.5.5 paths: - .pyenv/versions/3.5.5 - restore_cache: keys: - - py36-3.6.4 + - py36-v1-{{ arch }}-3.6.5 - run: - name: install python 3.6.4 + name: install python 3.6 command: | - if [ ! -d ".pyenv/versions/3.6.4" ]; then + if [ ! -d ".pyenv/versions/3.6.5" ]; then eval "$(pyenv init -)" - pyenv install 3.6.4 + pyenv install 3.6.5 fi - save_cache: - key: py36-3.6.4 + key: py36-v1-{{ arch }}-3.6.5 paths: - - .pyenv/versions/3.6.4 + - .pyenv/versions/3.6.5 - run: name: setup environment @@ -117,7 +117,7 @@ jobs: name: run tests command: | eval "$(pyenv init -)" - pyenv shell 2.7.14 3.4.8 3.5.5 3.6.4 + pyenv shell 2.7.15 3.4.8 3.5.5 3.6.5 . venv/bin/activate detox @@ -133,6 +133,20 @@ jobs: twine upload dist/* fi + - run: + name: save logs + command: | + mkdir -p out/log + cp .tox/*/log/py*.log out/log || true + when: on_fail + + - run: + name: save dist + command: | + mkdir -p out/dist + cp dist/* out/dist + when: on_success + - store_artifacts: - path: dist - destination: dist + path: out + destination: artifacts diff --git a/ChangeLog b/ChangeLog index ec969192..c94ad7ec 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,19 @@ +[next] + +* Support Diffie-Hellman-Merkle key exchange. +* MPIs (multi-precision integers) now implement the full +`numbers.Integral` API. +* MPIs are erased from memory upon garbage collection. +* The `mpi` library is now public (renamed `_mpi` -> `mpi`). + +API Changes + +* pk: Methods that were previously returning a long integer now +return an MPI. +* exceptions: Rename `_ErrorBase` -> `MbedTLSError`. It is now +the only new exception. +* exceptions: `mbedtls_strerror()` generates the error message. + [0.10.0] - 2018-05-07 Support elliptic curve cryptography diff --git a/README.rst b/README.rst index 435109a6..e2a29d0b 100644 --- a/README.rst +++ b/README.rst @@ -242,6 +242,46 @@ Now, client and server may generate their shared secret:: True +Diffie-Hellman-Merkle key exchange +---------------------------------- + +The classes DHServer and DHClient may be used for DH Key exchange. The +classes have the same API as ECDHServer and ECDHClient, respectively. + +The key exchange is as follow:: + + >>> from mbedtls import pk + >>> srv = pk.DHServer(23, 5) + >>> cli = pk.DHClient(23, 5) + +The values 23 and 5 are the prime modulus (P) and the generator (G). + +The server generates the ServerKeyExchange payload:: + + >>> ske = srv.generate() + >>> cli.import_SKE(ske) + +The payload ends with :math:`G^X mod P` where `X` is the secret value of +the server. + +:: + + >>> cke = cli.generate() + >>> srv.import_CKE(cke) + +`cke` is :math:`G^Y mod P` (with `Y` the secret value from the client) +returned as its representation in bytes so that it can be readily +transported over the network. + +As in ECDH, client and server may now generate their shared secret:: + + >>> secret = srv.generate_secret() + >>> cli.generate_secret() == secret + True + >>> srv.shared_secret == cli.shared_secret + True + + X.509 Certificate writing and parsing with `mbedtls.x509` --------------------------------------------------------- diff --git a/mbedtls/_mpi.pyx b/mbedtls/_mpi.pyx deleted file mode 100644 index da3e84d5..00000000 --- a/mbedtls/_mpi.pyx +++ /dev/null @@ -1,105 +0,0 @@ -"""Multi-precision integer library (MPI).""" - -__author__ = "Mathias Laurin" -__copyright__ = "Copyright 2018, Mathias Laurin" -__license__ = "MIT License" - - -cimport mbedtls._mpi as _mpi -from libc.stdlib cimport malloc, free - -import numbers -from binascii import hexlify, unhexlify - -from mbedtls.exceptions import * - -try: - long -except NameError: - long = int - - -cdef to_bytes(value): - return unhexlify("{0:02x}".format(value).encode("ascii")) - - -cdef from_bytes(value): - return long(hexlify(value), 16) - - -cdef class MPI: - """Multi-precision integer. - - Only minimal bindings here because Python already has - arbitrary-precision integers. - - """ - def __init__(self, value): - if value is None: - return # Implementation detail. - try: - value = to_bytes(value) - except TypeError: - pass - self._from_bytes(value) - - def __cinit__(self): - """Initialize one MPI.""" - _mpi.mbedtls_mpi_init(&self._ctx) - - def __dealloc__(self): - """Unallocate one MPI.""" - _mpi.mbedtls_mpi_free(&self._ctx) - - cdef _len(self): - """Return the total size in bytes.""" - return _mpi.mbedtls_mpi_size(&self._ctx) - - cpdef _from_bytes(self, const unsigned char[:] bytes): - check_error( - _mpi.mbedtls_mpi_read_binary(&self._ctx, &bytes[0], bytes.shape[0])) - return self - - def __str__(self): - return "%i" % long(self) - - def bit_length(self): - """Return the number of bits necessary to represent MPI in binary.""" - return _mpi.mbedtls_mpi_bitlen(&self._ctx) - - def __eq__(self, other): - if not isinstance(other, numbers.Integral): - raise NotImplemented - return long(self) == other - - @classmethod - def from_int(cls, value): - # mbedtls_mpi_lset is 'limited' to 64 bits. - return cls.from_bytes(to_bytes(value), byteorder="big") - - def __int__(self): - return from_bytes(self.to_bytes(self._len(), byteorder="big")) - - @classmethod - def from_bytes(cls, bytes, byteorder): - assert byteorder in {"big", "little"} - order = slice(None, None, -1 if byteorder is "little" else None) - return cls(None)._from_bytes(bytes[order]) - - def to_bytes(self, length, byteorder): - assert byteorder in {"big", "little"} - order = slice(None, None, -1 if byteorder is "little" else None) - cdef unsigned char* output = malloc( - length * sizeof(unsigned char)) - if not output: - raise MemoryError() - try: - check_error(_mpi.mbedtls_mpi_write_binary( - &self._ctx, output, length)) - return bytes(output[:length])[order] - except Exception as exc: - raise OverflowError from exc - finally: - free(output) - - __bytes__ = to_bytes diff --git a/mbedtls/cipher/AES.pyx b/mbedtls/cipher/AES.pyx index 20dfdd57..b848e3fd 100644 --- a/mbedtls/cipher/AES.pyx +++ b/mbedtls/cipher/AES.pyx @@ -35,8 +35,8 @@ def new(key, mode, iv=None): """ if len(key) not in {16, 24, 32}: - raise InvalidKeyLengthError( - "key size must 16, 24, or 32 bytes, got %i" % len(key)) + raise MbedTLSError( + msg="key size must 16, 24, or 32 bytes, got %i" % len(key)) if mode not in { _cipher.MODE_ECB, _cipher.MODE_CBC, @@ -45,7 +45,7 @@ def new(key, mode, iv=None): _cipher.MODE_GCM, _cipher.MODE_CCM }: - raise FeatureUnavailableError("unsupported mode %r" % mode) + raise MbedTLSError(msg="unsupported mode %r" % mode) mode_name = _cipher._get_mode_name(mode) if mode is _cipher.MODE_CFB: mode_name += "128" diff --git a/mbedtls/cipher/ARC4.pyx b/mbedtls/cipher/ARC4.pyx index 1226ab17..4550ffde 100644 --- a/mbedtls/cipher/ARC4.pyx +++ b/mbedtls/cipher/ARC4.pyx @@ -31,7 +31,7 @@ def new(key, mode=None, iv=None): """ if len(key) != key_size: - raise InvalidKeyLengthError( - "key size must be %i bytes, got %i" % (key_size, len(key))) + raise MbedTLSError( + msg="key size must be %i bytes, got %i" % (key_size, len(key))) name = ("ARC4-%i" % (len(key) * 8)).encode("ascii") return _cipher.Cipher(name, key, mode, iv) diff --git a/mbedtls/cipher/Blowfish.pyx b/mbedtls/cipher/Blowfish.pyx index fed74b86..8f5e0703 100644 --- a/mbedtls/cipher/Blowfish.pyx +++ b/mbedtls/cipher/Blowfish.pyx @@ -32,15 +32,16 @@ def new(key, mode, iv=None): """ if len(key) not in range(4, 57): - raise InvalidKeyLengthError( - "key size must be 4 to 57 bytes, got %i" % (key_size, len(key))) + raise MbedTLSError( + msg="key size must be 4 to 57 bytes, got %i" % ( + key_size, len(key))) if mode not in { _cipher.MODE_ECB, _cipher.MODE_CBC, _cipher.MODE_CFB, _cipher.MODE_CTR, }: - raise FeatureUnavailableError("unsupported mode %r" % mode) + raise MbedTLSError(msg="unsupported mode %r" % mode) mode_name = _cipher._get_mode_name(mode) if mode is _cipher.MODE_CFB: mode_name += "64" diff --git a/mbedtls/cipher/Camellia.pyx b/mbedtls/cipher/Camellia.pyx index 04e7499d..80dbcd17 100644 --- a/mbedtls/cipher/Camellia.pyx +++ b/mbedtls/cipher/Camellia.pyx @@ -31,8 +31,8 @@ def new(key, mode, iv=None): """ if len(key) not in {16, 24, 32}: - raise InvalidKeyLengthError( - "key size must 16, 24, or 32 bytes, got %r" % len(key)) + raise MbedTLSError( + msg="key size must 16, 24, or 32 bytes, got %r" % len(key)) if mode not in { _cipher.MODE_ECB, _cipher.MODE_CBC, @@ -41,7 +41,7 @@ def new(key, mode, iv=None): _cipher.MODE_GCM, _cipher.MODE_CCM, }: - raise FeatureUnavailableError("unsupported mode %r" % mode) + raise MbedTLSError(msg="unsupported mode %r" % mode) mode_name = _cipher._get_mode_name(mode) if mode is _cipher.MODE_CFB: mode_name += "128" diff --git a/mbedtls/cipher/DES.pyx b/mbedtls/cipher/DES.pyx index 86f6efe2..66ea8d46 100644 --- a/mbedtls/cipher/DES.pyx +++ b/mbedtls/cipher/DES.pyx @@ -43,13 +43,12 @@ def new(key, mode, iv=None): """ if len(key) != key_size: - raise InvalidKeyLengthError( - "key size must be 16 bytes, got %r" % len(key)) + raise MbedTLSError(msg="key size must be 16 bytes, got %r" % len(key)) if mode not in { _cipher.MODE_ECB, _cipher.MODE_CBC, }: - raise FeatureUnavailableError("unsupported mode %r" % mode) + raise MbedTLSError(msg="unsupported mode %r" % mode) mode_name = _cipher._get_mode_name(mode) name = ("DES-%s" % mode_name).encode("ascii") return _cipher.Cipher(name, key, mode, iv) diff --git a/mbedtls/cipher/DES3.pyx b/mbedtls/cipher/DES3.pyx index aa379004..89c9e83f 100644 --- a/mbedtls/cipher/DES3.pyx +++ b/mbedtls/cipher/DES3.pyx @@ -34,13 +34,13 @@ def new(key, mode, iv=None): """ if len(key) != key_size: - raise InvalidKeyLengthError( - "key size must be %i bytes, got %i" % (key_size, len(key))) + raise MbedTLSError( + msg="key size must be %i bytes, got %i" % (key_size, len(key))) if mode not in { _cipher.MODE_ECB, _cipher.MODE_CBC, }: - raise FeatureUnavailableError("unsupported mode %r" % mode) + raise MbedTLSError(msg="unsupported mode %r" % mode) mode_name = _cipher._get_mode_name(mode) name = ("DES-EDE3-%s" % mode_name).encode("ascii") return _cipher.Cipher(name, key, mode, iv) diff --git a/mbedtls/cipher/DES3dbl.pyx b/mbedtls/cipher/DES3dbl.pyx index 5e17b79b..2b466ab4 100644 --- a/mbedtls/cipher/DES3dbl.pyx +++ b/mbedtls/cipher/DES3dbl.pyx @@ -34,13 +34,13 @@ def new(key, mode, iv=None): """ if len(key) != key_size: - raise InvalidKeyLengthError( - "key size must be %i bytes, got %i" % (key_size, len(key))) + raise MbedTLSError( + msg="key size must be %i bytes, got %i" % (key_size, len(key))) if mode not in { _cipher.MODE_ECB, _cipher.MODE_CBC, }: - raise FeatureUnavailableError("unsupported mode %r" % mode) + raise MbedTLSError(msg="unsupported mode %r" % mode) mode_name = _cipher._get_mode_name(mode) name = ("DES-EDE-%s" % mode_name).encode("ascii") return _cipher.Cipher(name, key, mode, iv) diff --git a/mbedtls/cipher/_cipher.pyx b/mbedtls/cipher/_cipher.pyx index 378b3887..f33a8653 100644 --- a/mbedtls/cipher/_cipher.pyx +++ b/mbedtls/cipher/_cipher.pyx @@ -144,11 +144,11 @@ cdef class Cipher: cipher_name, const unsigned char[:] key, mode, - const unsigned char[:] iv): - if mode in {MODE_CBC, MODE_CFB} and iv is None: + const unsigned char[:] iv not None): + if mode in {MODE_CBC, MODE_CFB} and iv.size == 0: raise ValueError("mode requires an IV") if cipher_name not in get_supported_ciphers(): - raise CipherError(-1, "unsupported cipher: %r" % cipher_name) + raise MbedTLSError(msg="unsupported cipher: %r" % cipher_name) check_error(_cipher.mbedtls_cipher_setup( &self._enc_ctx, diff --git a/mbedtls/exceptions.pxd b/mbedtls/exceptions.pxd new file mode 100644 index 00000000..150d9a99 --- /dev/null +++ b/mbedtls/exceptions.pxd @@ -0,0 +1,9 @@ +"""Declarations from `mbedtls/error.h`.""" + +__author__ = "Mathias Laurin" +__copyright__ = "Copyright 2018, Mathias Laurin" +__license__ = "MIT License" + + +cdef extern from "mbedtls/error.h": + void mbedtls_strerror(int errnum, char *buffer, size_t buflen) diff --git a/mbedtls/exceptions.pyx b/mbedtls/exceptions.pyx index 83c35eb3..23b34646 100644 --- a/mbedtls/exceptions.pyx +++ b/mbedtls/exceptions.pyx @@ -2,201 +2,57 @@ __author__ = "Mathias Laurin" -__copyright__ = "Copyright 2015, Elaborated Networks GmbH" +__copyright__ = "Copyright 2018, Mathias Laurin" __license__ = "MIT License" +from libc.stdlib cimport malloc, free -__all__ = ("CipherError", "InvalidInputLengthError", "InvalidKeyLengthError", - "EntropyError", "MessageDigestError", "PkError", - "check_error", - ) +cimport mbedtls.exceptions as _err -class _ErrorBase(ValueError): - """Base class for cipher exceptions.""" - - def __init__(self, err=None, msg="", *args): - super().__init__(*args) - self.err = err - self.msg = msg - - def __str__(self): - return "%s([0x%04X] %s)" % (self.__class__.__name__, - self.err, self.msg) - - -class Asn1Error(_ErrorBase): - """Errors defined in `asn1.h`.""" - - -class Base64Error(_ErrorBase): - """Errors defined in `base64.h`.""" - - -class CipherError(_ErrorBase): - """Errors defined in the cipher module.""" - - -class InvalidInputLengthError(CipherError): - """Invalid input length.""" - - -class InvalidKeyLengthError(CipherError): - """Invalid key length.""" +__all__ = ("MbedTLSError", "check_error") -class EntropyError(_ErrorBase): - """Errors defined in the entropy module.""" - - -class MessageDigestError(_ErrorBase): - """Errors defined in the md module.""" - - -class PkError(_ErrorBase): - """Errors defined in the pk module.""" - - -class PemError(PkError): - """Errors defined in the pem module.""" - - -class RsaError(PkError): - """Errors defined in the rsa module.""" - - -class EcError(PkError): - """Errors defined in the ecp module.""" - - -class X509Error(_ErrorBase): - """Errors defined in the x509 module.""" +class MbedTLSError(Exception): + """Exception raise by Mbed TLS.""" + def __init__(self, err=None, msg=""): + super(MbedTLSError, self).__init__() + if err is not None: + assert err >= 0 + self.err = err + self._msg = msg + + @property + def msg(self): + if self.err is None: + return self._msg + + # Set buflen to 200 as in `strerror.c`. + cdef size_t buflen = 200 + cdef char* buffer = malloc(buflen * sizeof(char)) + if not buffer: + raise MemoryError() + try: + _err.mbedtls_strerror(self.err, &buffer[0], buflen) + output = bytes(buffer[:buflen]) + try: + olen = output.index(b"\0") + except ValueError: + olen = buflen + return output[:olen].decode("ascii") + finally: + free(buffer) -__lookup = { - # Blowfish-specific - 0x0016: (InvalidKeyLengthError, "invalid key length"), - 0x0018: (InvalidInputLengthError, "invalid data input length"), - # Base64 - 0x002a: (Base64Error, "output buffer too small"), - 0x002c: (Base64Error, "invalid character in input"), - # DES - 0x0032: (InvalidInputLengthError, "the data input has an invalid length"), - # Entropy - 0x003C: (EntropyError, "critical entropy source failure"), - 0x003D: (EntropyError, "no strong source have been added to poll"), - 0x003E: (EntropyError, "no more source can be added"), - 0x003F: (EntropyError, "read/write error in file"), - 0x0040: (EntropyError, "no sources have been added to poll"), - # ASN1 - 0x0060: (Asn1Error, "out of data when parsing and ASN1 data structure"), - 0x0062: (Asn1Error, "ASN.1 tag was of an unexpected value"), - 0x0064: (Asn1Error, - "error when trying to determine the length" + - "or invalid length"), - 0x0066: (Asn1Error, "actual length differs from expected length"), - 0x0068: (Asn1Error, "data is invalid"), - 0x006A: (Asn1Error, "memory allocation failed"), - 0x006c: (Asn1Error, "buffer too small when writing ASN.1 data structure"), - # PEM errors - 0x1080: (PemError, "no PEM header or footer found"), - 0x1100: (PemError, "PEM string is not as expected"), - 0x1180: (PemError, "failed to allocate memory"), - 0x1200: (PemError, "RSA IV is not in hex-format"), - 0x1280: (PemError, "unsupported key encryption algorithm"), - 0x1300: (PemError, "private key password can't be empty"), - 0x1380: (PemError, - "given private key password does not allow for" + - "correct decryption"), - 0x1400: (PemError, - "unavailable feature, e.g. hashing/decryption combination"), - 0x1480: (PemError, "bad input parameters to function"), - # X509 errors - 0x2080: (X509Error, "feature unavailable"), - 0x2100: (X509Error, "unknown OID"), - 0x2180: (X509Error, "invalid format"), - 0x2200: (X509Error, "invalid version"), - 0x2280: (X509Error, "invalid, serial"), - 0x2300: (X509Error, "invalid alg"), - 0x2380: (X509Error, "invalid name"), - 0x2400: (X509Error, "invalid date"), - 0x2480: (X509Error, "invalid signature"), - 0x2500: (X509Error, "invalid extensions"), - 0x2580: (X509Error, "unknown version"), - 0x2600: (X509Error, "unknown sig alg"), - 0x2680: (X509Error, "sig mismatch"), - 0x2700: (X509Error, "cert verify failed"), - 0x2780: (X509Error, "cert unknown format"), - 0x2800: (X509Error, "bad input data"), - 0x2880: (X509Error, "alloc failed"), - 0x2900: (X509Error, "file io error"), - 0x2980: (X509Error, "buffer too small"), - # PK errors - 0x3f80: (PkError, "memory allocation failed"), - 0x3f00: (PkError, - "type mismatch, eg attempt to encrypt with an ECDSA key"), - 0x3e80: (PkError, "bad input parameters to function"), - 0x3e00: (PkError, "read/write of file failed"), - 0x3d80: (PkError, "unsupported key version"), - 0x3d00: (PkError, "invalid key tag or value"), - 0x3c80: (PkError, - "key algorithm is unsupported" + - "(only RSA and EC are supported)"), - 0x3c00: (PkError, "private key password can't be empty"), - 0x3b80: (PkError, - "given private key password does not allow" + - "for correct decryption"), - 0x3b00: (PkError, - "the pubkey tag or value is invalid" + - "(only RSA and EC are supported)"), - 0x3a80: (PkError, "the algorithm tag or value is invalid"), - 0x3a00: (PkError, - "elliptic curve is unsupported" + - "(only NIST curves are supported)"), - 0x3980: (PkError, - "unavailable feature, eg RSA disabled for RSA key"), - 0x3900: (PkError, - "the signature is valid but its length" + - "is less than expected"), - # RSA errors - 0x4080: (RsaError, "bad input parameters to function"), - 0x4100: (RsaError, "input data contains invalid padding and is rejected"), - 0x4180: (RsaError, "something failed during generation of a key"), - 0x4200: (RsaError, "key failed to pass the library's validity check"), - 0x4280: (RsaError, "the public key operation failed"), - 0x4300: (RsaError, "the private key operation failed"), - 0x4380: (RsaError, "the PKCS#1 verification failed"), - 0x4400: (RsaError, - "the output buffer for decryption is not large enough"), - 0x4480: (RsaError, "the random generator failed to generate non-zeros"), - # ECP errors - 0x4f80: (EcError, "bad input parameters to function"), - 0x4f00: (EcError, "the buffer is too small to write to"), - 0x4e80: (EcError, "requested curve not available"), - 0x4e00: (EcError, "the signature is not valid"), - 0x4d80: (EcError, "memory allocation failed"), - 0x4d00: (EcError, - "generation of random value, such as (ephemeral) key, failed"), - 0x4c80: (EcError, "invalid private or public key"), - 0x4c00: (EcError, - "signature is valid but shorter than the user-specified length"), - # MD errors - 0x5080: (MessageDigestError, "the selected feature is not available"), - 0x5100: (MessageDigestError, "bad input parameter to function"), - 0x5180: (MessageDigestError, "failed to allocate memory"), - 0x5200: (MessageDigestError, "opening or reading of file failed"), - # Cipher errors - 0x6080: (CipherError, "the selected feature is not available"), - 0x6100: (CipherError, "bad input parameter to function"), - 0x6180: (CipherError, "failed to allocate memory"), - 0x6200: (CipherError, "input contains invalid padding and is rejected"), - 0x6280: (CipherError, "decryption of block requires a full block"), - 0x6300: (CipherError, "authentication failed (for AEAD modes)"), -} + def __str__(self): + if self.err is None: + return "%s(%s)" % (type(self).__name__, self.msg) + else: + return "%s([0x%04X] %r)" % (self.__class__.__name__, + self.err, self.msg) cpdef check_error(const int err): - if err < 0: - exc, msg = __lookup.get(-err, (_ErrorBase, "")) - raise exc(-err, msg) - return err + if err >= 0: + return err + raise MbedTLSError(-err) diff --git a/mbedtls/_mpi.pxd b/mbedtls/mpi.pxd similarity index 55% rename from mbedtls/_mpi.pxd rename to mbedtls/mpi.pxd index 7fc36615..a5635f9e 100644 --- a/mbedtls/_mpi.pxd +++ b/mbedtls/mpi.pxd @@ -6,12 +6,15 @@ __license__ = "MIT License" cdef extern from "mbedtls/bignum.h": + int MBEDTLS_MPI_MAX_SIZE + # Multi-precision integer library # ------------------------------- ctypedef struct mbedtls_mpi: pass - ctypedef enum mbedtls_mpi_sint: pass + ctypedef enum mbedtls_mpi_sint: + pass # mbedtls_mpi # ----------- @@ -48,8 +51,12 @@ cdef extern from "mbedtls/bignum.h": unsigned char *buf, size_t buflen) - # mbedtls_mpi_shift_l - # mbedtls_mpi_shift_r + int mbedtls_mpi_shift_l( + mbedtls_mpi *X, + size_t count) + int mbedtls_mpi_shift_r( + mbedtls_mpi *X, + size_t count) # mbedtls_mpi_cmp_abs int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, @@ -57,31 +64,58 @@ cdef extern from "mbedtls/bignum.h": # mbedtls_mpi_cmp_int # mbedtls_mpi_add_abs # mbedtls_mpi_sub_abs - # mbedtls_mpi_add_mpi - # mbedtls_mpi_sub_mpi + int mbedtls_mpi_add_mpi( + mbedtls_mpi *X, + const mbedtls_mpi *A, + const mbedtls_mpi *B) + int mbedtls_mpi_sub_mpi( + mbedtls_mpi *X, + const mbedtls_mpi *A, + const mbedtls_mpi *B) # mbedtls_mpi_add_int # mbedtls_mpi_sub_int - # mbedtls_mpi_mul_mpi + int mbedtls_mpi_mul_mpi( + mbedtls_mpi *X, + const mbedtls_mpi *A, + const mbedtls_mpi *B) # mbedtls_mpi_mul_int - # mbedtls_mpi_div_mpi + int mbedtls_mpi_div_mpi( + mbedtls_mpi *Q, + mbedtls_mpi *R, + const mbedtls_mpi *A, + const mbedtls_mpi *B) # mbedtls_mpi_div_int - # mbedtls_mpi_mod_mpi + int mbedtls_mpi_mod_mpi( + mbedtls_mpi *X, + const mbedtls_mpi *A, + const mbedtls_mpi *B) # mbedtls_mpi_mod_int - # mbedtls_mpi_exp_mod - # mbedtls_mpi_fill_random + int mbedtls_mpi_exp_mod( + mbedtls_mpi *X, + const mbedtls_mpi *A, + const mbedtls_mpi *E, + const mbedtls_mpi *N, + mbedtls_mpi *_RR) + int mbedtls_mpi_fill_random( + mbedtls_mpi *X, size_t size, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) # mbedtls_mpi_gcd # mbedtls_mpi_inv_mod - # mbedtls_mpi_is_prime - # mbedtls_mpi_gen_prime + int mbedtls_mpi_is_prime( + mbedtls_mpi *X, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) + int mbedtls_mpi_gen_prime( + mbedtls_mpi *X, size_t size, int dh_flag, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) cdef class MPI: cdef mbedtls_mpi _ctx - cdef _len(self) + cdef size_t _len(self) cpdef _from_bytes(self, const unsigned char[:] bytes) cdef inline from_mpi(mbedtls_mpi *c_mpi): - new_mpi = MPI(0) + new_mpi = MPI() mbedtls_mpi_copy(&new_mpi._ctx, c_mpi) return new_mpi diff --git a/mbedtls/mpi.pyx b/mbedtls/mpi.pyx new file mode 100644 index 00000000..6ddcb452 --- /dev/null +++ b/mbedtls/mpi.pyx @@ -0,0 +1,313 @@ +"""Multi-precision integer library (MPI).""" + +__author__ = "Mathias Laurin" +__copyright__ = "Copyright 2018, Mathias Laurin" +__license__ = "MIT License" + + +cimport mbedtls.mpi as _mpi +cimport mbedtls.random as _random +from libc.stdlib cimport malloc, free + +import numbers +from binascii import hexlify, unhexlify + +import mbedtls.random as _random +from mbedtls.exceptions import * + +try: + long +except NameError: + long = int + + +cdef _random.Random __rng = _random.Random() + + +cdef to_bytes(value): + return unhexlify("{0:02x}".format(value).encode("ascii")) + + +cdef from_bytes(value): + return long(hexlify(value), 16) + + +cdef class MPI: + """Multi-precision integer. + + This class implements `numbers.Integral`. The representation + of the MPI is overwritten with random bytes when the MPI is + garbage collected. + + The bitwise operations are not implemented. + + """ + def __init__(self, value=0): + if isinstance(value, MPI): + value_ = value + check_error(mbedtls_mpi_copy(&self._ctx, &value_._ctx)) + else: + value = to_bytes(value) + self._from_bytes(value) + + def __del__(self): + """Fill the MPI with random data.""" + check_error(mbedtls_mpi_fill_random( + &self._ctx, self._len(), + &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) + + def __cinit__(self): + """Initialize one MPI.""" + _mpi.mbedtls_mpi_init(&self._ctx) + + def __dealloc__(self): + """Unallocate one MPI.""" + _mpi.mbedtls_mpi_free(&self._ctx) + + cdef size_t _len(self): + """Return the total size in bytes.""" + return _mpi.mbedtls_mpi_size(&self._ctx) + + cpdef _from_bytes(self, const unsigned char[:] bytes): + check_error( + _mpi.mbedtls_mpi_read_binary(&self._ctx, &bytes[0], bytes.shape[0])) + return self + + def __str__(self): + return "%i" % long(self) + + def __repr__(self): + return "%s(%i)" % (type(self).__name__, long(self)) + + def bit_length(self): + """Return the number of bits necessary to represent MPI in binary.""" + return _mpi.mbedtls_mpi_bitlen(&self._ctx) + + @classmethod + def from_int(cls, value): + # mbedtls_mpi_lset is 'limited' to 64 bits. + return cls.from_bytes(to_bytes(value), byteorder="big") + + @classmethod + def from_bytes(cls, bytes, byteorder): + assert byteorder in {"big", "little"} + order = slice(None, None, -1 if byteorder is "little" else None) + return cls()._from_bytes(bytes[order]) + + def to_bytes(self, length, byteorder): + assert byteorder in {"big", "little"} + order = slice(None, None, -1 if byteorder is "little" else None) + cdef unsigned char* output = malloc( + length * sizeof(unsigned char)) + if not output: + raise MemoryError() + try: + check_error(_mpi.mbedtls_mpi_write_binary( + &self._ctx, output, length)) + return bytes(output[:length])[order] + except Exception as exc: + raise OverflowError from exc + finally: + free(output) + + __bytes__ = to_bytes + + @classmethod + def prime(cls, size): + """Return an MPI that is probably prime.""" + cdef MPI self_ = cls() + check_error(mbedtls_mpi_gen_prime( + &self_._ctx, size, 0, + &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) + return self_ + + def is_prime(self): + """Miller-Rabin primality test.""" + return check_error(mbedtls_mpi_is_prime( + &self._ctx, + &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) == 0 + + def __hash__(self): + return long(self) + + def __bool__(self): + return self != 0 + + def __add__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + cdef MPI result = MPI() + check_error(mbedtls_mpi_add_mpi( + &result._ctx, &self_._ctx, &other_._ctx)) + return result + + def __neg__(self): + raise TypeError("negative value") + + def __pos__(self): + return self + + def __sub__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + cdef MPI result = MPI() + check_error(mbedtls_mpi_sub_mpi( + &result._ctx, &self_._ctx, &other_._ctx)) + return result + + def __mul__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + cdef MPI result = MPI() + check_error(mbedtls_mpi_mul_mpi( + &result._ctx, &self_._ctx, &other_._ctx)) + return result + + def __truediv__(self, other): + return NotImplemented + + def __pow__(self, exponent, modulus): + if exponent < 0 or not all(isinstance(_, numbers.Integral) + for _ in (self, exponent, modulus)): + raise TypeError("invalid argument") + cdef MPI result = MPI() + cdef MPI self_ = MPI(self) + cdef MPI exponent_ = MPI(exponent) + cdef MPI modulus_ = MPI(modulus) + check_error(mbedtls_mpi_exp_mod( + &result._ctx, &self_._ctx, &exponent_._ctx, &modulus_._ctx, NULL)) + return result + + def __abs__(self): + # Negative values are not supported. + return self + + def __eq__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + return mbedtls_mpi_cmp_mpi(&self_._ctx, &other_._ctx) == 0 + + def __float__(self): + return float(long(self)) + + def __trunc__(self): + return self + + def __floor__(self): + return self + + def __ceil__(self): + return self + + def __round__(self, ndigits=None): + return self + + def __divmod__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + cdef MPI quotient = MPI() + cdef MPI rest = MPI() + check_error(mbedtls_mpi_div_mpi( + "ient._ctx, &rest._ctx, &self_._ctx, &other_._ctx)) + return quotient, rest + + def __floordiv__(self, other): + return divmod(self, other)[0] + + def __mod__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + cdef MPI result = MPI() + check_error(mbedtls_mpi_mod_mpi( + &result._ctx, &self_._ctx, &other_._ctx)) + return result + + def __lt__(self, other): + if not all((isinstance(self, numbers.Integral), + isinstance(other, numbers.Integral))): + return NotImplemented + cdef MPI self_ = MPI(self) + cdef MPI other_ = MPI(other) + return mbedtls_mpi_cmp_mpi(&self_._ctx, &other_._ctx) == -1 + + def __le__(self, other): + return any((self < other, self == other)) + + def __complex__(self): + return complex(float(self)) + + def real(self): + return self + + def imag(self): + return 0 + + def conjugate(self): + return self + + def __int__(self): + return from_bytes(self.to_bytes(self._len(), byteorder="big")) + + def __index__(self): + return long(self) + + def __lshift__(self, other): + if not isinstance(self, MPI): + return NotImplemented + cdef MPI self_ = MPI(self) + check_error(mbedtls_mpi_shift_l(&self_._ctx, long(other))) + return self_ + + def __rshift__(self, other): + if not isinstance(self, MPI): + return NotImplemented + cdef MPI self_ = MPI(self) + check_error(mbedtls_mpi_shift_r(&self_._ctx, long(other))) + return self_ + + def __and__(self, other): + raise NotImplementedError + + def __xor__(self, other): + raise NotImplementedError + + def __or__(self, other): + raise NotImplementedError + + def __invert__(self): + raise NotImplementedError + + @property + def numerator(self): + return self + + @property + def denominator(self): + return 1 + + def __gt__(self, other): + return not self <= other + + def __ge__(self, other): + return any((self > other, self == other)) + + +numbers.Integral.register(MPI) diff --git a/mbedtls/pk.pxd b/mbedtls/pk.pxd index afe7ece7..4f48afcb 100644 --- a/mbedtls/pk.pxd +++ b/mbedtls/pk.pxd @@ -16,7 +16,49 @@ cdef extern from "mbedtls/bignum.h": ctypedef struct mbedtls_mpi: pass - int MBEDTLS_MPI_MAX_SIZE + +cdef extern from "mbedtls/dhm.h": + ctypedef struct mbedtls_dhm_context: + mbedtls_mpi P + mbedtls_mpi G + mbedtls_mpi X + mbedtls_mpi GX + mbedtls_mpi GY + mbedtls_mpi K + + void mbedtls_dhm_init(mbedtls_dhm_context *ctx) + void mbedtls_dhm_free(mbedtls_dhm_context *ctx) + + int mbedtls_dhm_make_params( + mbedtls_dhm_context *ctx, + int x_size, + unsigned char *output, size_t *olen, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng) + int mbedtls_dhm_make_public( + mbedtls_dhm_context *ctx, + int x_size, + unsigned char *output, size_t olen, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng) + + int mbedtls_dhm_read_params( + mbedtls_dhm_context *ctx, + unsigned char **p, + const unsigned char *end) + int mbedtls_dhm_read_public( + mbedtls_dhm_context *ctx, + const unsigned char *input, size_t ilen) + + int mbedtls_dhm_calc_secret( + mbedtls_dhm_context *ctx, + unsigned char *output, size_t output_size, size_t *olen, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng) + + # int mbedtls_dhm_parse_dhm( + # mbedtls_dhm_context *dhm, + # const unsigned char *dhmin, size_t dhminlen) + # int mbedtls_dhm_parse_dhmfile( + # mbedtls_dhm_context *dhm, + # const char *path) cdef extern from "mbedtls/ecp.h": @@ -349,6 +391,10 @@ cdef class ECPoint: cdef mbedtls_ecp_point _ctx +cdef class DHBase: + cdef mbedtls_dhm_context _ctx + + cdef class ECDHBase: cdef mbedtls_ecdh_context _ctx cdef curve diff --git a/mbedtls/pk.pyx b/mbedtls/pk.pyx index 3e3c638f..4f67d88f 100644 --- a/mbedtls/pk.pyx +++ b/mbedtls/pk.pyx @@ -19,7 +19,7 @@ __license__ = "MIT License" from libc.stdlib cimport malloc, free -cimport mbedtls._mpi as _mpi +cimport mbedtls.mpi as _mpi cimport mbedtls.pk as _pk cimport mbedtls.random as _random @@ -33,17 +33,13 @@ import enum from functools import partial import mbedtls.random as _random -from mbedtls.exceptions import check_error, PkError +from mbedtls.exceptions import check_error, MbedTLSError import mbedtls.hash as _hash -try: - long -except NameError: - long = int - __all__ = ("check_pair", "get_supported_ciphers", "get_supported_curves", - "Curve", "RSA", "ECC", "ECDHServer", "ECDHClient") + "Curve", "RSA", "ECC", "DHServer", "DHClient", + "ECDHServer", "ECDHClient") CIPHER_NAME = ( @@ -76,9 +72,9 @@ class Curve(bytes, enum.Enum): # The following calculations come from mbedtls/library/pkwrite.c. -RSA_PUB_DER_MAX_BYTES = 38 + 2 * _pk.MBEDTLS_MPI_MAX_SIZE -MPI_MAX_SIZE_2 = MBEDTLS_MPI_MAX_SIZE / 2 + MBEDTLS_MPI_MAX_SIZE % 2 -RSA_PRV_DER_MAX_BYTES = 47 + 3 * _pk.MBEDTLS_MPI_MAX_SIZE + 5 * MPI_MAX_SIZE_2 +RSA_PUB_DER_MAX_BYTES = 38 + 2 * _mpi.MBEDTLS_MPI_MAX_SIZE +MPI_MAX_SIZE_2 = _mpi.MBEDTLS_MPI_MAX_SIZE / 2 + _mpi.MBEDTLS_MPI_MAX_SIZE % 2 +RSA_PRV_DER_MAX_BYTES = 47 + 3 * _mpi.MBEDTLS_MPI_MAX_SIZE + 5 * MPI_MAX_SIZE_2 ECP_PUB_DER_MAX_BYTES = 30 + 2 * _pk.MBEDTLS_ECP_MAX_BYTES ECP_PRV_DER_MAX_BYTES = 29 + 3 * _pk.MBEDTLS_ECP_MAX_BYTES @@ -172,7 +168,7 @@ cdef class CipherBase: return NotImplemented try: return self.to_DER() == other.to_DER() - except PkError: + except MbedTLSError: return False property _type: @@ -225,7 +221,7 @@ cdef class CipherBase: cdef const unsigned char[:] hash_ = md_alg.digest() cdef size_t sig_len = 0 cdef unsigned char* output = malloc( - _pk.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) if not output: raise MemoryError() try: @@ -272,7 +268,7 @@ cdef class CipherBase: """ cdef size_t olen = 0 cdef unsigned char* output = malloc( - _pk.MBEDTLS_MPI_MAX_SIZE // 2 * sizeof(unsigned char)) + _mpi.MBEDTLS_MPI_MAX_SIZE // 2 * sizeof(unsigned char)) if not output: raise MemoryError() try: @@ -293,7 +289,7 @@ cdef class CipherBase: """ cdef size_t olen = 0 cdef unsigned char* output = malloc( - _pk.MBEDTLS_MPI_MAX_SIZE // 2 * sizeof(unsigned char)) + _mpi.MBEDTLS_MPI_MAX_SIZE // 2 * sizeof(unsigned char)) if not output: raise MemoryError() try: @@ -347,7 +343,7 @@ cdef class CipherBase: check_error(_pk.mbedtls_pk_parse_key( &self._ctx, &key_[0], key_.size, &pwd_[0] if pwd_.size else NULL, pwd_.size)) - except PkError: + except MbedTLSError: check_error(_pk.mbedtls_pk_parse_public_key( &self._ctx, &key_[0], key_.size)) @@ -498,25 +494,25 @@ cdef class ECPoint: """Return the X coordinate.""" def __get__(self): try: - return long(_mpi.from_mpi(&self._ctx.X)) + return _mpi.from_mpi(&self._ctx.X) except ValueError: - return 0 + return _mpi.MPI() property y: """Return the Y coordinate.""" def __get__(self): try: - return long(_mpi.from_mpi(&self._ctx.Y)) + return _mpi.from_mpi(&self._ctx.Y) except ValueError: - return 0 + return _mpi.MPI() property z: """Return the Z coordinate.""" def __get__(self): try: - return long(_mpi.from_mpi(&self._ctx.Z)) + return _mpi.from_mpi(&self._ctx.Z) except ValueError: - return 0 + return _mpi.MPI() def _tuple(self): return (self.x, self.y) @@ -579,7 +575,7 @@ cdef class ECC(CipherBase): def _has_private(self): """Return `True` if the key contains a valid private half.""" cdef const mbedtls_ecp_keypair* ecp = _pk.mbedtls_pk_ec(self._ctx) - return _mpi.mbedtls_mpi_cmp_mpi(&ecp.d, &_mpi.MPI(0)._ctx) != 0 + return _mpi.mbedtls_mpi_cmp_mpi(&ecp.d, &_mpi.MPI()._ctx) != 0 def _has_public(self): """Return `True` if the key contains a valid public half.""" @@ -603,9 +599,9 @@ cdef class ECC(CipherBase): def _private_to_num(self): try: - return long(_mpi.from_mpi(&_pk.mbedtls_pk_ec(self._ctx).d)) + return _mpi.from_mpi(&_pk.mbedtls_pk_ec(self._ctx).d) except ValueError: - return 0 + return _mpi.MPI() def export_key(self, format="DER"): """Return the private key. @@ -653,6 +649,150 @@ cdef class ECC(CipherBase): return ecdh +cdef class DHBase: + + """Base class to DH key exchange: client and server. + + Args: + modulus (int): The prime modulus P. + generator (int): The generator G, a primitive root modulo P. + + See Also: + DHServer, DHClient: The derived classes. + + """ + def __init__(self, modulus, generator): + super().__init__() + check_error(_mpi.mbedtls_mpi_copy( + &self._ctx.P, &_mpi.MPI(modulus)._ctx)) + check_error(_mpi.mbedtls_mpi_copy( + &self._ctx.G, &_mpi.MPI(generator)._ctx)) + + def __cinit__(self): + """Initialize the context.""" + _pk.mbedtls_dhm_init(&self._ctx) + + def __dealloc__(self): + """Free and clear the context.""" + _pk.mbedtls_dhm_free(&self._ctx) + + property key_size: + """Return the size of the key, in bytes.""" + def __get__(self): + return _mpi.mbedtls_mpi_size(&self._ctx.P) + + property modulus: + """Return the prime modulus, P.""" + def __get__(self): + return _mpi.from_mpi(&self._ctx.P) + + property generator: + """Return the generator, G.""" + def __get__(self): + return _mpi.from_mpi(&self._ctx.G) + + property _secret: + """Return the secret (int).""" + def __get__(self): + return _mpi.from_mpi(&self._ctx.X) + + property shared_secret: + """The shared secret (int). + + The shared secret is 0 if the TLS handshake is not finished. + + """ + def __get__(self): + try: + return _mpi.from_mpi(&self._ctx.K) + except ValueError: + return _mpi.MPI() + + def generate_secret(self): + """Generate the shared secret.""" + cdef _mpi.MPI mpi + cdef unsigned char* output = malloc( + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + cdef size_t olen = 0 + if not output: + raise MemoryError() + try: + check_error(mbedtls_dhm_calc_secret( + &self._ctx, &output[0], _mpi.MBEDTLS_MPI_MAX_SIZE, &olen, + &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) + assert olen != 0 + mpi = _mpi.MPI() + _mpi.mbedtls_mpi_read_binary(&mpi._ctx, &output[0], olen) + return mpi + finally: + free(output) + + +cdef class DHServer(DHBase): + + """The server side of the DH key exchange.""" + + def generate(self): + """Generate a public key. + + Return: + bytes: A TLS ServerKeyExchange payload. + + """ + cdef unsigned char* output = malloc( + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + cdef size_t olen = 0 + if not output: + raise MemoryError() + try: + check_error(_pk.mbedtls_dhm_make_params( + &self._ctx, self.key_size, &output[0], &olen, + &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) + assert olen != 0 + return bytes(output[:olen]) + finally: + free(output) + + def import_CKE(self, const unsigned char[:] buffer): + """Read the ClientKeyExchange payload.""" + check_error(_pk.mbedtls_dhm_read_public( + &self._ctx, &buffer[0], buffer.size)) + + +cdef class DHClient(DHBase): + + """The client side of the DH key exchange.""" + + def generate(self): + """Generate the public key. + + Return: + bytes: The byte representation (big endian) of: G^X mod P. + + """ + cdef _mpi.MPI mpi + cdef unsigned char* output = malloc( + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + if not output: + raise MemoryError() + try: + check_error(_pk.mbedtls_dhm_make_public( + &self._ctx, self.key_size, &output[0], self.key_size, + &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) + mpi = _mpi.from_mpi(&self._ctx.GX) + return mpi.to_bytes( + _mpi.mbedtls_mpi_size(&mpi._ctx), "big") + finally: + free(output) + + def import_SKE(self, const unsigned char[:] buffer): + """Read the ServerKeyExchange payload.""" + cdef const unsigned char* first = &buffer[0] + cdef const unsigned char* end = &buffer[-1] + 1 + check_error(_pk.mbedtls_dhm_read_params( + &self._ctx, &first, end)) + + cdef class ECDHBase: """Base class to ECDH(E) key exchange: client and server. @@ -683,7 +823,7 @@ cdef class ECDHBase: def _has_private(self): """Return `True` if the key contains a valid private half.""" - return _mpi.mbedtls_mpi_cmp_mpi(&self._ctx.d, &_mpi.MPI(0)._ctx) != 0 + return _mpi.mbedtls_mpi_cmp_mpi(&self._ctx.d, &_mpi.MPI()._ctx) != 0 def _has_public(self): """Return `True` if the key contains a valid public half.""" @@ -695,19 +835,19 @@ cdef class ECDHBase: def generate_secret(self): """Generate the shared secret.""" - cdef _mpi.MPI mpi = _mpi.MPI(0) + cdef _mpi.MPI mpi = _mpi.MPI() cdef unsigned char* output = malloc( - _pk.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) cdef size_t olen = 0 if not output: raise MemoryError() try: check_error(mbedtls_ecdh_calc_secret( - &self._ctx, &olen, &output[0], _pk.MBEDTLS_MPI_MAX_SIZE, + &self._ctx, &olen, &output[0], _mpi.MBEDTLS_MPI_MAX_SIZE, &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) assert olen != 0 _mpi.mbedtls_mpi_read_binary(&mpi._ctx, &output[0], olen) - return long(mpi) + return mpi finally: free(output) @@ -719,9 +859,9 @@ cdef class ECDHBase: """ def __get__(self): try: - return long(_mpi.from_mpi(&self._ctx.z)) + return _mpi.from_mpi(&self._ctx.z) except ValueError: - return 0 + return _mpi.MPI() cdef class ECDHServer(ECDHBase): @@ -740,13 +880,13 @@ cdef class ECDHServer(ECDHBase): """ cdef unsigned char* output = malloc( - _pk.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) cdef size_t olen = 0 if not output: raise MemoryError() try: - check_error(mbedtls_ecdh_make_params( - &self._ctx, &olen, &output[0], _pk.MBEDTLS_MPI_MAX_SIZE, + check_error(_pk.mbedtls_ecdh_make_params( + &self._ctx, &olen, &output[0], _mpi.MBEDTLS_MPI_MAX_SIZE, &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) assert olen != 0 return bytes(output[:olen]) @@ -755,7 +895,7 @@ cdef class ECDHServer(ECDHBase): def import_CKE(self, const unsigned char[:] buffer): """Read the ClientKeyExchange payload.""" - check_error(mbedtls_ecdh_read_public( + check_error(_pk.mbedtls_ecdh_read_public( &self._ctx, &buffer[0], buffer.size)) @@ -775,13 +915,13 @@ cdef class ECDHClient(ECDHBase): """ cdef unsigned char* output = malloc( - _pk.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) + _mpi.MBEDTLS_MPI_MAX_SIZE * sizeof(unsigned char)) cdef size_t olen = 0 if not output: raise MemoryError() try: - check_error(mbedtls_ecdh_make_public( - &self._ctx, &olen, &output[0], _pk.MBEDTLS_MPI_MAX_SIZE, + check_error(_pk.mbedtls_ecdh_make_public( + &self._ctx, &olen, &output[0], _mpi.MBEDTLS_MPI_MAX_SIZE, &_random.mbedtls_ctr_drbg_random, &__rng._ctx)) assert olen != 0 return bytes(output[:olen]) @@ -792,5 +932,5 @@ cdef class ECDHClient(ECDHBase): """Read the ServerKeyExchange payload.""" cdef const unsigned char* first = &buffer[0] cdef const unsigned char* end = &buffer[-1] + 1 - check_error(mbedtls_ecdh_read_params( + check_error(_pk.mbedtls_ecdh_read_params( &self._ctx, &first, end)) diff --git a/mbedtls/x509.pyx b/mbedtls/x509.pyx index 40eecc0e..1a5f05f6 100644 --- a/mbedtls/x509.pyx +++ b/mbedtls/x509.pyx @@ -8,7 +8,7 @@ __license__ = "MIT License" from libc.stdlib cimport malloc, free cimport mbedtls.x509 as x509 -cimport mbedtls._mpi as _mpi +cimport mbedtls.mpi as _mpi cimport mbedtls.pk as _pk import base64 diff --git a/setup.py b/setup.py index 9888d276..2cf97673 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import sys from setuptools import setup, Extension -version = "0.10.0" +version = "0.11.0" download_url = "https://github.com/Synss/python-mbedtls/tarball/%s" % version diff --git a/tests/test_cipher.py b/tests/test_cipher.py index 1b601394..7c2ff79b 100644 --- a/tests/test_cipher.py +++ b/tests/test_cipher.py @@ -29,18 +29,18 @@ def test_get_supported_ciphers(): assert cl and set(cl).issubset(set(CIPHER_NAME)) -def test_wrong_size_raises_cipher_error(): - with pytest.raises(CipherError): +def test_wrong_size_raises_exception(): + with pytest.raises(MbedTLSError): Cipher(b"AES-512-ECB", b"", 0, b"") -def test_random_name_raises_cipher_error(): - with pytest.raises(CipherError): +def test_random_name_raises_exception(): + with pytest.raises(MbedTLSError): Cipher(b"RANDOM TEXT IS NOT A CIPHER", b"", 0, b"") -def test_zero_length_raises_cipher_error(): - with pytest.raises(CipherError): +def test_zero_length_raises_exception(): + with pytest.raises(MbedTLSError): Cipher(b"", b"", 0, b"") @@ -99,11 +99,11 @@ def test_module_level_key_size_variable(cipher): assert cipher.key_size == mod.key_size -def test_wrong_key_size_raises_invalid_key_size_error(cipher, randbytes): +def test_wrong_key_size_raises_exception(cipher, randbytes): mod = module_from_name(cipher.name) if mod.key_size is None: pytest.skip("module defines variable-length key") - with pytest.raises(InvalidKeyLengthError): + with pytest.raises(MbedTLSError): mod.new(randbytes(cipher.key_size) + b"\x00", cipher.mode, randbytes(cipher.iv_size)) @@ -219,19 +219,19 @@ def test_streaming_ciphers(cipher, randbytes): assert cipher.decrypt(cipher.encrypt(block)) == block -def test_fixed_block_size_ciphers_long_block_raise_ciphererror( +def test_fixed_block_size_ciphers_long_block_raise_exception( cipher, randbytes): if is_streaming(cipher): pytest.skip("streaming cipher") - with pytest.raises(CipherError): + with pytest.raises(MbedTLSError): block = randbytes(cipher.block_size) + randbytes(1) cipher.encrypt(block) -def test_fixed_block_size_ciphers_short_block_raise_ciphererror( +def test_fixed_block_size_ciphers_short_block_raise_exception( cipher, randbytes): if is_streaming(cipher): pytest.skip("streaming cipher") - with pytest.raises(CipherError): + with pytest.raises(MbedTLSError): block = randbytes(cipher.block_size)[1:] cipher.encrypt(block) diff --git a/tests/test_error.py b/tests/test_error.py new file mode 100644 index 00000000..cb7b621c --- /dev/null +++ b/tests/test_error.py @@ -0,0 +1,18 @@ +"""Unit test mbedtls.exceptions.""" + +import pytest + +from mbedtls.exceptions import check_error, MbedTLSError + + +@pytest.mark.parametrize( + "err, msg", + ((0x003C, "ENTROPY"), (0x1080, "PEM"), (0x2200, "X509"))) +def test_mbedtls_error(err, msg): + with pytest.raises(MbedTLSError, match=r"%s - .+" % msg) as exc: + check_error(-err) + + +def test_other_error(): + with pytest.raises(MbedTLSError, match="error message") as exc: + raise MbedTLSError(msg="error message") diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 6316c02c..f2e15ff6 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1,14 +1,147 @@ +import numbers from binascii import hexlify, unhexlify import pytest -from mbedtls._mpi import MPI +from mbedtls.mpi import MPI @pytest.mark.parametrize("value", (12, 2**32 - 1, 10**100)) def test_from_int(value): mpi = MPI.from_int(value) assert mpi == value + assert value == mpi + assert mpi == mpi + + +def test_is_integral(): + assert isinstance(MPI(42), numbers.Integral) + + +def test_prime(): + assert MPI.prime(512).is_prime() + + +def test_add(): + assert MPI(12) + MPI(12) == 24 + assert MPI(12) + 12 == 24 + assert 12 + MPI(12) == 24 + + +def test_sub(): + assert MPI(12) - MPI(5) == 7 + assert MPI(12) - 5 == 7 + assert 12 - MPI(5) == 7 + + +def test_mul(): + assert MPI(12) * MPI(2) == 24 + assert MPI(12) * 2 == 24 + assert 12 * MPI(2) == 24 + + +def test_eq_same_number_is_true(): + assert (MPI(12) == MPI(12)) is True + assert (MPI(12) == 12) is True + assert (12 == MPI(12)) is True + + +def test_eq_different_numbers_is_false(): + assert (MPI(12) == MPI(42)) is False + assert (MPI(12) == 42) is False + assert (12 == MPI(42)) is False + + +def test_neq_same_numbers_is_false(): + assert (MPI(12) != MPI(12)) is False + assert (MPI(12) != 12) is False + assert (12 != MPI(12)) is False + + +def test_neq_different_numbers_is_true(): + assert (MPI(12) != MPI(42)) is True + assert (MPI(12) != 42) is True + assert (12 != MPI(42)) is True + + +def test_lt_larger_number_is_true(): + assert (MPI(12) < MPI(42)) is True + assert (MPI(12) < 42) is True + assert (12 < MPI(42)) is True + + +def test_lt_smaller_number_is_false(): + assert (MPI(42) < MPI(12)) is False + assert (MPI(42) < 12) is False + assert (42 < MPI(12)) is False + + +def test_lt_same_number_is_false(): + assert (MPI(12) < MPI(12)) is False + assert (MPI(12) < 12) is False + assert (12 < MPI(12)) is False + + +def test_gt_larger_number_is_false(): + assert (MPI(12) > MPI(42)) is False + assert (MPI(12) > 42) is False + assert (12 > MPI(42)) is False + + +def test_gt_smaller_number_is_true(): + assert (MPI(42) > MPI(12)) is True + assert (MPI(42) > 12) is True + assert (42 > MPI(12)) is True + + +def test_gt_same_number_is_false(): + assert (MPI(12) > MPI(12)) is False + assert (MPI(12) > 12) is False + assert (12 > MPI(12)) is False + + +def test_le(): + assert (MPI(12) <= MPI(42)) is True + assert (MPI(12) <= MPI(12)) is True + assert (MPI(42) <= MPI(12)) is False + + +def test_ge(): + assert (MPI(42) >= MPI(12)) is True + assert (MPI(42) >= MPI(42)) is True + assert (MPI(12) >= MPI(42)) is False + + +def test_bool(): + assert bool(MPI(0)) is False + + +def test_float(): + assert float(MPI(12)) == 12.0 + + +def test_rshift(): + assert MPI(12) >> MPI(2) == 3 + assert MPI(12) >> 2 == 3 + assert 12 >> int(MPI(2)) == 3 + + +def test_lshift(): + assert MPI(12) << MPI(2) == 48 + assert MPI(12) << 2 == 48 + assert 12 << int(MPI(2)) == 48 + + +def test_floordiv(): + assert MPI(24) // MPI(2) == 12 + assert MPI(24) // 2 == 12 + assert 24 // MPI(2) == 12 + + +def test_mod(): + assert MPI(12) % MPI(10) == 2 + assert MPI(12) % 10 == 2 + assert 12 % MPI(10) == 2 @pytest.mark.parametrize("value", (12, 2**32 - 1, 10**100)) diff --git a/tests/test_pk.py b/tests/test_pk.py index 330b3429..4be6fa58 100644 --- a/tests/test_pk.py +++ b/tests/test_pk.py @@ -1,6 +1,7 @@ """Unit tests for mbedtls.pk.""" +import numbers from itertools import product from functools import partial from tempfile import TemporaryFile @@ -9,7 +10,6 @@ import mbedtls.hash as _hash from mbedtls.exceptions import * -from mbedtls.exceptions import _ErrorBase from mbedtls.pk import _type_from_name, _get_md_alg, CipherBase from mbedtls.pk import * @@ -157,9 +157,9 @@ def test_cipher_without_key(self): @pytest.mark.usefixtures("key") def test_public_value_accessor(self): pub = self.cipher.export_public_key("POINT") - assert isinstance(pub.x, long) - assert isinstance(pub.y, long) - assert isinstance(pub.z, long) + assert isinstance(pub.x, numbers.Integral) + assert isinstance(pub.y, numbers.Integral) + assert isinstance(pub.z, numbers.Integral) assert pub.x not in (0, pub.y, pub.z) assert pub.y not in (0, pub.x, pub.z) assert pub.z in (0, 1) @@ -167,7 +167,7 @@ def test_public_value_accessor(self): @pytest.mark.usefixtures("key") def test_private_value_accessor(self): prv = self.cipher.export_key("NUM") - assert isinstance(prv, long) + assert isinstance(prv, numbers.Integral) assert prv != 0 @@ -193,6 +193,29 @@ def test_exchange(self): assert srv_sec == cli_sec +class TestDH: + @pytest.fixture(autouse=True) + def _setup(self): + self.srv = DHServer(23, 5) + self.cli = DHClient(23, 5) + + def test_key_accessors_without_key(self): + for cipher in (self.srv, self.cli): + assert cipher.shared_secret == 0 + + def test_exchange(self): + ske = self.srv.generate() + self.cli.import_SKE(ske) + cke = self.cli.generate() + self.srv.import_CKE(cke) + + srv_sec = self.srv.generate_secret() + cli_sec = self.cli.generate_secret() + assert srv_sec == cli_sec + assert srv_sec == self.srv.shared_secret + assert cli_sec == self.cli.shared_secret + + class TestECDH: @pytest.fixture(autouse=True, params=get_supported_curves()) def _setup(self, request): diff --git a/tests/test_random.py b/tests/test_random.py index d39f421c..4a6d55b6 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -1,6 +1,7 @@ """Unit tests for mbedtls.random.""" # pylint: disable=missing-docstring +import random as _random # pylint: disable=import-error import mbedtls.random as _drbg @@ -8,7 +9,11 @@ import pytest -from mbedtls.exceptions import EntropyError +from mbedtls.exceptions import MbedTLSError + + +def sample(start, end, k=20): + return _random.sample(range(start, end), k) @pytest.fixture @@ -26,14 +31,14 @@ def test_entropy_gather(entropy): entropy.gather() -@pytest.mark.parametrize("length", range(64)) +@pytest.mark.parametrize("length", sample(0, 64)) def test_entropy_retrieve(entropy, length): assert len(entropy.retrieve(length)) == length @pytest.mark.parametrize("length", (100, )) -def test_entropy_retrieve_long_block_raises_entropyerror(entropy, length): - with pytest.raises(EntropyError): +def test_entropy_retrieve_long_block_raises_exception(entropy, length): + with pytest.raises(MbedTLSError): entropy.retrieve(length) @@ -70,11 +75,11 @@ def test_initial_values(random): assert random.token_bytes(8) != other.token_bytes(8) -@pytest.mark.parametrize("length", range(1024)) +@pytest.mark.parametrize("length", sample(0, 1024)) def test_token_bytes(random, length): assert len(random.token_bytes(length)) == length -@pytest.mark.parametrize("length", range(1024)) +@pytest.mark.parametrize("length", sample(0, 1024)) def test_token_hex(random, length): assert len(random.token_hex(length)) == 2 * length