From d1984d4625fae43ab6939be0955a3a0a03466295 Mon Sep 17 00:00:00 2001 From: Mattia Rizzolo Date: Wed, 22 Apr 2020 20:45:09 +0200 Subject: [PATCH] unbundle jwcrypto Signed-off-by: Mattia Rizzolo --- .gitlab-ci.yml | 1 + jwcrypto/README.md | 12 - jwcrypto/__init__.py | 0 jwcrypto/common.py | 106 ----- jwcrypto/jwa.py | 1072 ------------------------------------------ jwcrypto/jwe.py | 498 -------------------- jwcrypto/jwk.py | 875 ---------------------------------- jwcrypto/jws.py | 611 ------------------------ jwcrypto/jwt.py | 506 -------------------- 9 files changed, 1 insertion(+), 3680 deletions(-) delete mode 100644 jwcrypto/README.md delete mode 100644 jwcrypto/__init__.py delete mode 100644 jwcrypto/common.py delete mode 100644 jwcrypto/jwa.py delete mode 100644 jwcrypto/jwe.py delete mode 100644 jwcrypto/jwk.py delete mode 100644 jwcrypto/jws.py delete mode 100644 jwcrypto/jwt.py diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c7b8c85..00e366b 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -21,6 +21,7 @@ libjs-jquery-flot python3-model-mommy python3-requests-oauthlib + python3-jwcrypto $ADDITIONAL_PACKAGES script: - python3 manage.py test -v2 diff --git a/jwcrypto/README.md b/jwcrypto/README.md deleted file mode 100644 index dde0f90..0000000 --- a/jwcrypto/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# Local bundling of jwcrypto - -This is a copy of jwcrypto 0.6.0-2 from bullseye, bundled locally until a backport is available. - -See [#956560](https://bugs.debian.org/956560) to track this. - -Sources are at - -Debian package is at - -Copyright: 2015 JWCrypto Project Contributors -License: LGPL-3+ diff --git a/jwcrypto/__init__.py b/jwcrypto/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jwcrypto/common.py b/jwcrypto/common.py deleted file mode 100644 index af6be92..0000000 --- a/jwcrypto/common.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file - -import json -from base64 import urlsafe_b64decode, urlsafe_b64encode - - -# Padding stripping versions as described in -# RFC 7515 Appendix C - - -def base64url_encode(payload): - if not isinstance(payload, bytes): - payload = payload.encode('utf-8') - encode = urlsafe_b64encode(payload) - return encode.decode('utf-8').rstrip('=') - - -def base64url_decode(payload): - size = len(payload) % 4 - if size == 2: - payload += '==' - elif size == 3: - payload += '=' - elif size != 0: - raise ValueError('Invalid base64 string') - return urlsafe_b64decode(payload.encode('utf-8')) - - -# JSON encoding/decoding helpers with good defaults - -def json_encode(string): - if isinstance(string, bytes): - string = string.decode('utf-8') - return json.dumps(string, separators=(',', ':'), sort_keys=True) - - -def json_decode(string): - if isinstance(string, bytes): - string = string.decode('utf-8') - return json.loads(string) - - -class JWException(Exception): - pass - - -class InvalidJWAAlgorithm(JWException): - def __init__(self, message=None): - msg = 'Invalid JWA Algorithm name' - if message: - msg += ' (%s)' % message - super(InvalidJWAAlgorithm, self).__init__(msg) - - -class InvalidCEKeyLength(JWException): - """Invalid CEK Key Length. - - This exception is raised when a Content Encryption Key does not match - the required lenght. - """ - - def __init__(self, expected, obtained): - msg = 'Expected key of length %d bits, got %d' % (expected, obtained) - super(InvalidCEKeyLength, self).__init__(msg) - - -class InvalidJWEOperation(JWException): - """Invalid JWS Object. - - This exception is raised when a requested operation cannot - be execute due to unsatisfied conditions. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = message - else: - msg = 'Unknown Operation Failure' - if exception: - msg += ' {%s}' % repr(exception) - super(InvalidJWEOperation, self).__init__(msg) - - -class InvalidJWEKeyType(JWException): - """Invalid JWE Key Type. - - This exception is raised when the provided JWK Key does not match - the type required by the sepcified algorithm. - """ - - def __init__(self, expected, obtained): - msg = 'Expected key type %s, got %s' % (expected, obtained) - super(InvalidJWEKeyType, self).__init__(msg) - - -class InvalidJWEKeyLength(JWException): - """Invalid JWE Key Length. - - This exception is raised when the provided JWK Key does not match - the lenght required by the sepcified algorithm. - """ - - def __init__(self, expected, obtained): - msg = 'Expected key of lenght %d, got %d' % (expected, obtained) - super(InvalidJWEKeyLength, self).__init__(msg) diff --git a/jwcrypto/jwa.py b/jwcrypto/jwa.py deleted file mode 100644 index 5785295..0000000 --- a/jwcrypto/jwa.py +++ /dev/null @@ -1,1072 +0,0 @@ -# Copyright (C) 2016 JWCrypto Project Contributors - see LICENSE file - -import abc -import os -import struct -from binascii import hexlify, unhexlify - -from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import constant_time, hashes, hmac -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric import padding -from cryptography.hazmat.primitives.asymmetric import utils as ec_utils -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from cryptography.hazmat.primitives.keywrap import aes_key_unwrap, aes_key_wrap -from cryptography.hazmat.primitives.padding import PKCS7 - -import six - -from jwcrypto.common import InvalidCEKeyLength -from jwcrypto.common import InvalidJWAAlgorithm -from jwcrypto.common import InvalidJWEKeyLength -from jwcrypto.common import InvalidJWEKeyType -from jwcrypto.common import InvalidJWEOperation -from jwcrypto.common import base64url_decode, base64url_encode -from jwcrypto.common import json_decode -from jwcrypto.jwk import JWK - -# Implements RFC 7518 - JSON Web Algorithms (JWA) - - -@six.add_metaclass(abc.ABCMeta) -class JWAAlgorithm(object): - - @abc.abstractproperty - def name(self): - """The algorithm Name""" - pass - - @abc.abstractproperty - def description(self): - """A short description""" - pass - - @abc.abstractproperty - def keysize(self): - """The actual/recommended/minimum key size""" - pass - - @abc.abstractproperty - def algorithm_usage_location(self): - """One of 'alg', 'enc' or 'JWK'""" - pass - - @abc.abstractproperty - def algorithm_use(self): - """One of 'sig', 'kex', 'enc'""" - pass - - -def _bitsize(x): - return len(x) * 8 - - -def _inbytes(x): - return x // 8 - - -def _randombits(x): - if x % 8 != 0: - raise ValueError("lenght must be a multiple of 8") - return os.urandom(_inbytes(x)) - - -# Note: the number of bits should be a multiple of 16 -def _encode_int(n, bits): - e = '{:x}'.format(n) - ilen = ((bits + 7) // 8) * 2 # number of bytes rounded up times 2 bytes - return unhexlify(e.rjust(ilen, '0')[:ilen]) - - -def _decode_int(n): - return int(hexlify(n), 16) - - -class _RawJWS(object): - - def sign(self, key, payload): - raise NotImplementedError - - def verify(self, key, payload, signature): - raise NotImplementedError - - -class _RawHMAC(_RawJWS): - - def __init__(self, hashfn): - self.backend = default_backend() - self.hashfn = hashfn - - def _hmac_setup(self, key, payload): - h = hmac.HMAC(key, self.hashfn, backend=self.backend) - h.update(payload) - return h - - def sign(self, key, payload): - skey = base64url_decode(key.get_op_key('sign')) - h = self._hmac_setup(skey, payload) - return h.finalize() - - def verify(self, key, payload, signature): - vkey = base64url_decode(key.get_op_key('verify')) - h = self._hmac_setup(vkey, payload) - h.verify(signature) - - -class _RawRSA(_RawJWS): - def __init__(self, padfn, hashfn): - self.padfn = padfn - self.hashfn = hashfn - - def sign(self, key, payload): - skey = key.get_op_key('sign') - return skey.sign(payload, self.padfn, self.hashfn) - - def verify(self, key, payload, signature): - pkey = key.get_op_key('verify') - pkey.verify(signature, payload, self.padfn, self.hashfn) - - -class _RawEC(_RawJWS): - def __init__(self, curve, hashfn): - self._curve = curve - self.hashfn = hashfn - - @property - def curve(self): - return self._curve - - def sign(self, key, payload): - skey = key.get_op_key('sign', self._curve) - signature = skey.sign(payload, ec.ECDSA(self.hashfn)) - r, s = ec_utils.decode_dss_signature(signature) - size = key.get_curve(self._curve).key_size - return _encode_int(r, size) + _encode_int(s, size) - - def verify(self, key, payload, signature): - pkey = key.get_op_key('verify', self._curve) - r = signature[:len(signature) // 2] - s = signature[len(signature) // 2:] - enc_signature = ec_utils.encode_dss_signature( - int(hexlify(r), 16), int(hexlify(s), 16)) - pkey.verify(enc_signature, payload, ec.ECDSA(self.hashfn)) - - -class _RawNone(_RawJWS): - - def sign(self, key, payload): - return '' - - def verify(self, key, payload, signature): - raise InvalidSignature('The "none" signature cannot be verified') - - -class _HS256(_RawHMAC, JWAAlgorithm): - - name = "HS256" - description = "HMAC using SHA-256" - keysize = 256 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_HS256, self).__init__(hashes.SHA256()) - - -class _HS384(_RawHMAC, JWAAlgorithm): - - name = "HS384" - description = "HMAC using SHA-384" - keysize = 384 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_HS384, self).__init__(hashes.SHA384()) - - -class _HS512(_RawHMAC, JWAAlgorithm): - - name = "HS512" - description = "HMAC using SHA-512" - keysize = 512 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_HS512, self).__init__(hashes.SHA512()) - - -class _RS256(_RawRSA, JWAAlgorithm): - - name = "RS256" - description = "RSASSA-PKCS1-v1_5 using SHA-256" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_RS256, self).__init__(padding.PKCS1v15(), hashes.SHA256()) - - -class _RS384(_RawRSA, JWAAlgorithm): - - name = "RS384" - description = "RSASSA-PKCS1-v1_5 using SHA-384" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_RS384, self).__init__(padding.PKCS1v15(), hashes.SHA384()) - - -class _RS512(_RawRSA, JWAAlgorithm): - - name = "RS512" - description = "RSASSA-PKCS1-v1_5 using SHA-512" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_RS512, self).__init__(padding.PKCS1v15(), hashes.SHA512()) - - -class _ES256(_RawEC, JWAAlgorithm): - - name = "ES256" - description = "ECDSA using P-256 and SHA-256" - keysize = 256 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_ES256, self).__init__('P-256', hashes.SHA256()) - - -class _ES384(_RawEC, JWAAlgorithm): - - name = "ES384" - description = "ECDSA using P-384 and SHA-384" - keysize = 384 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_ES384, self).__init__('P-384', hashes.SHA384()) - - -class _ES512(_RawEC, JWAAlgorithm): - - name = "ES512" - description = "ECDSA using P-521 and SHA-512" - keysize = 512 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - super(_ES512, self).__init__('P-521', hashes.SHA512()) - - -class _PS256(_RawRSA, JWAAlgorithm): - - name = "PS256" - description = "RSASSA-PSS using SHA-256 and MGF1 with SHA-256" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - padfn = padding.PSS(padding.MGF1(hashes.SHA256()), - hashes.SHA256.digest_size) - super(_PS256, self).__init__(padfn, hashes.SHA256()) - - -class _PS384(_RawRSA, JWAAlgorithm): - - name = "PS384" - description = "RSASSA-PSS using SHA-384 and MGF1 with SHA-384" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - padfn = padding.PSS(padding.MGF1(hashes.SHA384()), - hashes.SHA384.digest_size) - super(_PS384, self).__init__(padfn, hashes.SHA384()) - - -class _PS512(_RawRSA, JWAAlgorithm): - - name = "PS512" - description = "RSASSA-PSS using SHA-512 and MGF1 with SHA-512" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - def __init__(self): - padfn = padding.PSS(padding.MGF1(hashes.SHA512()), - hashes.SHA512.digest_size) - super(_PS512, self).__init__(padfn, hashes.SHA512()) - - -class _None(_RawNone, JWAAlgorithm): - - name = "none" - description = "No digital signature or MAC performed" - keysize = 0 - algorithm_usage_location = 'alg' - algorithm_use = 'sig' - - -class _RawKeyMgmt(object): - - def wrap(self, key, bitsize, cek, headers): - raise NotImplementedError - - def unwrap(self, key, bitsize, ek, headers): - raise NotImplementedError - - -class _RSA(_RawKeyMgmt): - - def __init__(self, padfn): - self.padfn = padfn - - def _check_key(self, key): - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - if key.key_type != 'RSA': - raise InvalidJWEKeyType('RSA', key.key_type) - - # FIXME: get key size and insure > 2048 bits - def wrap(self, key, bitsize, cek, headers): - self._check_key(key) - if not cek: - cek = _randombits(bitsize) - rk = key.get_op_key('wrapKey') - ek = rk.encrypt(cek, self.padfn) - return {'cek': cek, 'ek': ek} - - def unwrap(self, key, bitsize, ek, headers): - self._check_key(key) - rk = key.get_op_key('decrypt') - cek = rk.decrypt(ek, self.padfn) - if _bitsize(cek) != bitsize: - raise InvalidJWEKeyLength(bitsize, _bitsize(cek)) - return cek - - -class _Rsa15(_RSA, JWAAlgorithm): - - name = 'RSA1_5' - description = "RSAES-PKCS1-v1_5" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - def __init__(self): - super(_Rsa15, self).__init__(padding.PKCS1v15()) - - def unwrap(self, key, bitsize, ek, headers): - self._check_key(key) - # Address MMA attack by implementing RFC 3218 - 2.3.2. Random Filling - # provides a random cek that will cause the decryption engine to - # run to the end, but will fail decryption later. - - # always generate a random cek so we spend roughly the - # same time as in the exception side of the branch - cek = _randombits(bitsize) - try: - cek = super(_Rsa15, self).unwrap(key, bitsize, ek, headers) - # always raise so we always run through the exception handling - # code in all cases - raise Exception('Dummy') - except Exception: # pylint: disable=broad-except - return cek - - -class _RsaOaep(_RSA, JWAAlgorithm): - - name = 'RSA-OAEP' - description = "RSAES OAEP using default parameters" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - def __init__(self): - super(_RsaOaep, self).__init__( - padding.OAEP(padding.MGF1(hashes.SHA1()), - hashes.SHA1(), None)) - - -class _RsaOaep256(_RSA, JWAAlgorithm): # noqa: ignore=N801 - - name = 'RSA-OAEP-256' - description = "RSAES OAEP using SHA-256 and MGF1 with SHA-256" - keysize = 2048 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - def __init__(self): - super(_RsaOaep256, self).__init__( - padding.OAEP(padding.MGF1(hashes.SHA256()), - hashes.SHA256(), None)) - - -class _AesKw(_RawKeyMgmt): - - keysize = None - - def __init__(self): - self.backend = default_backend() - - def _get_key(self, key, op): - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - if key.key_type != 'oct': - raise InvalidJWEKeyType('oct', key.key_type) - rk = base64url_decode(key.get_op_key(op)) - if _bitsize(rk) != self.keysize: - raise InvalidJWEKeyLength(self.keysize, _bitsize(rk)) - return rk - - def wrap(self, key, bitsize, cek, headers): - rk = self._get_key(key, 'encrypt') - if not cek: - cek = _randombits(bitsize) - - ek = aes_key_wrap(rk, cek, default_backend()) - - return {'cek': cek, 'ek': ek} - - def unwrap(self, key, bitsize, ek, headers): - rk = self._get_key(key, 'decrypt') - - cek = aes_key_unwrap(rk, ek, default_backend()) - if _bitsize(cek) != bitsize: - raise InvalidJWEKeyLength(bitsize, _bitsize(cek)) - return cek - - -class _A128KW(_AesKw, JWAAlgorithm): - - name = 'A128KW' - description = "AES Key Wrap using 128-bit key" - keysize = 128 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _A192KW(_AesKw, JWAAlgorithm): - - name = 'A192KW' - description = "AES Key Wrap using 192-bit key" - keysize = 192 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _A256KW(_AesKw, JWAAlgorithm): - - name = 'A256KW' - description = "AES Key Wrap using 256-bit key" - keysize = 256 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _AesGcmKw(_RawKeyMgmt): - - keysize = None - - def __init__(self): - self.backend = default_backend() - - def _get_key(self, key, op): - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - if key.key_type != 'oct': - raise InvalidJWEKeyType('oct', key.key_type) - rk = base64url_decode(key.get_op_key(op)) - if _bitsize(rk) != self.keysize: - raise InvalidJWEKeyLength(self.keysize, _bitsize(rk)) - return rk - - def wrap(self, key, bitsize, cek, headers): - rk = self._get_key(key, 'encrypt') - if not cek: - cek = _randombits(bitsize) - - iv = _randombits(96) - cipher = Cipher(algorithms.AES(rk), modes.GCM(iv), - backend=self.backend) - encryptor = cipher.encryptor() - ek = encryptor.update(cek) + encryptor.finalize() - - tag = encryptor.tag - return {'cek': cek, 'ek': ek, - 'header': {'iv': base64url_encode(iv), - 'tag': base64url_encode(tag)}} - - def unwrap(self, key, bitsize, ek, headers): - rk = self._get_key(key, 'decrypt') - - if 'iv' not in headers: - raise ValueError('Invalid Header, missing "iv" parameter') - iv = base64url_decode(headers['iv']) - if 'tag' not in headers: - raise ValueError('Invalid Header, missing "tag" parameter') - tag = base64url_decode(headers['tag']) - - cipher = Cipher(algorithms.AES(rk), modes.GCM(iv, tag), - backend=self.backend) - decryptor = cipher.decryptor() - cek = decryptor.update(ek) + decryptor.finalize() - if _bitsize(cek) != bitsize: - raise InvalidJWEKeyLength(bitsize, _bitsize(cek)) - return cek - - -class _A128GcmKw(_AesGcmKw, JWAAlgorithm): - - name = 'A128GCMKW' - description = "Key wrapping with AES GCM using 128-bit key" - keysize = 128 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _A192GcmKw(_AesGcmKw, JWAAlgorithm): - - name = 'A192GCMKW' - description = "Key wrapping with AES GCM using 192-bit key" - keysize = 192 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _A256GcmKw(_AesGcmKw, JWAAlgorithm): - - name = 'A256GCMKW' - description = "Key wrapping with AES GCM using 256-bit key" - keysize = 256 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _Pbes2HsAesKw(_RawKeyMgmt): - - name = None - keysize = None - hashsize = None - - def __init__(self): - self.backend = default_backend() - self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW} - - def _get_key(self, alg, key, p2s, p2c): - if isinstance(key, bytes): - plain = key - else: - plain = key.encode('utf8') - salt = bytes(self.name.encode('utf8')) + b'\x00' + p2s - - if self.hashsize == 256: - hashalg = hashes.SHA256() - elif self.hashsize == 384: - hashalg = hashes.SHA384() - elif self.hashsize == 512: - hashalg = hashes.SHA512() - else: - raise ValueError('Unknown Hash Size') - - kdf = PBKDF2HMAC(algorithm=hashalg, length=_inbytes(self.keysize), - salt=salt, iterations=p2c, backend=self.backend) - rk = kdf.derive(plain) - if _bitsize(rk) != self.keysize: - raise InvalidJWEKeyLength(self.keysize, len(rk)) - return JWK(kty="oct", use="enc", k=base64url_encode(rk)) - - def wrap(self, key, bitsize, cek, headers): - p2s = _randombits(128) - p2c = 8192 - kek = self._get_key(headers['alg'], key, p2s, p2c) - - aeskw = self.aeskwmap[self.keysize]() - ret = aeskw.wrap(kek, bitsize, cek, headers) - ret['header'] = {'p2s': base64url_encode(p2s), 'p2c': p2c} - return ret - - def unwrap(self, key, bitsize, ek, headers): - if 'p2s' not in headers: - raise ValueError('Invalid Header, missing "p2s" parameter') - if 'p2c' not in headers: - raise ValueError('Invalid Header, missing "p2c" parameter') - p2s = base64url_decode(headers['p2s']) - p2c = headers['p2c'] - kek = self._get_key(headers['alg'], key, p2s, p2c) - - aeskw = self.aeskwmap[self.keysize]() - return aeskw.unwrap(kek, bitsize, ek, headers) - - -class _Pbes2Hs256A128Kw(_Pbes2HsAesKw, JWAAlgorithm): - - name = 'PBES2-HS256+A128KW' - description = 'PBES2 with HMAC SHA-256 and "A128KW" wrapping' - keysize = 128 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - hashsize = 256 - - -class _Pbes2Hs384A192Kw(_Pbes2HsAesKw, JWAAlgorithm): - - name = 'PBES2-HS384+A192KW' - description = 'PBES2 with HMAC SHA-384 and "A192KW" wrapping' - keysize = 192 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - hashsize = 384 - - -class _Pbes2Hs512A256Kw(_Pbes2HsAesKw, JWAAlgorithm): - - name = 'PBES2-HS512+A256KW' - description = 'PBES2 with HMAC SHA-512 and "A256KW" wrapping' - keysize = 256 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - hashsize = 512 - - -class _Direct(_RawKeyMgmt, JWAAlgorithm): - - name = 'dir' - description = "Direct use of a shared symmetric key" - keysize = 128 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - def _check_key(self, key): - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - if key.key_type != 'oct': - raise InvalidJWEKeyType('oct', key.key_type) - - def wrap(self, key, bitsize, cek, headers): - self._check_key(key) - if cek: - return (cek, None) - k = base64url_decode(key.get_op_key('encrypt')) - if _bitsize(k) != bitsize: - raise InvalidCEKeyLength(bitsize, _bitsize(k)) - return {'cek': k} - - def unwrap(self, key, bitsize, ek, headers): - self._check_key(key) - if ek != b'': - raise ValueError('Invalid Encryption Key.') - cek = base64url_decode(key.get_op_key('decrypt')) - if _bitsize(cek) != bitsize: - raise InvalidJWEKeyLength(bitsize, _bitsize(cek)) - return cek - - -class _EcdhEs(_RawKeyMgmt, JWAAlgorithm): - - name = 'ECDH-ES' - description = "ECDH-ES using Concat KDF" - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - keysize = None - - def __init__(self): - self.backend = default_backend() - self.aeskwmap = {128: _A128KW, 192: _A192KW, 256: _A256KW} - - def _check_key(self, key): - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - if key.key_type != 'EC': - raise InvalidJWEKeyType('EC', key.key_type) - - def _derive(self, privkey, pubkey, alg, bitsize, headers): - # OtherInfo is defined in NIST SP 56A 5.8.1.2.1 - - # AlgorithmID - otherinfo = struct.pack('>I', len(alg)) - otherinfo += bytes(alg.encode('utf8')) - - # PartyUInfo - apu = base64url_decode(headers['apu']) if 'apu' in headers else b'' - otherinfo += struct.pack('>I', len(apu)) - otherinfo += apu - - # PartyVInfo - apv = base64url_decode(headers['apv']) if 'apv' in headers else b'' - otherinfo += struct.pack('>I', len(apv)) - otherinfo += apv - - # SuppPubInfo - otherinfo += struct.pack('>I', bitsize) - - # no SuppPrivInfo - - shared_key = privkey.exchange(ec.ECDH(), pubkey) - ckdf = ConcatKDFHash(algorithm=hashes.SHA256(), - length=_inbytes(bitsize), - otherinfo=otherinfo, - backend=self.backend) - return ckdf.derive(shared_key) - - def wrap(self, key, bitsize, cek, headers): - self._check_key(key) - dk_size = self.keysize - if self.keysize is None: - if cek is not None: - raise InvalidJWEOperation('ECDH-ES cannot use an existing CEK') - alg = headers['enc'] - dk_size = bitsize - else: - alg = headers['alg'] - - epk = JWK.generate(kty=key.key_type, crv=key.key_curve) - dk = self._derive(epk.get_op_key('unwrapKey'), - key.get_op_key('wrapKey'), - alg, dk_size, headers) - - if self.keysize is None: - ret = {'cek': dk} - else: - aeskw = self.aeskwmap[self.keysize]() - kek = JWK(kty="oct", use="enc", k=base64url_encode(dk)) - ret = aeskw.wrap(kek, bitsize, cek, headers) - - ret['header'] = {'epk': json_decode(epk.export_public())} - return ret - - def unwrap(self, key, bitsize, ek, headers): - if 'epk' not in headers: - raise ValueError('Invalid Header, missing "epk" parameter') - self._check_key(key) - dk_size = self.keysize - if self.keysize is None: - alg = headers['enc'] - dk_size = bitsize - else: - alg = headers['alg'] - - epk = JWK(**headers['epk']) - dk = self._derive(key.get_op_key('unwrapKey'), - epk.get_op_key('wrapKey'), - alg, dk_size, headers) - if self.keysize is None: - return dk - else: - aeskw = self.aeskwmap[self.keysize]() - kek = JWK(kty="oct", use="enc", k=base64url_encode(dk)) - cek = aeskw.unwrap(kek, bitsize, ek, headers) - return cek - - -class _EcdhEsAes128Kw(_EcdhEs): - - name = 'ECDH-ES+A128KW' - description = 'ECDH-ES using Concat KDF and "A128KW" wrapping' - keysize = 128 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _EcdhEsAes192Kw(_EcdhEs): - - name = 'ECDH-ES+A192KW' - description = 'ECDH-ES using Concat KDF and "A192KW" wrapping' - keysize = 192 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _EcdhEsAes256Kw(_EcdhEs): - - name = 'ECDH-ES+A256KW' - description = 'ECDH-ES using Concat KDF and "A256KW" wrapping' - keysize = 256 - algorithm_usage_location = 'alg' - algorithm_use = 'kex' - - -class _RawJWE(object): - - def encrypt(self, k, a, m): - raise NotImplementedError - - def decrypt(self, k, a, iv, e, t): - raise NotImplementedError - - -class _AesCbcHmacSha2(_RawJWE): - - keysize = None - - def __init__(self, hashfn): - self.backend = default_backend() - self.hashfn = hashfn - self.blocksize = algorithms.AES.block_size - self.wrap_key_size = self.keysize * 2 - - def _mac(self, k, a, iv, e): - al = _encode_int(_bitsize(a), 64) - h = hmac.HMAC(k, self.hashfn, backend=self.backend) - h.update(a) - h.update(iv) - h.update(e) - h.update(al) - m = h.finalize() - return m[:_inbytes(self.keysize)] - - # RFC 7518 - 5.2.2 - def encrypt(self, k, a, m): - """ Encrypt according to the selected encryption and hashing - functions. - - :param k: Encryption key (optional) - :param a: Additional Authentication Data - :param m: Plaintext - - Returns a dictionary with the computed data. - """ - hkey = k[:_inbytes(self.keysize)] - ekey = k[_inbytes(self.keysize):] - - # encrypt - iv = _randombits(self.blocksize) - cipher = Cipher(algorithms.AES(ekey), modes.CBC(iv), - backend=self.backend) - encryptor = cipher.encryptor() - padder = PKCS7(self.blocksize).padder() - padded_data = padder.update(m) + padder.finalize() - e = encryptor.update(padded_data) + encryptor.finalize() - - # mac - t = self._mac(hkey, a, iv, e) - - return (iv, e, t) - - def decrypt(self, k, a, iv, e, t): - """ Decrypt according to the selected encryption and hashing - functions. - :param k: Encryption key (optional) - :param a: Additional Authenticated Data - :param iv: Initialization Vector - :param e: Ciphertext - :param t: Authentication Tag - - Returns plaintext or raises an error - """ - hkey = k[:_inbytes(self.keysize)] - dkey = k[_inbytes(self.keysize):] - - # verify mac - if not constant_time.bytes_eq(t, self._mac(hkey, a, iv, e)): - raise InvalidSignature('Failed to verify MAC') - - # decrypt - cipher = Cipher(algorithms.AES(dkey), modes.CBC(iv), - backend=self.backend) - decryptor = cipher.decryptor() - d = decryptor.update(e) + decryptor.finalize() - unpadder = PKCS7(self.blocksize).unpadder() - return unpadder.update(d) + unpadder.finalize() - - -class _A128CbcHs256(_AesCbcHmacSha2, JWAAlgorithm): - - name = 'A128CBC-HS256' - description = "AES_128_CBC_HMAC_SHA_256 authenticated" - keysize = 128 - algorithm_usage_location = 'enc' - algorithm_use = 'enc' - - def __init__(self): - super(_A128CbcHs256, self).__init__(hashes.SHA256()) - - -class _A192CbcHs384(_AesCbcHmacSha2, JWAAlgorithm): - - name = 'A192CBC-HS384' - description = "AES_192_CBC_HMAC_SHA_384 authenticated" - keysize = 192 - algorithm_usage_location = 'enc' - algorithm_use = 'enc' - - def __init__(self): - super(_A192CbcHs384, self).__init__(hashes.SHA384()) - - -class _A256CbcHs512(_AesCbcHmacSha2, JWAAlgorithm): - - name = 'A256CBC-HS512' - description = "AES_256_CBC_HMAC_SHA_512 authenticated" - keysize = 256 - algorithm_usage_location = 'enc' - algorithm_use = 'enc' - - def __init__(self): - super(_A256CbcHs512, self).__init__(hashes.SHA512()) - - -class _AesGcm(_RawJWE): - - keysize = None - - def __init__(self): - self.backend = default_backend() - self.wrap_key_size = self.keysize - - # RFC 7518 - 5.3 - def encrypt(self, k, a, m): - """ Encrypt accoriding to the selected encryption and hashing - functions. - - :param k: Encryption key (optional) - :param a: Additional Authentication Data - :param m: Plaintext - - Returns a dictionary with the computed data. - """ - iv = _randombits(96) - cipher = Cipher(algorithms.AES(k), modes.GCM(iv), - backend=self.backend) - encryptor = cipher.encryptor() - encryptor.authenticate_additional_data(a) - e = encryptor.update(m) + encryptor.finalize() - - return (iv, e, encryptor.tag) - - def decrypt(self, k, a, iv, e, t): - """ Decrypt accoriding to the selected encryption and hashing - functions. - :param k: Encryption key (optional) - :param a: Additional Authenticated Data - :param iv: Initialization Vector - :param e: Ciphertext - :param t: Authentication Tag - - Returns plaintext or raises an error - """ - cipher = Cipher(algorithms.AES(k), modes.GCM(iv, t), - backend=self.backend) - decryptor = cipher.decryptor() - decryptor.authenticate_additional_data(a) - return decryptor.update(e) + decryptor.finalize() - - -class _A128Gcm(_AesGcm, JWAAlgorithm): - - name = 'A128GCM' - description = "AES GCM using 128-bit key" - keysize = 128 - algorithm_usage_location = 'enc' - algorithm_use = 'enc' - - -class _A192Gcm(_AesGcm, JWAAlgorithm): - - name = 'A192GCM' - description = "AES GCM using 192-bit key" - keysize = 192 - algorithm_usage_location = 'enc' - algorithm_use = 'enc' - - -class _A256Gcm(_AesGcm, JWAAlgorithm): - - name = 'A256GCM' - description = "AES GCM using 256-bit key" - keysize = 256 - algorithm_usage_location = 'enc' - algorithm_use = 'enc' - - -class JWA(object): - """JWA Signing Algorithms. - - This class provides access to all JWA algorithms. - """ - - algorithms_registry = { - 'HS256': _HS256, - 'HS384': _HS384, - 'HS512': _HS512, - 'RS256': _RS256, - 'RS384': _RS384, - 'RS512': _RS512, - 'ES256': _ES256, - 'ES384': _ES384, - 'ES512': _ES512, - 'PS256': _PS256, - 'PS384': _PS384, - 'PS512': _PS512, - 'none': _None, - 'RSA1_5': _Rsa15, - 'RSA-OAEP': _RsaOaep, - 'RSA-OAEP-256': _RsaOaep256, - 'A128KW': _A128KW, - 'A192KW': _A192KW, - 'A256KW': _A256KW, - 'dir': _Direct, - 'ECDH-ES': _EcdhEs, - 'ECDH-ES+A128KW': _EcdhEsAes128Kw, - 'ECDH-ES+A192KW': _EcdhEsAes192Kw, - 'ECDH-ES+A256KW': _EcdhEsAes256Kw, - 'A128GCMKW': _A128GcmKw, - 'A192GCMKW': _A192GcmKw, - 'A256GCMKW': _A256GcmKw, - 'PBES2-HS256+A128KW': _Pbes2Hs256A128Kw, - 'PBES2-HS384+A192KW': _Pbes2Hs384A192Kw, - 'PBES2-HS512+A256KW': _Pbes2Hs512A256Kw, - 'A128CBC-HS256': _A128CbcHs256, - 'A192CBC-HS384': _A192CbcHs384, - 'A256CBC-HS512': _A256CbcHs512, - 'A128GCM': _A128Gcm, - 'A192GCM': _A192Gcm, - 'A256GCM': _A256Gcm - } - - @classmethod - def instantiate_alg(cls, name, use=None): - alg = cls.algorithms_registry[name] - if use is not None and alg.algorithm_use != use: - raise KeyError - return alg() - - @classmethod - def signing_alg(cls, name): - try: - return cls.instantiate_alg(name, use='sig') - except KeyError: - raise InvalidJWAAlgorithm( - '%s is not a valid Signign algorithm name' % name) - - @classmethod - def keymgmt_alg(cls, name): - try: - return cls.instantiate_alg(name, use='kex') - except KeyError: - raise InvalidJWAAlgorithm( - '%s is not a valid Key Management algorithm name' % name) - - @classmethod - def encryption_alg(cls, name): - try: - return cls.instantiate_alg(name, use='enc') - except KeyError: - raise InvalidJWAAlgorithm( - '%s is not a valid Encryption algorithm name' % name) diff --git a/jwcrypto/jwe.py b/jwcrypto/jwe.py deleted file mode 100644 index b4c0f57..0000000 --- a/jwcrypto/jwe.py +++ /dev/null @@ -1,498 +0,0 @@ -# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file - -import zlib - -from jwcrypto import common -from jwcrypto.common import JWException -from jwcrypto.common import base64url_decode, base64url_encode -from jwcrypto.common import json_decode, json_encode -from jwcrypto.jwa import JWA - - -# RFC 7516 - 4.1 -# name: (description, supported?) -JWEHeaderRegistry = {'alg': ('Algorithm', True), - 'enc': ('Encryption Algorithm', True), - 'zip': ('Compression Algorithm', True), - 'jku': ('JWK Set URL', False), - 'jwk': ('JSON Web Key', False), - 'kid': ('Key ID', True), - 'x5u': ('X.509 URL', False), - 'x5c': ('X.509 Certificate Chain', False), - 'x5t': ('X.509 Certificate SHA-1 Thumbprint', False), - 'x5t#S256': ('X.509 Certificate SHA-256 Thumbprint', - False), - 'typ': ('Type', True), - 'cty': ('Content Type', True), - 'crit': ('Critical', True)} -"""Registry of valid header parameters""" - -default_allowed_algs = [ - # Key Management Algorithms - 'RSA1_5', 'RSA-OAEP', 'RSA-OAEP-256', - 'A128KW', 'A192KW', 'A256KW', - 'dir', - 'ECDH-ES', 'ECDH-ES+A128KW', 'ECDH-ES+A192KW', 'ECDH-ES+A256KW', - 'A128GCMKW', 'A192GCMKW', 'A256GCMKW', - 'PBES2-HS256+A128KW', 'PBES2-HS384+A192KW', 'PBES2-HS512+A256KW', - # Content Encryption Algoritms - 'A128CBC-HS256', 'A192CBC-HS384', 'A256CBC-HS512', - 'A128GCM', 'A192GCM', 'A256GCM'] -"""Default allowed algorithms""" - - -class InvalidJWEData(JWException): - """Invalid JWE Object. - - This exception is raised when the JWE Object is invalid and/or - improperly formatted. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = message - else: - msg = 'Unknown Data Verification Failure' - if exception: - msg += ' {%s}' % str(exception) - super(InvalidJWEData, self).__init__(msg) - - -# These have been moved to jwcrypto.common, maintain here for backwards compat -InvalidCEKeyLength = common.InvalidCEKeyLength -InvalidJWEKeyLength = common.InvalidJWEKeyLength -InvalidJWEKeyType = common.InvalidJWEKeyType -InvalidJWEOperation = common.InvalidJWEOperation - - -class JWE(object): - """JSON Web Encryption object - - This object represent a JWE token. - """ - - def __init__(self, plaintext=None, protected=None, unprotected=None, - aad=None, algs=None, recipient=None, header=None): - """Creates a JWE token. - - :param plaintext(bytes): An arbitrary plaintext to be encrypted. - :param protected: A JSON string with the protected header. - :param unprotected: A JSON string with the shared unprotected header. - :param aad(bytes): Arbitrary additional authenticated data - :param algs: An optional list of allowed algorithms - :param recipient: An optional, default recipient key - :param header: An optional header for the default recipient - """ - self._allowed_algs = None - self.objects = dict() - self.plaintext = None - if plaintext is not None: - if isinstance(plaintext, bytes): - self.plaintext = plaintext - else: - self.plaintext = plaintext.encode('utf-8') - self.cek = None - self.decryptlog = None - if aad: - self.objects['aad'] = aad - if protected: - if isinstance(protected, dict): - protected = json_encode(protected) - else: - json_decode(protected) # check header encoding - self.objects['protected'] = protected - if unprotected: - if isinstance(unprotected, dict): - unprotected = json_encode(unprotected) - else: - json_decode(unprotected) # check header encoding - self.objects['unprotected'] = unprotected - if algs: - self._allowed_algs = algs - - if recipient: - self.add_recipient(recipient, header=header) - elif header: - raise ValueError('Header is allowed only with default recipient') - - def _jwa_keymgmt(self, name): - allowed = self._allowed_algs or default_allowed_algs - if name not in allowed: - raise InvalidJWEOperation('Algorithm not allowed') - return JWA.keymgmt_alg(name) - - def _jwa_enc(self, name): - allowed = self._allowed_algs or default_allowed_algs - if name not in allowed: - raise InvalidJWEOperation('Algorithm not allowed') - return JWA.encryption_alg(name) - - @property - def allowed_algs(self): - """Allowed algorithms. - - The list of allowed algorithms. - Can be changed by setting a list of algorithm names. - """ - - if self._allowed_algs: - return self._allowed_algs - else: - return default_allowed_algs - - @allowed_algs.setter - def allowed_algs(self, algs): - if not isinstance(algs, list): - raise TypeError('Allowed Algs must be a list') - self._allowed_algs = algs - - def _merge_headers(self, h1, h2): - for k in list(h1.keys()): - if k in h2: - raise InvalidJWEData('Duplicate header: "%s"' % k) - h1.update(h2) - return h1 - - def _get_jose_header(self, header=None): - jh = dict() - if 'protected' in self.objects: - ph = json_decode(self.objects['protected']) - jh = self._merge_headers(jh, ph) - if 'unprotected' in self.objects: - uh = json_decode(self.objects['unprotected']) - jh = self._merge_headers(jh, uh) - if header: - rh = json_decode(header) - jh = self._merge_headers(jh, rh) - return jh - - def _get_alg_enc_from_headers(self, jh): - algname = jh.get('alg', None) - if algname is None: - raise InvalidJWEData('Missing "alg" from headers') - alg = self._jwa_keymgmt(algname) - encname = jh.get('enc', None) - if encname is None: - raise InvalidJWEData('Missing "enc" from headers') - enc = self._jwa_enc(encname) - return alg, enc - - def _encrypt(self, alg, enc, jh): - aad = base64url_encode(self.objects.get('protected', '')) - if 'aad' in self.objects: - aad += '.' + base64url_encode(self.objects['aad']) - aad = aad.encode('utf-8') - - compress = jh.get('zip', None) - if compress == 'DEF': - data = zlib.compress(self.plaintext)[2:-4] - elif compress is None: - data = self.plaintext - else: - raise ValueError('Unknown compression') - - iv, ciphertext, tag = enc.encrypt(self.cek, aad, data) - self.objects['iv'] = iv - self.objects['ciphertext'] = ciphertext - self.objects['tag'] = tag - - def add_recipient(self, key, header=None): - """Encrypt the plaintext with the given key. - - :param key: A JWK key or password of appropriate type for the 'alg' - provided in the JOSE Headers. - :param header: A JSON string representing the per-recipient header. - - :raises ValueError: if the plaintext is missing or not of type bytes. - :raises ValueError: if the compression type is unknown. - :raises InvalidJWAAlgorithm: if the 'alg' provided in the JOSE - headers is missing or unknown, or otherwise not implemented. - """ - if self.plaintext is None: - raise ValueError('Missing plaintext') - if not isinstance(self.plaintext, bytes): - raise ValueError("Plaintext must be 'bytes'") - - if isinstance(header, dict): - header = json_encode(header) - - jh = self._get_jose_header(header) - alg, enc = self._get_alg_enc_from_headers(jh) - - rec = dict() - if header: - rec['header'] = header - - wrapped = alg.wrap(key, enc.wrap_key_size, self.cek, jh) - self.cek = wrapped['cek'] - - if 'ek' in wrapped: - rec['encrypted_key'] = wrapped['ek'] - - if 'header' in wrapped: - h = json_decode(rec.get('header', '{}')) - nh = self._merge_headers(h, wrapped['header']) - rec['header'] = json_encode(nh) - - if 'ciphertext' not in self.objects: - self._encrypt(alg, enc, jh) - - if 'recipients' in self.objects: - self.objects['recipients'].append(rec) - elif 'encrypted_key' in self.objects or 'header' in self.objects: - self.objects['recipients'] = list() - n = dict() - if 'encrypted_key' in self.objects: - n['encrypted_key'] = self.objects.pop('encrypted_key') - if 'header' in self.objects: - n['header'] = self.objects.pop('header') - self.objects['recipients'].append(n) - self.objects['recipients'].append(rec) - else: - self.objects.update(rec) - - def serialize(self, compact=False): - """Serializes the object into a JWE token. - - :param compact(boolean): if True generates the compact - representation, otherwise generates a standard JSON format. - - :raises InvalidJWEOperation: if the object cannot serialized - with the compact representation and `compact` is True. - :raises InvalidJWEOperation: if no recipients have been added - to the object. - """ - - if 'ciphertext' not in self.objects: - raise InvalidJWEOperation("No available ciphertext") - - if compact: - for invalid in 'aad', 'unprotected': - if invalid in self.objects: - raise InvalidJWEOperation( - "Can't use compact encoding when the '%s' parameter" - "is set" % invalid) - if 'protected' not in self.objects: - raise InvalidJWEOperation( - "Can't use compat encoding without protected headers") - else: - ph = json_decode(self.objects['protected']) - for required in 'alg', 'enc': - if required not in ph: - raise InvalidJWEOperation( - "Can't use compat encoding, '%s' must be in the " - "protected header" % required) - if 'recipients' in self.objects: - if len(self.objects['recipients']) != 1: - raise InvalidJWEOperation("Invalid number of recipients") - rec = self.objects['recipients'][0] - else: - rec = self.objects - if 'header' in rec: - # The AESGCMKW algorithm generates data (iv, tag) we put in the - # per-recipient unpotected header by default. Move it to the - # protected header and re-encrypt the payload, as the protected - # header is used as additional authenticated data. - h = json_decode(rec['header']) - ph = json_decode(self.objects['protected']) - nph = self._merge_headers(h, ph) - self.objects['protected'] = json_encode(nph) - jh = self._get_jose_header() - alg, enc = self._get_alg_enc_from_headers(jh) - self._encrypt(alg, enc, jh) - del rec['header'] - - return '.'.join([base64url_encode(self.objects['protected']), - base64url_encode(rec.get('encrypted_key', '')), - base64url_encode(self.objects['iv']), - base64url_encode(self.objects['ciphertext']), - base64url_encode(self.objects['tag'])]) - else: - obj = self.objects - enc = {'ciphertext': base64url_encode(obj['ciphertext']), - 'iv': base64url_encode(obj['iv']), - 'tag': base64url_encode(self.objects['tag'])} - if 'protected' in obj: - enc['protected'] = base64url_encode(obj['protected']) - if 'unprotected' in obj: - enc['unprotected'] = json_decode(obj['unprotected']) - if 'aad' in obj: - enc['aad'] = base64url_encode(obj['aad']) - if 'recipients' in obj: - enc['recipients'] = list() - for rec in obj['recipients']: - e = dict() - if 'encrypted_key' in rec: - e['encrypted_key'] = \ - base64url_encode(rec['encrypted_key']) - if 'header' in rec: - e['header'] = json_decode(rec['header']) - enc['recipients'].append(e) - else: - if 'encrypted_key' in obj: - enc['encrypted_key'] = \ - base64url_encode(obj['encrypted_key']) - if 'header' in obj: - enc['header'] = json_decode(obj['header']) - return json_encode(enc) - - def _check_crit(self, crit): - for k in crit: - if k not in JWEHeaderRegistry: - raise InvalidJWEData('Unknown critical header: "%s"' % k) - else: - if not JWEHeaderRegistry[k][1]: - raise InvalidJWEData('Unsupported critical header: ' - '"%s"' % k) - - # FIXME: allow to specify which algorithms to accept as valid - def _decrypt(self, key, ppe): - - jh = self._get_jose_header(ppe.get('header', None)) - - # TODO: allow caller to specify list of headers it understands - self._check_crit(jh.get('crit', dict())) - - alg = self._jwa_keymgmt(jh.get('alg', None)) - enc = self._jwa_enc(jh.get('enc', None)) - - aad = base64url_encode(self.objects.get('protected', '')) - if 'aad' in self.objects: - aad += '.' + base64url_encode(self.objects['aad']) - - cek = alg.unwrap(key, enc.wrap_key_size, - ppe.get('encrypted_key', b''), jh) - data = enc.decrypt(cek, aad.encode('utf-8'), - self.objects['iv'], - self.objects['ciphertext'], - self.objects['tag']) - - self.decryptlog.append('Success') - self.cek = cek - - compress = jh.get('zip', None) - if compress == 'DEF': - self.plaintext = zlib.decompress(data, -zlib.MAX_WBITS) - elif compress is None: - self.plaintext = data - else: - raise ValueError('Unknown compression') - - def decrypt(self, key): - """Decrypt a JWE token. - - :param key: The (:class:`jwcrypto.jwk.JWK`) decryption key. - :param key: A (:class:`jwcrypto.jwk.JWK`) decryption key or a password - string (optional). - - :raises InvalidJWEOperation: if the key is not a JWK object. - :raises InvalidJWEData: if the ciphertext can't be decrypted or - the object is otherwise malformed. - """ - - if 'ciphertext' not in self.objects: - raise InvalidJWEOperation("No available ciphertext") - self.decryptlog = list() - - if 'recipients' in self.objects: - for rec in self.objects['recipients']: - try: - self._decrypt(key, rec) - except Exception as e: # pylint: disable=broad-except - self.decryptlog.append('Failed: [%s]' % repr(e)) - else: - try: - self._decrypt(key, self.objects) - except Exception as e: # pylint: disable=broad-except - self.decryptlog.append('Failed: [%s]' % repr(e)) - - if not self.plaintext: - raise InvalidJWEData('No recipient matched the provided ' - 'key' + repr(self.decryptlog)) - - def deserialize(self, raw_jwe, key=None): - """Deserialize a JWE token. - - NOTE: Destroys any current status and tries to import the raw - JWE provided. - - :param raw_jwe: a 'raw' JWE token (JSON Encoded or Compact - notation) string. - :param key: A (:class:`jwcrypto.jwk.JWK`) decryption key or a password - string (optional). - If a key is provided a decryption step will be attempted after - the object is successfully deserialized. - - :raises InvalidJWEData: if the raw object is an invaid JWE token. - :raises InvalidJWEOperation: if the decryption fails. - """ - - self.objects = dict() - self.plaintext = None - self.cek = None - - o = dict() - try: - try: - djwe = json_decode(raw_jwe) - o['iv'] = base64url_decode(djwe['iv']) - o['ciphertext'] = base64url_decode(djwe['ciphertext']) - o['tag'] = base64url_decode(djwe['tag']) - if 'protected' in djwe: - p = base64url_decode(djwe['protected']) - o['protected'] = p.decode('utf-8') - if 'unprotected' in djwe: - o['unprotected'] = json_encode(djwe['unprotected']) - if 'aad' in djwe: - o['aad'] = base64url_decode(djwe['aad']) - if 'recipients' in djwe: - o['recipients'] = list() - for rec in djwe['recipients']: - e = dict() - if 'encrypted_key' in rec: - e['encrypted_key'] = \ - base64url_decode(rec['encrypted_key']) - if 'header' in rec: - e['header'] = json_encode(rec['header']) - o['recipients'].append(e) - else: - if 'encrypted_key' in djwe: - o['encrypted_key'] = \ - base64url_decode(djwe['encrypted_key']) - if 'header' in djwe: - o['header'] = json_encode(djwe['header']) - - except ValueError: - c = raw_jwe.split('.') - if len(c) != 5: - raise InvalidJWEData() - p = base64url_decode(c[0]) - o['protected'] = p.decode('utf-8') - ekey = base64url_decode(c[1]) - if ekey != b'': - o['encrypted_key'] = base64url_decode(c[1]) - o['iv'] = base64url_decode(c[2]) - o['ciphertext'] = base64url_decode(c[3]) - o['tag'] = base64url_decode(c[4]) - - self.objects = o - - except Exception as e: # pylint: disable=broad-except - raise InvalidJWEData('Invalid format', repr(e)) - - if key: - self.decrypt(key) - - @property - def payload(self): - if not self.plaintext: - raise InvalidJWEOperation("Plaintext not available") - return self.plaintext - - @property - def jose_header(self): - jh = self._get_jose_header() - if len(jh) == 0: - raise InvalidJWEOperation("JOSE Header not available") - return jh diff --git a/jwcrypto/jwk.py b/jwcrypto/jwk.py deleted file mode 100644 index b494414..0000000 --- a/jwcrypto/jwk.py +++ /dev/null @@ -1,875 +0,0 @@ -# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file - -import os -from binascii import hexlify, unhexlify -from collections import namedtuple -from enum import Enum - -from cryptography import x509 -from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric import rsa - -from six import iteritems - -from jwcrypto.common import JWException -from jwcrypto.common import base64url_decode, base64url_encode -from jwcrypto.common import json_decode, json_encode - - -# RFC 7518 - 7.4 -JWKTypesRegistry = {'EC': 'Elliptic Curve', - 'RSA': 'RSA', - 'oct': 'Octet sequence'} -"""Registry of valid Key Types""" - - -# RFC 7518 - 7.5 -# It is part of the JWK Parameters Registry, but we want a more -# specific map for internal usage -class ParmType(Enum): - name = 'A string with a name' - b64 = 'Base64url Encoded' - b64U = 'Base64urlUint Encoded' - unsupported = 'Unsupported Parameter' - - -JWKParameter = namedtuple('Parameter', 'description public required type') -JWKValuesRegistry = { - 'EC': { - 'crv': JWKParameter('Curve', True, True, ParmType.name), - 'x': JWKParameter('X Coordinate', True, True, ParmType.b64), - 'y': JWKParameter('Y Coordinate', True, True, ParmType.b64), - 'd': JWKParameter('ECC Private Key', False, False, ParmType.b64), - }, - 'RSA': { - 'n': JWKParameter('Modulus', True, True, ParmType.b64), - 'e': JWKParameter('Exponent', True, True, ParmType.b64U), - 'd': JWKParameter('Private Exponent', False, False, ParmType.b64U), - 'p': JWKParameter('First Prime Factor', False, False, ParmType.b64U), - 'q': JWKParameter('Second Prime Factor', False, False, ParmType.b64U), - 'dp': JWKParameter('First Factor CRT Exponent', - False, False, ParmType.b64U), - 'dq': JWKParameter('Second Factor CRT Exponent', - False, False, ParmType.b64U), - 'qi': JWKParameter('First CRT Coefficient', - False, False, ParmType.b64U), - 'oth': JWKParameter('Other Primes Info', - False, False, ParmType.unsupported), - }, - 'oct': { - 'k': JWKParameter('Key Value', False, True, ParmType.b64), - } -} -"""Registry of valid key values""" - -JWKParamsRegistry = { - 'kty': JWKParameter('Key Type', True, None, None), - 'use': JWKParameter('Public Key Use', True, None, None), - 'key_ops': JWKParameter('Key Operations', True, None, None), - 'alg': JWKParameter('Algorithm', True, None, None), - 'kid': JWKParameter('Key ID', True, None, None), - 'x5u': JWKParameter('X.509 URL', True, None, None), - 'x5c': JWKParameter('X.509 Certificate Chain', True, None, None), - 'x5t': JWKParameter('X.509 Certificate SHA-1 Thumbprint', - True, None, None), - 'x5t#S256': JWKParameter('X.509 Certificate SHA-256 Thumbprint', - True, None, None) -} -"""Regstry of valid key parameters""" - -# RFC 7518 - 7.6 -JWKEllipticCurveRegistry = {'P-256': 'P-256 curve', - 'P-384': 'P-384 curve', - 'P-521': 'P-521 curve'} -"""Registry of allowed Elliptic Curves""" - -# RFC 7517 - 8.2 -JWKUseRegistry = {'sig': 'Digital Signature or MAC', - 'enc': 'Encryption'} -"""Registry of allowed uses""" - -# RFC 7517 - 8.3 -JWKOperationsRegistry = {'sign': 'Compute digital Signature or MAC', - 'verify': 'Verify digital signature or MAC', - 'encrypt': 'Encrypt content', - 'decrypt': 'Decrypt content and validate' - ' decryption, if applicable', - 'wrapKey': 'Encrypt key', - 'unwrapKey': 'Decrypt key and validate' - ' decryption, if applicable', - 'deriveKey': 'Derive key', - 'deriveBits': 'Derive bits not to be used as a key'} -"""Registry of allowed operations""" - -JWKpycaCurveMap = {'secp256r1': 'P-256', - 'secp384r1': 'P-384', - 'secp521r1': 'P-521'} - - -class InvalidJWKType(JWException): - """Invalid JWK Type Exception. - - This exception is raised when an invalid parameter type is used. - """ - - def __init__(self, value=None): - super(InvalidJWKType, self).__init__() - self.value = value - - def __str__(self): - return 'Unknown type "%s", valid types are: %s' % ( - self.value, list(JWKTypesRegistry.keys())) - - -class InvalidJWKUsage(JWException): - """Invalid JWK usage Exception. - - This exception is raised when an invalid key usage is requested, - based on the key type and declared usage constraints. - """ - - def __init__(self, use, value): - super(InvalidJWKUsage, self).__init__() - self.value = value - self.use = use - - def __str__(self): - if self.use in list(JWKUseRegistry.keys()): - usage = JWKUseRegistry[self.use] - else: - usage = 'Unknown(%s)' % self.use - if self.value in list(JWKUseRegistry.keys()): - valid = JWKUseRegistry[self.value] - else: - valid = 'Unknown(%s)' % self.value - return 'Invalid usage requested: "%s". Valid for: "%s"' % (usage, - valid) - - -class InvalidJWKOperation(JWException): - """Invalid JWK Operation Exception. - - This exception is raised when an invalid key operation is requested, - based on the key type and declared usage constraints. - """ - - def __init__(self, operation, values): - super(InvalidJWKOperation, self).__init__() - self.op = operation - self.values = values - - def __str__(self): - if self.op in list(JWKOperationsRegistry.keys()): - op = JWKOperationsRegistry[self.op] - else: - op = 'Unknown(%s)' % self.op - valid = list() - for v in self.values: - if v in list(JWKOperationsRegistry.keys()): - valid.append(JWKOperationsRegistry[v]) - else: - valid.append('Unknown(%s)' % v) - return 'Invalid operation requested: "%s". Valid for: "%s"' % (op, - valid) - - -class InvalidJWKValue(JWException): - """Invalid JWK Value Exception. - - This exception is raised when an invalid/unknown value is used in the - context of an operation that requires specific values to be used based - on the key type or other constraints. - """ - - pass - - -class JWK(object): - """JSON Web Key object - - This object represent a Key. - It must be instantiated by using the standard defined key/value pairs - as arguments of the initialization function. - """ - - def __init__(self, **kwargs): - """Creates a new JWK object. - - The function arguments must be valid parameters as defined in the - 'IANA JSON Web Key Set Parameters registry' and specified in - the :data:`JWKParamsRegistry` variable. The 'kty' parameter must - always be provided and its value must be a valid one as defined - by the 'IANA JSON Web Key Types registry' and specified in the - :data:`JWKTypesRegistry` variable. The valid key parameters per - key type are defined in the :data:`JWKValuesregistry` variable. - - To generate a new random key call the class method generate() with - the appropriate 'kty' parameter, and other parameters as needed (key - size, public exponents, curve types, etc..) - - Valid options per type, when generating new keys: - * oct: size(int) - * RSA: public_exponent(int), size(int) - * EC: curve(str) (one of P-256, P-384, P-521) - - Deprecated: - Alternatively if the 'generate' parameter is provided, with a - valid key type as value then a new key will be generated according - to the defaults or provided key strenght options (type specific). - - :raises InvalidJWKType: if the key type is invalid - :raises InvalidJWKValue: if incorrect or inconsistent parameters - are provided. - """ - self._params = dict() - self._key = dict() - self._unknown = dict() - - if 'generate' in kwargs: - self.generate_key(**kwargs) - elif kwargs: - self.import_key(**kwargs) - - @classmethod - def generate(cls, **kwargs): - obj = cls() - kty = None - try: - kty = kwargs['kty'] - gen = getattr(obj, '_generate_%s' % kty) - except (KeyError, AttributeError): - raise InvalidJWKType(kty) - gen(kwargs) - return obj - - def generate_key(self, **params): - kty = None - try: - kty = params.pop('generate') - gen = getattr(self, '_generate_%s' % kty) - except (KeyError, AttributeError): - raise InvalidJWKType(kty) - - gen(params) - - def _get_gen_size(self, params, default_size=None): - size = default_size - if 'size' in params: - size = params.pop('size') - elif 'alg' in params: - try: - from jwcrypto.jwa import JWA - alg = JWA.instantiate_alg(params['alg']) - except KeyError: - raise ValueError("Invalid 'alg' parameter") - size = alg.keysize - return size - - def _generate_oct(self, params): - size = self._get_gen_size(params, 128) - key = os.urandom(size // 8) - params['kty'] = 'oct' - params['k'] = base64url_encode(key) - self.import_key(**params) - - def _encode_int(self, i): - intg = hex(i).rstrip("L").lstrip("0x") - return base64url_encode(unhexlify((len(intg) % 2) * '0' + intg)) - - def _generate_RSA(self, params): - pubexp = 65537 - size = self._get_gen_size(params, 2048) - if 'public_exponent' in params: - pubexp = params.pop('public_exponent') - key = rsa.generate_private_key(pubexp, size, default_backend()) - self._import_pyca_pri_rsa(key, **params) - - def _import_pyca_pri_rsa(self, key, **params): - pn = key.private_numbers() - params.update( - kty='RSA', - n=self._encode_int(pn.public_numbers.n), - e=self._encode_int(pn.public_numbers.e), - d=self._encode_int(pn.d), - p=self._encode_int(pn.p), - q=self._encode_int(pn.q), - dp=self._encode_int(pn.dmp1), - dq=self._encode_int(pn.dmq1), - qi=self._encode_int(pn.iqmp) - ) - self.import_key(**params) - - def _import_pyca_pub_rsa(self, key, **params): - pn = key.public_numbers() - params.update( - kty='RSA', - n=self._encode_int(pn.n), - e=self._encode_int(pn.e) - ) - self.import_key(**params) - - def _get_curve_by_name(self, name): - if name == 'P-256': - return ec.SECP256R1() - elif name == 'P-384': - return ec.SECP384R1() - elif name == 'P-521': - return ec.SECP521R1() - else: - raise InvalidJWKValue('Unknown Elliptic Curve Type') - - def _generate_EC(self, params): - curve = 'P-256' - if 'curve' in params: - curve = params.pop('curve') - # 'curve' is for backwards compat, if 'crv' is defined it takes - # precedence - if 'crv' in params: - curve = params.pop('crv') - curve_name = self._get_curve_by_name(curve) - key = ec.generate_private_key(curve_name, default_backend()) - self._import_pyca_pri_ec(key, **params) - - def _import_pyca_pri_ec(self, key, **params): - pn = key.private_numbers() - params.update( - kty='EC', - crv=JWKpycaCurveMap[key.curve.name], - x=self._encode_int(pn.public_numbers.x), - y=self._encode_int(pn.public_numbers.y), - d=self._encode_int(pn.private_value) - ) - self.import_key(**params) - - def _import_pyca_pub_ec(self, key, **params): - pn = key.public_numbers() - params.update( - kty='EC', - crv=JWKpycaCurveMap[key.curve.name], - x=self._encode_int(pn.x), - y=self._encode_int(pn.y), - ) - self.import_key(**params) - - def import_key(self, **kwargs): - names = list(kwargs.keys()) - - for name in list(JWKParamsRegistry.keys()): - if name in kwargs: - self._params[name] = kwargs[name] - while name in names: - names.remove(name) - - kty = self._params.get('kty', None) - if kty not in JWKTypesRegistry: - raise InvalidJWKType(kty) - - for name in list(JWKValuesRegistry[kty].keys()): - if name in kwargs: - self._key[name] = kwargs[name] - while name in names: - names.remove(name) - - for name, val in iteritems(JWKValuesRegistry[kty]): - if val.required and name not in self._key: - raise InvalidJWKValue('Missing required value %s' % name) - if val.type == ParmType.unsupported and name in self._key: - raise InvalidJWKValue('Unsupported parameter %s' % name) - if val.type == ParmType.b64 and name in self._key: - # Check that the value is base64url encoded - try: - base64url_decode(self._key[name]) - except Exception: # pylint: disable=broad-except - raise InvalidJWKValue( - '"%s" is not base64url encoded' % name - ) - if val[3] == ParmType.b64U and name in self._key: - # Check that the value is Base64urlUInt encoded - try: - self._decode_int(self._key[name]) - except Exception: # pylint: disable=broad-except - raise InvalidJWKValue( - '"%s" is not Base64urlUInt encoded' % name - ) - - # Unknown key parameters are allowed - # Let's just store them out of the way - for name in names: - self._unknown[name] = kwargs[name] - - if len(self._key) == 0: - raise InvalidJWKValue('No Key Values found') - - # check key_ops - if 'key_ops' in self._params: - for ko in self._params['key_ops']: - c = 0 - for cko in self._params['key_ops']: - if ko == cko: - c += 1 - if c != 1: - raise InvalidJWKValue('Duplicate values in "key_ops"') - - # check use/key_ops consistency - if 'use' in self._params and 'key_ops' in self._params: - sigl = ['sign', 'verify'] - encl = ['encrypt', 'decrypt', 'wrapKey', 'unwrapKey', - 'deriveKey', 'deriveBits'] - if self._params['use'] == 'sig': - for op in encl: - if op in self._params['key_ops']: - raise InvalidJWKValue('Incompatible "use" and' - ' "key_ops" values specified at' - ' the same time') - elif self._params['use'] == 'enc': - for op in sigl: - if op in self._params['key_ops']: - raise InvalidJWKValue('Incompatible "use" and' - ' "key_ops" values specified at' - ' the same time') - - @classmethod - def from_json(cls, key): - """Creates a RFC 7517 JWK from the standard JSON format. - - :param key: The RFC 7517 representation of a JWK. - """ - obj = cls() - try: - jkey = json_decode(key) - except Exception as e: # pylint: disable=broad-except - raise InvalidJWKValue(e) - obj.import_key(**jkey) - return obj - - def export(self, private_key=True): - """Exports the key in the standard JSON format. - Exports the key regardless of type, if private_key is False - and the key is_symmetric an exceptionis raised. - - :param private_key(bool): Whether to export the private key. - Defaults to True. - """ - if private_key is True: - # Use _export_all for backwards compatibility, as this - # function allows to export symmetrict keys too - return self._export_all() - else: - return self.export_public() - - def export_public(self): - """Exports the public key in the standard JSON format. - It fails if one is not available like when this function - is called on a symmetric key. - """ - pub = self._public_params() - return json_encode(pub) - - def _public_params(self): - if not self.has_public: - raise InvalidJWKType("No public key available") - pub = {} - preg = JWKParamsRegistry - for name in preg: - if preg[name].public: - if name in self._params: - pub[name] = self._params[name] - reg = JWKValuesRegistry[self._params['kty']] - for param in reg: - if reg[param].public: - pub[param] = self._key[param] - return pub - - def _export_all(self): - d = dict() - d.update(self._params) - d.update(self._key) - d.update(self._unknown) - return json_encode(d) - - def export_private(self): - """Export the private key in the standard JSON format. - It fails for a JWK that has only a public key or is symmetric. - """ - if self.has_private: - return self._export_all() - raise InvalidJWKType("No private key available") - - def export_symmetric(self): - if self.is_symmetric: - return self._export_all() - raise InvalidJWKType("Not a symmetric key") - - def public(self): - pub = self._public_params() - return JWK(**pub) - - @property - def has_public(self): - """Whether this JWK has an asymmetric Public key.""" - if self.is_symmetric: - return False - reg = JWKValuesRegistry[self._params['kty']] - for value in reg: - if reg[value].public and value in self._key: - return True - - @property - def has_private(self): - """Whether this JWK has an asymmetric key Private key.""" - if self.is_symmetric: - return False - reg = JWKValuesRegistry[self._params['kty']] - for value in reg: - if not reg[value].public and value in self._key: - return True - return False - - @property - def is_symmetric(self): - """Whether this JWK is a symmetric key.""" - return self.key_type == 'oct' - - @property - def key_type(self): - """The Key type""" - return self._params.get('kty', None) - - @property - def key_id(self): - """The Key ID. - Provided by the kid parameter if present, otherwise returns None. - """ - return self._params.get('kid', None) - - @property - def key_curve(self): - """The Curve Name.""" - if self._params['kty'] != 'EC': - raise InvalidJWKType('Not an EC key') - return self._key['crv'] - - def get_curve(self, arg): - """Gets the Elliptic Curve associated with the key. - - :param arg: an optional curve name - - :raises InvalidJWKType: the key is not an EC key. - :raises InvalidJWKValue: if the curve names is invalid. - """ - k = self._key - if self._params['kty'] != 'EC': - raise InvalidJWKType('Not an EC key') - if arg and k['crv'] != arg: - raise InvalidJWKValue('Curve requested is "%s", but ' - 'key curve is "%s"' % (arg, k['crv'])) - - return self._get_curve_by_name(k['crv']) - - def _check_constraints(self, usage, operation): - use = self._params.get('use', None) - if use and use != usage: - raise InvalidJWKUsage(usage, use) - ops = self._params.get('key_ops', None) - if ops: - if not isinstance(ops, list): - ops = [ops] - if operation not in ops: - raise InvalidJWKOperation(operation, ops) - # TODO: check alg ? - - def _decode_int(self, n): - return int(hexlify(base64url_decode(n)), 16) - - def _rsa_pub(self, k): - return rsa.RSAPublicNumbers(self._decode_int(k['e']), - self._decode_int(k['n'])) - - def _rsa_pri(self, k): - return rsa.RSAPrivateNumbers(self._decode_int(k['p']), - self._decode_int(k['q']), - self._decode_int(k['d']), - self._decode_int(k['dp']), - self._decode_int(k['dq']), - self._decode_int(k['qi']), - self._rsa_pub(k)) - - def _ec_pub(self, k, curve): - return ec.EllipticCurvePublicNumbers(self._decode_int(k['x']), - self._decode_int(k['y']), - self.get_curve(curve)) - - def _ec_pri(self, k, curve): - return ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']), - self._ec_pub(k, curve)) - - def _get_public_key(self, arg=None): - if self._params['kty'] == 'oct': - return self._key['k'] - elif self._params['kty'] == 'RSA': - return self._rsa_pub(self._key).public_key(default_backend()) - elif self._params['kty'] == 'EC': - return self._ec_pub(self._key, arg).public_key(default_backend()) - else: - raise NotImplementedError - - def _get_private_key(self, arg=None): - if self._params['kty'] == 'oct': - return self._key['k'] - elif self._params['kty'] == 'RSA': - return self._rsa_pri(self._key).private_key(default_backend()) - elif self._params['kty'] == 'EC': - return self._ec_pri(self._key, arg).private_key(default_backend()) - else: - raise NotImplementedError - - def get_op_key(self, operation=None, arg=None): - """Get the key object associated to the requested opration. - For example the public RSA key for the 'verify' operation or - the private EC key for the 'decrypt' operation. - - :param operation: The requested operation. - The valid set of operations is availble in the - :data:`JWKOperationsRegistry` registry. - :param arg: an optional, context specific, argument - For example a curve name. - - :raises InvalidJWKOperation: if the operation is unknown or - not permitted with this key. - :raises InvalidJWKUsage: if the use constraints do not permit - the operation. - """ - validops = self._params.get('key_ops', - list(JWKOperationsRegistry.keys())) - if validops is not list: - validops = [validops] - if operation is None: - if self._params['kty'] == 'oct': - return self._key['k'] - raise InvalidJWKOperation(operation, validops) - elif operation == 'sign': - self._check_constraints('sig', operation) - return self._get_private_key(arg) - elif operation == 'verify': - self._check_constraints('sig', operation) - return self._get_public_key(arg) - elif operation == 'encrypt' or operation == 'wrapKey': - self._check_constraints('enc', operation) - return self._get_public_key(arg) - elif operation == 'decrypt' or operation == 'unwrapKey': - self._check_constraints('enc', operation) - return self._get_private_key(arg) - else: - raise NotImplementedError - - def import_from_pyca(self, key): - if isinstance(key, rsa.RSAPrivateKey): - self._import_pyca_pri_rsa(key) - elif isinstance(key, rsa.RSAPublicKey): - self._import_pyca_pub_rsa(key) - elif isinstance(key, ec.EllipticCurvePrivateKey): - self._import_pyca_pri_ec(key) - elif isinstance(key, ec.EllipticCurvePublicKey): - self._import_pyca_pub_ec(key) - else: - raise InvalidJWKValue('Unknown key object %r' % key) - - def import_from_pem(self, data, password=None): - """Imports a key from data loaded from a PEM file. - The key may be encrypted with a password. - Private keys (PKCS#8 format), public keys, and X509 certificate's - public keys can be imported with this interface. - - :param data(bytes): The data contained in a PEM file. - :param password(bytes): An optional password to unwrap the key. - """ - - try: - key = serialization.load_pem_private_key( - data, password=password, backend=default_backend()) - except ValueError as e: - if password is not None: - raise e - try: - key = serialization.load_pem_public_key( - data, backend=default_backend()) - except ValueError: - try: - cert = x509.load_pem_x509_certificate( - data, backend=default_backend()) - key = cert.public_key() - except ValueError: - raise e - - self.import_from_pyca(key) - self._params['kid'] = self.thumbprint() - - def export_to_pem(self, private_key=False, password=False): - """Exports keys to a data buffer suitable to be stored as a PEM file. - Either the public or the private key can be exported to a PEM file. - For private keys the PKCS#8 format is used. If a password is provided - the best encryption method available as determined by the cryptography - module is used to wrap the key. - - :param private_key: Whether the private key should be exported. - Defaults to `False` which means the public key is exported by default. - :param password(bytes): A password for wrapping the private key. - Defaults to False which will cause the operation to fail. To avoid - encryption the user must explicitly pass None, otherwise the user - needs to provide a password in a bytes buffer. - """ - e = serialization.Encoding.PEM - if private_key: - if not self.has_private: - raise InvalidJWKType("No private key available") - f = serialization.PrivateFormat.PKCS8 - if password is None: - a = serialization.NoEncryption() - elif isinstance(password, bytes): - a = serialization.BestAvailableEncryption(password) - elif password is False: - raise ValueError("The password must be None or a bytes string") - else: - raise TypeError("The password string must be bytes") - return self._get_private_key().private_bytes( - encoding=e, format=f, encryption_algorithm=a) - else: - if not self.has_public: - raise InvalidJWKType("No public key available") - f = serialization.PublicFormat.SubjectPublicKeyInfo - return self._get_public_key().public_bytes(encoding=e, format=f) - - @classmethod - def from_pyca(cls, key): - obj = cls() - obj.import_from_pyca(key) - return obj - - @classmethod - def from_pem(cls, data, password=None): - """Creates a key from PKCS#8 formatted data loaded from a PEM file. - See the function `import_from_pem` for details. - - :param data(bytes): The data contained in a PEM file. - :param password(bytes): An optional password to unwrap the key. - """ - obj = cls() - obj.import_from_pem(data, password) - return obj - - def thumbprint(self, hashalg=hashes.SHA256()): - """Returns the key thumbprint as specified by RFC 7638. - - :param hashalg: A hash function (defaults to SHA256) - """ - - t = {'kty': self._params['kty']} - for name, val in iteritems(JWKValuesRegistry[t['kty']]): - if val.required: - t[name] = self._key[name] - digest = hashes.Hash(hashalg, backend=default_backend()) - digest.update(bytes(json_encode(t).encode('utf8'))) - return base64url_encode(digest.finalize()) - - -class _JWKkeys(set): - - def add(self, elem): - """Adds a JWK object to the set - - :param elem: the JWK object to add. - - :raises TypeError: if the object is not a JWK. - """ - if not isinstance(elem, JWK): - raise TypeError('Only JWK objects are valid elements') - set.add(self, elem) - - -class JWKSet(dict): - """A set of JWK objects. - - Inherits from the standard 'dict' bultin type. - Creates a special key 'keys' that is of a type derived from 'set' - The 'keys' attribute accepts only :class:`jwcrypto.jwk.JWK` elements. - """ - def __init__(self, *args, **kwargs): - super(JWKSet, self).__init__() - super(JWKSet, self).__setitem__('keys', _JWKkeys()) - self.update(*args, **kwargs) - - def __iter__(self): - return self['keys'].__iter__() - - def __contains__(self, key): - return self['keys'].__contains__(key) - - def __setitem__(self, key, val): - if key == 'keys': - self['keys'].add(val) - else: - super(JWKSet, self).__setitem__(key, val) - - def update(self, *args, **kwargs): - for k, v in iteritems(dict(*args, **kwargs)): - self.__setitem__(k, v) - - def add(self, elem): - self['keys'].add(elem) - - def export(self, private_keys=True): - """Exports a RFC 7517 keyset using the standard JSON format - - :param private_key(bool): Whether to export private keys. - Defaults to True. - """ - exp_dict = dict() - for k, v in iteritems(self): - if k == 'keys': - keys = list() - for jwk in v: - keys.append(json_decode(jwk.export(private_keys))) - v = keys - exp_dict[k] = v - return json_encode(exp_dict) - - def import_keyset(self, keyset): - """Imports a RFC 7517 keyset using the standard JSON format. - - :param keyset: The RFC 7517 representation of a JOSE Keyset. - """ - try: - jwkset = json_decode(keyset) - except Exception: # pylint: disable=broad-except - raise InvalidJWKValue() - - if 'keys' not in jwkset: - raise InvalidJWKValue() - - for k, v in iteritems(jwkset): - if k == 'keys': - for jwk in v: - self['keys'].add(JWK(**jwk)) - else: - self[k] = v - - @classmethod - def from_json(cls, keyset): - """Creates a RFC 7517 keyset from the standard JSON format. - - :param keyset: The RFC 7517 representation of a JOSE Keyset. - """ - obj = cls() - obj.import_keyset(keyset) - return obj - - def get_key(self, kid): - """Gets a key from the set. - :param kid: the 'kid' key identifier. - """ - for jwk in self['keys']: - if jwk.key_id == kid: - return jwk - return None diff --git a/jwcrypto/jws.py b/jwcrypto/jws.py deleted file mode 100644 index 0dfd15d..0000000 --- a/jwcrypto/jws.py +++ /dev/null @@ -1,611 +0,0 @@ -# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file - -from collections import namedtuple - -from jwcrypto.common import JWException -from jwcrypto.common import base64url_decode, base64url_encode -from jwcrypto.common import json_decode, json_encode -from jwcrypto.jwa import JWA -from jwcrypto.jwk import JWK - - -# RFC 7515 - 9.1 -# name: (description, supported?) -JWSHeaderParameter = namedtuple('Parameter', - 'description mustprotect supported') -JWSHeaderRegistry = { - 'alg': JWSHeaderParameter('Algorithm', False, True), - 'jku': JWSHeaderParameter('JWK Set URL', False, False), - 'jwk': JWSHeaderParameter('JSON Web Key', False, False), - 'kid': JWSHeaderParameter('Key ID', False, True), - 'x5u': JWSHeaderParameter('X.509 URL', False, False), - 'x5c': JWSHeaderParameter('X.509 Certificate Chain', False, False), - 'x5t': JWSHeaderParameter( - 'X.509 Certificate SHA-1 Thumbprint', False, False), - 'x5t#S256': JWSHeaderParameter( - 'X.509 Certificate SHA-256 Thumbprint', False, False), - 'typ': JWSHeaderParameter('Type', False, True), - 'cty': JWSHeaderParameter('Content Type', False, True), - 'crit': JWSHeaderParameter('Critical', True, True), - 'b64': JWSHeaderParameter('Base64url-Encode Payload', True, True) -} -"""Registry of valid header parameters""" - -default_allowed_algs = [ - 'HS256', 'HS384', 'HS512', - 'RS256', 'RS384', 'RS512', - 'ES256', 'ES384', 'ES512', - 'PS256', 'PS384', 'PS512'] -"""Default allowed algorithms""" - - -class InvalidJWSSignature(JWException): - """Invalid JWS Signature. - - This exception is raised when a signature cannot be validated. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Unknown Signature Verification Failure' - if exception: - msg += ' {%s}' % str(exception) - super(InvalidJWSSignature, self).__init__(msg) - - -class InvalidJWSObject(JWException): - """Invalid JWS Object. - - This exception is raised when the JWS Object is invalid and/or - improperly formatted. - """ - - def __init__(self, message=None, exception=None): - msg = 'Invalid JWS Object' - if message: - msg += ' [%s]' % message - if exception: - msg += ' {%s}' % str(exception) - super(InvalidJWSObject, self).__init__(msg) - - -class InvalidJWSOperation(JWException): - """Invalid JWS Object. - - This exception is raised when a requested operation cannot - be execute due to unsatisfied conditions. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = message - else: - msg = 'Unknown Operation Failure' - if exception: - msg += ' {%s}' % str(exception) - super(InvalidJWSOperation, self).__init__(msg) - - -class JWSCore(object): - """The inner JWS Core object. - - This object SHOULD NOT be used directly, the JWS object should be - used instead as JWS perform necessary checks on the validity of - the object and requested operations. - - """ - - def __init__(self, alg, key, header, payload, algs=None): - """Core JWS token handling. - - :param alg: The algorithm used to produce the signature. - See RFC 7518 - :param key: A (:class:`jwcrypto.jwk.JWK`) key of appropriate - type for the "alg" provided in the 'protected' json string. - :param header: A JSON string representing the protected header. - :param payload(bytes): An arbitrary value - :param algs: An optional list of allowed algorithms - - :raises ValueError: if the key is not a :class:`JWK` object - :raises InvalidJWAAlgorithm: if the algorithm is not valid, is - unknown or otherwise not yet implemented. - """ - self.alg = alg - self.engine = self._jwa(alg, algs) - if not isinstance(key, JWK): - raise ValueError('key is not a JWK object') - self.key = key - - if header is not None: - if isinstance(header, dict): - self.header = header - header = json_encode(header) - else: - self.header = json_decode(header) - - self.protected = base64url_encode(header.encode('utf-8')) - else: - self.header = dict() - self.protected = '' - self.payload = payload - - def _jwa(self, name, allowed): - if allowed is None: - allowed = default_allowed_algs - if name not in allowed: - raise InvalidJWSOperation('Algorithm not allowed') - return JWA.signing_alg(name) - - def _payload(self): - if self.header.get('b64', True): - return base64url_encode(self.payload).encode('utf-8') - else: - if isinstance(self.payload, bytes): - return self.payload - else: - return self.payload.encode('utf-8') - - def sign(self): - """Generates a signature""" - payload = self._payload() - sigin = b'.'.join([self.protected.encode('utf-8'), payload]) - signature = self.engine.sign(self.key, sigin) - return {'protected': self.protected, - 'payload': payload, - 'signature': base64url_encode(signature)} - - def verify(self, signature): - """Verifies a signature - - :raises InvalidJWSSignature: if the verification fails. - """ - try: - payload = self._payload() - sigin = b'.'.join([self.protected.encode('utf-8'), payload]) - self.engine.verify(self.key, sigin, signature) - except Exception as e: # pylint: disable=broad-except - raise InvalidJWSSignature('Verification failed', repr(e)) - return True - - -class JWS(object): - """JSON Web Signature object - - This object represent a JWS token. - """ - - def __init__(self, payload=None): - """Creates a JWS object. - - :param payload(bytes): An arbitrary value (optional). - """ - self.objects = dict() - if payload: - self.objects['payload'] = payload - self.verifylog = None - self._allowed_algs = None - - @property - def allowed_algs(self): - """Allowed algorithms. - - The list of allowed algorithms. - Can be changed by setting a list of algorithm names. - """ - - if self._allowed_algs: - return self._allowed_algs - else: - return default_allowed_algs - - @allowed_algs.setter - def allowed_algs(self, algs): - if not isinstance(algs, list): - raise TypeError('Allowed Algs must be a list') - self._allowed_algs = algs - - @property - def is_valid(self): - return self.objects.get('valid', False) - - # TODO: allow caller to specify list of headers it understands - def _merge_check_headers(self, protected, *headers): - header = None - crit = [] - if protected is not None: - if 'crit' in protected: - crit = protected['crit'] - # Check immediately if we support these critical headers - for k in crit: - if k not in JWSHeaderRegistry: - raise InvalidJWSObject( - 'Unknown critical header: "%s"' % k) - else: - if not JWSHeaderRegistry[k][1]: - raise InvalidJWSObject( - 'Unsupported critical header: "%s"' % k) - header = protected - if 'b64' in header: - if not isinstance(header['b64'], bool): - raise InvalidJWSObject('b64 header must be a boolean') - - for hn in headers: - if hn is None: - continue - if header is None: - header = dict() - for h in list(hn.keys()): - if h in JWSHeaderRegistry: - if JWSHeaderRegistry[h].mustprotect: - raise InvalidJWSObject('"%s" must be protected' % h) - if h in header: - raise InvalidJWSObject('Duplicate header: "%s"' % h) - header.update(hn) - - for k in crit: - if k not in header: - raise InvalidJWSObject('Missing critical header "%s"' % k) - - return header - - # TODO: support selecting key with 'kid' and passing in multiple keys - def _verify(self, alg, key, payload, signature, protected, header=None): - p = dict() - # verify it is a valid JSON object and decode - if protected is not None: - p = json_decode(protected) - if not isinstance(p, dict): - raise InvalidJWSSignature('Invalid Protected header') - # merge heders, and verify there are no duplicates - if header: - if not isinstance(header, dict): - raise InvalidJWSSignature('Invalid Unprotected header') - - # Merge and check (critical) headers - self._merge_check_headers(p, header) - # check 'alg' is present - if alg is None and 'alg' not in p: - raise InvalidJWSSignature('No "alg" in headers') - if alg: - if 'alg' in p and alg != p['alg']: - raise InvalidJWSSignature('"alg" mismatch, requested ' - '"%s", found "%s"' % (alg, - p['alg'])) - a = alg - else: - a = p['alg'] - - # the following will verify the "alg" is supported and the signature - # verifies - c = JWSCore(a, key, protected, payload, self._allowed_algs) - c.verify(signature) - - def verify(self, key, alg=None): - """Verifies a JWS token. - - :param key: The (:class:`jwcrypto.jwk.JWK`) verification key. - :param alg: The signing algorithm (optional). usually the algorithm - is known as it is provided with the JOSE Headers of the token. - - :raises InvalidJWSSignature: if the verification fails. - """ - - self.verifylog = list() - self.objects['valid'] = False - obj = self.objects - if 'signature' in obj: - try: - self._verify(alg, key, - obj['payload'], - obj['signature'], - obj.get('protected', None), - obj.get('header', None)) - obj['valid'] = True - except Exception as e: # pylint: disable=broad-except - self.verifylog.append('Failed: [%s]' % repr(e)) - - elif 'signatures' in obj: - for o in obj['signatures']: - try: - self._verify(alg, key, - obj['payload'], - o['signature'], - o.get('protected', None), - o.get('header', None)) - # Ok if at least one verifies - obj['valid'] = True - except Exception as e: # pylint: disable=broad-except - self.verifylog.append('Failed: [%s]' % repr(e)) - else: - raise InvalidJWSSignature('No signatures availble') - - if not self.is_valid: - raise InvalidJWSSignature('Verification failed for all ' - 'signatures' + repr(self.verifylog)) - - def _deserialize_signature(self, s): - o = dict() - o['signature'] = base64url_decode(str(s['signature'])) - if 'protected' in s: - p = base64url_decode(str(s['protected'])) - o['protected'] = p.decode('utf-8') - if 'header' in s: - o['header'] = s['header'] - return o - - def _deserialize_b64(self, o, protected): - if protected is None: - b64n = None - else: - p = json_decode(protected) - b64n = p.get('b64') - if b64n is not None: - if not isinstance(b64n, bool): - raise InvalidJWSObject('b64 header must be boolean') - b64 = o.get('b64') - if b64 == b64n: - return - elif b64 is None: - o['b64'] = b64n - else: - raise InvalidJWSObject('conflicting b64 values') - - def deserialize(self, raw_jws, key=None, alg=None): - """Deserialize a JWS token. - - NOTE: Destroys any current status and tries to import the raw - JWS provided. - - :param raw_jws: a 'raw' JWS token (JSON Encoded or Compact - notation) string. - :param key: A (:class:`jwcrypto.jwk.JWK`) verification key (optional). - If a key is provided a verification step will be attempted after - the object is successfully deserialized. - :param alg: The signing algorithm (optional). usually the algorithm - is known as it is provided with the JOSE Headers of the token. - - :raises InvalidJWSObject: if the raw object is an invaid JWS token. - :raises InvalidJWSSignature: if the verification fails. - """ - self.objects = dict() - o = dict() - try: - try: - djws = json_decode(raw_jws) - if 'signatures' in djws: - o['signatures'] = list() - for s in djws['signatures']: - os = self._deserialize_signature(s) - o['signatures'].append(os) - self._deserialize_b64(o, os.get('protected')) - else: - o = self._deserialize_signature(djws) - self._deserialize_b64(o, o.get('protected')) - - if 'payload' in djws: - if o.get('b64', True): - o['payload'] = base64url_decode(str(djws['payload'])) - else: - o['payload'] = djws['payload'] - - except ValueError: - c = raw_jws.split('.') - if len(c) != 3: - raise InvalidJWSObject('Unrecognized representation') - p = base64url_decode(str(c[0])) - if len(p) > 0: - o['protected'] = p.decode('utf-8') - self._deserialize_b64(o, o['protected']) - o['payload'] = base64url_decode(str(c[1])) - o['signature'] = base64url_decode(str(c[2])) - - self.objects = o - - except Exception as e: # pylint: disable=broad-except - raise InvalidJWSObject('Invalid format', repr(e)) - - if key: - self.verify(key, alg) - - def add_signature(self, key, alg=None, protected=None, header=None): - """Adds a new signature to the object. - - :param key: A (:class:`jwcrypto.jwk.JWK`) key of appropriate for - the "alg" provided. - :param alg: An optional algorithm name. If already provided as an - element of the protected or unprotected header it can be safely - omitted. - :param potected: The Protected Header (optional) - :param header: The Unprotected Header (optional) - - :raises InvalidJWSObject: if no payload has been set on the object, - or invalid headers are provided. - :raises ValueError: if the key is not a :class:`JWK` object. - :raises ValueError: if the algorithm is missing or is not provided - by one of the headers. - :raises InvalidJWAAlgorithm: if the algorithm is not valid, is - unknown or otherwise not yet implemented. - """ - - if not self.objects.get('payload', None): - raise InvalidJWSObject('Missing Payload') - - b64 = True - - p = dict() - if protected: - if isinstance(protected, dict): - p = protected - protected = json_encode(p) - else: - p = json_decode(protected) - - # If b64 is present we must enforce criticality - if 'b64' in list(p.keys()): - crit = p.get('crit', []) - if 'b64' not in crit: - raise InvalidJWSObject('b64 header must always be critical') - b64 = p['b64'] - - if 'b64' in self.objects: - if b64 != self.objects['b64']: - raise InvalidJWSObject('Mixed b64 headers on signatures') - - h = None - if header: - if isinstance(header, dict): - h = header - header = json_encode(header) - else: - h = json_decode(header) - - p = self._merge_check_headers(p, h) - - if 'alg' in p: - if alg is None: - alg = p['alg'] - elif alg != p['alg']: - raise ValueError('"alg" value mismatch, specified "alg" ' - 'does not match JOSE header value') - - if alg is None: - raise ValueError('"alg" not specified') - - c = JWSCore(alg, key, protected, self.objects['payload']) - sig = c.sign() - - o = dict() - o['signature'] = base64url_decode(sig['signature']) - if protected: - o['protected'] = protected - if header: - o['header'] = h - o['valid'] = True - - if 'signatures' in self.objects: - self.objects['signatures'].append(o) - elif 'signature' in self.objects: - self.objects['signatures'] = list() - n = dict() - n['signature'] = self.objects.pop('signature') - if 'protected' in self.objects: - n['protected'] = self.objects.pop('protected') - if 'header' in self.objects: - n['header'] = self.objects.pop('header') - if 'valid' in self.objects: - n['valid'] = self.objects.pop('valid') - self.objects['signatures'].append(n) - self.objects['signatures'].append(o) - else: - self.objects.update(o) - self.objects['b64'] = b64 - - def serialize(self, compact=False): - """Serializes the object into a JWS token. - - :param compact(boolean): if True generates the compact - representation, otherwise generates a standard JSON format. - - :raises InvalidJWSOperation: if the object cannot serialized - with the compact representation and `compat` is True. - :raises InvalidJWSSignature: if no signature has been added - to the object, or no valid signature can be found. - """ - if compact: - if 'signatures' in self.objects: - raise InvalidJWSOperation("Can't use compact encoding with " - "multiple signatures") - if 'signature' not in self.objects: - raise InvalidJWSSignature("No available signature") - if not self.objects.get('valid', False): - raise InvalidJWSSignature("No valid signature found") - if 'protected' in self.objects: - protected = base64url_encode(self.objects['protected']) - else: - protected = '' - if self.objects.get('payload', False): - if self.objects.get('b64', True): - payload = base64url_encode(self.objects['payload']) - else: - if isinstance(self.objects['payload'], bytes): - payload = self.objects['payload'].decode('utf-8') - else: - payload = self.objects['payload'] - if '.' in payload: - raise InvalidJWSOperation( - "Can't use compact encoding with unencoded " - "payload that uses the . character") - else: - payload = '' - return '.'.join([protected, payload, - base64url_encode(self.objects['signature'])]) - else: - obj = self.objects - sig = dict() - if self.objects.get('payload', False): - if self.objects.get('b64', True): - sig['payload'] = base64url_encode(self.objects['payload']) - else: - sig['payload'] = self.objects['payload'] - if 'signature' in obj: - if not obj.get('valid', False): - raise InvalidJWSSignature("No valid signature found") - sig['signature'] = base64url_encode(obj['signature']) - if 'protected' in obj: - sig['protected'] = base64url_encode(obj['protected']) - if 'header' in obj: - sig['header'] = obj['header'] - elif 'signatures' in obj: - sig['signatures'] = list() - for o in obj['signatures']: - if not o.get('valid', False): - continue - s = {'signature': base64url_encode(o['signature'])} - if 'protected' in o: - s['protected'] = base64url_encode(o['protected']) - if 'header' in o: - s['header'] = o['header'] - sig['signatures'].append(s) - if len(sig['signatures']) == 0: - raise InvalidJWSSignature("No valid signature found") - else: - raise InvalidJWSSignature("No available signature") - return json_encode(sig) - - @property - def payload(self): - if 'payload' not in self.objects: - raise InvalidJWSOperation("Payload not available") - if not self.is_valid: - raise InvalidJWSOperation("Payload not verified") - return self.objects['payload'] - - def detach_payload(self): - self.objects.pop('payload', None) - - @property - def jose_header(self): - obj = self.objects - if 'signature' in obj: - if 'protected' in obj: - p = json_decode(obj['protected']) - else: - p = None - return self._merge_check_headers(p, obj.get('header', dict())) - elif 'signatures' in self.objects: - jhl = list() - for o in obj['signatures']: - jh = dict() - if 'protected' in o: - p = json_decode(o['protected']) - else: - p = None - jh = self._merge_check_headers(p, o.get('header', dict())) - jhl.append(jh) - return jhl - else: - raise InvalidJWSOperation("JOSE Header(s) not available") diff --git a/jwcrypto/jwt.py b/jwcrypto/jwt.py deleted file mode 100644 index 06233d4..0000000 --- a/jwcrypto/jwt.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright (C) 2015 JWCrypto Project Contributors - see LICENSE file - -import time -import uuid - -from six import string_types - -from jwcrypto.common import JWException, json_decode, json_encode -from jwcrypto.jwe import JWE -from jwcrypto.jwk import JWK, JWKSet -from jwcrypto.jws import JWS - - -# RFC 7519 - 4.1 -# name: description -JWTClaimsRegistry = {'iss': 'Issuer', - 'sub': 'Subject', - 'aud': 'Audience', - 'exp': 'Expiration Time', - 'nbf': 'Not Before', - 'iat': 'Issued At', - 'jti': 'JWT ID'} - - -class JWTExpired(JWException): - """Json Web Token is expired. - - This exception is raised when a token is expired accoring to its claims. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Token expired' - if exception: - msg += ' {%s}' % str(exception) - super(JWTExpired, self).__init__(msg) - - -class JWTNotYetValid(JWException): - """Json Web Token is not yet valid. - - This exception is raised when a token is not valid yet according to its - claims. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Token not yet valid' - if exception: - msg += ' {%s}' % str(exception) - super(JWTNotYetValid, self).__init__(msg) - - -class JWTMissingClaim(JWException): - """Json Web Token claim is invalid. - - This exception is raised when a claim does not match the expected value. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Invalid Claim Value' - if exception: - msg += ' {%s}' % str(exception) - super(JWTMissingClaim, self).__init__(msg) - - -class JWTInvalidClaimValue(JWException): - """Json Web Token claim is invalid. - - This exception is raised when a claim does not match the expected value. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Invalid Claim Value' - if exception: - msg += ' {%s}' % str(exception) - super(JWTInvalidClaimValue, self).__init__(msg) - - -class JWTInvalidClaimFormat(JWException): - """Json Web Token claim format is invalid. - - This exception is raised when a claim is not in a valid format. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Invalid Claim Format' - if exception: - msg += ' {%s}' % str(exception) - super(JWTInvalidClaimFormat, self).__init__(msg) - - -class JWTMissingKeyID(JWException): - """Json Web Token is missing key id. - - This exception is raised when trying to decode a JWT with a key set - that does not have a kid value in its header. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Missing Key ID' - if exception: - msg += ' {%s}' % str(exception) - super(JWTMissingKeyID, self).__init__(msg) - - -class JWTMissingKey(JWException): - """Json Web Token is using a key not in the key set. - - This exception is raised if the key that was used is not available - in the passed key set. - """ - - def __init__(self, message=None, exception=None): - msg = None - if message: - msg = str(message) - else: - msg = 'Missing Key' - if exception: - msg += ' {%s}' % str(exception) - super(JWTMissingKey, self).__init__(msg) - - -class JWT(object): - """JSON Web token object - - This object represent a generic token. - """ - - def __init__(self, header=None, claims=None, jwt=None, key=None, - algs=None, default_claims=None, check_claims=None): - """Creates a JWT object. - - :param header: A dict or a JSON string with the JWT Header data. - :param claims: A dict or a string with the JWT Claims data. - :param jwt: a 'raw' JWT token - :param key: A (:class:`jwcrypto.jwk.JWK`) key to deserialize - the token. A (:class:`jwcrypto.jwk.JWKSet`) can also be used. - :param algs: An optional list of allowed algorithms - :param default_claims: An optional dict with default values for - registred claims. A None value for NumericDate type claims - will cause generation according to system time. Only the values - from RFC 7519 - 4.1 are evaluated. - :param check_claims: An optional dict of claims that must be - present in the token, if the value is not None the claim must - match exactly. - - Note: either the header,claims or jwt,key parameters should be - provided as a deserialization operation (which occurs if the jwt - is provided will wipe any header os claim provided by setting - those obtained from the deserialization of the jwt token. - - Note: if check_claims is not provided the 'exp' and 'nbf' claims - are checked if they are set on the token but not enforced if not - set. Any other RFC 7519 registered claims are checked only for - format conformance. - """ - - self._header = None - self._claims = None - self._token = None - self._algs = algs - self._reg_claims = None - self._check_claims = None - self._leeway = 60 # 1 minute clock skew allowed - self._validity = 600 # 10 minutes validity (up to 11 with leeway) - - if header: - self.header = header - - if default_claims is not None: - self._reg_claims = default_claims - - if check_claims is not None: - self._check_claims = check_claims - - if claims: - self.claims = claims - - if jwt is not None: - self.deserialize(jwt, key) - - @property - def header(self): - if self._header is None: - raise KeyError("'header' not set") - return self._header - - @header.setter - def header(self, h): - if isinstance(h, dict): - eh = json_encode(h) - else: - eh = h - h = json_decode(eh) - - if h.get('b64') is False: - raise ValueError("b64 header is invalid." - "JWTs cannot use unencoded payloads") - self._header = eh - - @property - def claims(self): - if self._claims is None: - raise KeyError("'claims' not set") - return self._claims - - @claims.setter - def claims(self, c): - if self._reg_claims and not isinstance(c, dict): - # decode c so we can set default claims - c = json_decode(c) - - if isinstance(c, dict): - self._add_default_claims(c) - self._claims = json_encode(c) - else: - self._claims = c - - @property - def token(self): - return self._token - - @token.setter - def token(self, t): - if isinstance(t, JWS) or isinstance(t, JWE) or isinstance(t, JWT): - self._token = t - else: - raise TypeError("Invalid token type, must be one of JWS,JWE,JWT") - - @property - def leeway(self): - return self._leeway - - @leeway.setter - def leeway(self, l): - self._leeway = int(l) - - @property - def validity(self): - return self._validity - - @validity.setter - def validity(self, v): - self._validity = int(v) - - def _add_optional_claim(self, name, claims): - if name in claims: - return - val = self._reg_claims.get(name, None) - if val is not None: - claims[name] = val - - def _add_time_claim(self, name, claims, defval): - if name in claims: - return - if name in self._reg_claims: - if self._reg_claims[name] is None: - claims[name] = defval - else: - claims[name] = self._reg_claims[name] - - def _add_jti_claim(self, claims): - if 'jti' in claims or 'jti' not in self._reg_claims: - return - claims['jti'] = str(uuid.uuid4()) - - def _add_default_claims(self, claims): - if self._reg_claims is None: - return - - now = int(time.time()) - self._add_optional_claim('iss', claims) - self._add_optional_claim('sub', claims) - self._add_optional_claim('aud', claims) - self._add_time_claim('exp', claims, now + self.validity) - self._add_time_claim('nbf', claims, now) - self._add_time_claim('iat', claims, now) - self._add_jti_claim(claims) - - def _check_string_claim(self, name, claims): - if name not in claims: - return - if not isinstance(claims[name], string_types): - raise JWTInvalidClaimFormat("Claim %s is not a StringOrURI type") - - def _check_array_or_string_claim(self, name, claims): - if name not in claims: - return - if isinstance(claims[name], list): - if any(not isinstance(claim, string_types) for claim in claims): - raise JWTInvalidClaimFormat( - "Claim %s contains non StringOrURI types" % (name, )) - elif not isinstance(claims[name], string_types): - raise JWTInvalidClaimFormat( - "Claim %s is not a StringOrURI type" % (name, )) - - def _check_integer_claim(self, name, claims): - if name not in claims: - return - try: - int(claims[name]) - except ValueError: - raise JWTInvalidClaimFormat( - "Claim %s is not an integer" % (name, )) - - def _check_exp(self, claim, limit, leeway): - if claim < limit - leeway: - raise JWTExpired('Expired at %d, time: %d(leeway: %d)' % ( - claim, limit, leeway)) - - def _check_nbf(self, claim, limit, leeway): - if claim > limit + leeway: - raise JWTNotYetValid('Valid from %d, time: %d(leeway: %d)' % ( - claim, limit, leeway)) - - def _check_default_claims(self, claims): - self._check_string_claim('iss', claims) - self._check_string_claim('sub', claims) - self._check_array_or_string_claim('aud', claims) - self._check_integer_claim('exp', claims) - self._check_integer_claim('nbf', claims) - self._check_integer_claim('iat', claims) - self._check_string_claim('jti', claims) - - if self._check_claims is None: - if 'exp' in claims: - self._check_exp(claims['exp'], time.time(), self._leeway) - if 'nbf' in claims: - self._check_nbf(claims['nbf'], time.time(), self._leeway) - - def _check_provided_claims(self): - # check_claims can be set to False to skip any check - if self._check_claims is False: - return - - try: - claims = json_decode(self.claims) - if not isinstance(claims, dict): - raise ValueError() - except ValueError: - if self._check_claims is not None: - raise JWTInvalidClaimFormat( - "Claims check requested but claims is not a json dict") - return - - self._check_default_claims(claims) - - if self._check_claims is None: - return - - for name, value in self._check_claims.items(): - if name not in claims: - raise JWTMissingClaim("Claim %s is missing" % (name, )) - - if name in ['iss', 'sub', 'jti']: - if value is not None and value != claims[name]: - raise JWTInvalidClaimValue( - "Invalid '%s' value. Expected '%s' got '%s'" % ( - name, value, claims[name])) - - elif name == 'aud': - if value is not None: - if value == claims[name]: - continue - if isinstance(claims[name], list): - if value in claims[name]: - continue - raise JWTInvalidClaimValue( - "Invalid '%s' value. Expected '%s' to be in '%s'" % ( - name, claims[name], value)) - - elif name == 'exp': - if value is not None: - self._check_exp(claims[name], value, 0) - else: - self._check_exp(claims[name], time.time(), self._leeway) - - elif name == 'nbf': - if value is not None: - self._check_nbf(claims[name], value, 0) - else: - self._check_nbf(claims[name], time.time(), self._leeway) - - else: - if value is not None and value != claims[name]: - raise JWTInvalidClaimValue( - "Invalid '%s' value. Expected '%s' got '%s'" % ( - name, value, claims[name])) - - def make_signed_token(self, key): - """Signs the payload. - - Creates a JWS token with the header as the JWS protected header and - the claims as the payload. See (:class:`jwcrypto.jws.JWS`) for - details on the exceptions that may be reaised. - - :param key: A (:class:`jwcrypto.jwk.JWK`) key. - """ - - t = JWS(self.claims) - t.add_signature(key, protected=self.header) - self.token = t - - def make_encrypted_token(self, key): - """Encrypts the payload. - - Creates a JWE token with the header as the JWE protected header and - the claims as the plaintext. See (:class:`jwcrypto.jwe.JWE`) for - details on the exceptions that may be reaised. - - :param key: A (:class:`jwcrypto.jwk.JWK`) key. - """ - - t = JWE(self.claims, self.header) - t.add_recipient(key) - self.token = t - - def deserialize(self, jwt, key=None): - """Deserialize a JWT token. - - NOTE: Destroys any current status and tries to import the raw - token provided. - - :param jwt: a 'raw' JWT token. - :param key: A (:class:`jwcrypto.jwk.JWK`) verification or - decryption key, or a (:class:`jwcrypto.jwk.JWKSet`) that - contains a key indexed by the 'kid' header. - """ - c = jwt.count('.') - if c == 2: - self.token = JWS() - elif c == 4: - self.token = JWE() - else: - raise ValueError("Token format unrecognized") - - # Apply algs restrictions if any, before performing any operation - if self._algs: - self.token.allowed_algs = self._algs - - # now deserialize and also decrypt/verify (or raise) if we - # have a key - if key is None: - self.token.deserialize(jwt, None) - elif isinstance(key, JWK): - self.token.deserialize(jwt, key) - elif isinstance(key, JWKSet): - self.token.deserialize(jwt, None) - if 'kid' not in self.token.jose_header: - raise JWTMissingKeyID('No key ID in JWT header') - - token_key = key.get_key(self.token.jose_header['kid']) - if not token_key: - raise JWTMissingKey('Key ID %s not in key set' - % self.token.jose_header['kid']) - - if isinstance(self.token, JWE): - self.token.decrypt(token_key) - elif isinstance(self.token, JWS): - self.token.verify(token_key) - else: - raise RuntimeError("Unknown Token Type") - else: - raise ValueError("Unrecognized Key Type") - - if key is not None: - self.header = self.token.jose_header - self.claims = self.token.payload.decode('utf-8') - self._check_provided_claims() - - def serialize(self, compact=True): - """Serializes the object into a JWS token. - - :param compact(boolean): must be True. - - Note: the compact parameter is provided for general compatibility - with the serialize() functions of :class:`jwcrypto.jws.JWS` and - :class:`jwcrypto.jwe.JWE` so that these objects can all be used - interchangeably. However the only valid JWT representtion is the - compact representation. - """ - return self.token.serialize(compact) -- GitLab