jwk.py 29.7 KB
Newer Older
Simo Sorce's avatar
Simo Sorce committed
1 2
# Copyright (C) 2015  JWCrypto Project Contributors - see LICENSE file

Simo Sorce's avatar
Simo Sorce committed
3
import os
4
from binascii import hexlify, unhexlify
Simo Sorce's avatar
Simo Sorce committed
5

6
from cryptography import x509
Simo Sorce's avatar
Simo Sorce committed
7
from cryptography.hazmat.backends import default_backend
8
from cryptography.hazmat.primitives import hashes, serialization
Simo Sorce's avatar
Simo Sorce committed
9
from cryptography.hazmat.primitives.asymmetric import ec
Simo Sorce's avatar
Simo Sorce committed
10
from cryptography.hazmat.primitives.asymmetric import rsa
11 12 13

from six import iteritems

14
from jwcrypto.common import JWException
Simo Sorce's avatar
Simo Sorce committed
15 16
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode, json_encode
Simo Sorce's avatar
Simo Sorce committed
17

18

19
# RFC 7518 - 7.4
Simo Sorce's avatar
Simo Sorce committed
20 21 22
JWKTypesRegistry = {'EC': 'Elliptic Curve',
                    'RSA': 'RSA',
                    'oct': 'Octet sequence'}
Simo Sorce's avatar
Simo Sorce committed
23
"""Registry of valid Key Types"""
Simo Sorce's avatar
Simo Sorce committed
24

25
# RFC 7518 - 7.5
Simo Sorce's avatar
Simo Sorce committed
26 27
# It is part of the JWK Parameters Registry, but we want a more
# specific map for internal usage
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
JWKValuesRegistry = {'EC': {'crv': ('Curve', 'Public', 'Required'),
                            'x': ('X Coordinate', 'Public', 'Required'),
                            'y': ('Y Coordinate', 'Public', 'Required'),
                            'd': ('ECC Private Key', 'Private', None)},
                     'RSA': {'n': ('Modulus', 'Public', 'Required'),
                             'e': ('Exponent', 'Public', 'Required'),
                             'd': ('Private Exponent', 'Private', None),
                             'p': ('First Prime Factor', 'Private', None),
                             'q': ('Second Prime Factor', 'Private', None),
                             'dp': ('First Factor CRT Exponent', 'Private',
                                    None),
                             'dq': ('Second Factor CRT Exponent', 'Private',
                                    None),
                             'qi': ('First CRT Coefficient', 'Private', None)},
                     'oct': {'k': ('Key Value', 'Private', 'Required')}}
Simo Sorce's avatar
Simo Sorce committed
43
"""Registry of valid key values"""
Simo Sorce's avatar
Simo Sorce committed
44 45 46 47 48 49 50 51 52 53 54

JWKParamsRegistry = {'kty': ('Key Type', 'Public', ),
                     'use': ('Public Key Use', 'Public'),
                     'key_ops': ('Key Operations', 'Public'),
                     'alg': ('Algorithm', 'Public'),
                     'kid': ('Key ID', 'Public'),
                     'x5u': ('X.509 URL', 'Public'),
                     'x5c': ('X.509 Certificate Chain', 'Public'),
                     'x5t': ('X.509 Certificate SHA-1 Thumbprint', 'Public'),
                     'x5t#S256': ('X.509 Certificate SHA-256 Thumbprint',
                                  'Public')}
Simo Sorce's avatar
Simo Sorce committed
55
"""Regstry of valid key parameters"""
Simo Sorce's avatar
Simo Sorce committed
56

57
# RFC 7518 - 7.6
Simo Sorce's avatar
Simo Sorce committed
58 59 60
JWKEllipticCurveRegistry = {'P-256': 'P-256 curve',
                            'P-384': 'P-384 curve',
                            'P-521': 'P-521 curve'}
Simo Sorce's avatar
Simo Sorce committed
61
"""Registry of allowed Elliptic Curves"""
Simo Sorce's avatar
Simo Sorce committed
62

63
# RFC 7517 - 8.2
Simo Sorce's avatar
Simo Sorce committed
64 65
JWKUseRegistry = {'sig': 'Digital Signature or MAC',
                  'enc': 'Encryption'}
Simo Sorce's avatar
Simo Sorce committed
66
"""Registry of allowed uses"""
Simo Sorce's avatar
Simo Sorce committed
67

68
# RFC 7517 - 8.3
Simo Sorce's avatar
Simo Sorce committed
69 70 71 72 73 74 75 76 77 78
JWKOperationsRegistry = {'sign': 'Compute digital Signature or MAC',
                         'verify': 'Verify digital signature or MAC',
                         'encrypt': 'Encrypt content',
                         'decrypt': 'Decrypt content and validate'
                                    ' decryption, if applicable',
                         'wrapKey': 'Encrypt key',
                         'unwrapKey': 'Decrypt key and validate'
                                    ' decryption, if applicable',
                         'deriveKey': 'Derive key',
                         'deriveBits': 'Derive bits not to be used as a key'}
Simo Sorce's avatar
Simo Sorce committed
79
"""Registry of allowed operations"""
Simo Sorce's avatar
Simo Sorce committed
80

81 82 83 84
JWKpycaCurveMap = {'secp256r1': 'P-256',
                   'secp384r1': 'P-384',
                   'secp521r1': 'P-521'}

Simo Sorce's avatar
Simo Sorce committed
85

86
class InvalidJWKType(JWException):
Simo Sorce's avatar
Simo Sorce committed
87 88 89 90
    """Invalid JWK Type Exception.

    This exception is raised when an invalid parameter type is used.
    """
Simo Sorce's avatar
Simo Sorce committed
91 92 93 94 95 96 97

    def __init__(self, value=None):
        super(InvalidJWKType, self).__init__()
        self.value = value

    def __str__(self):
        return 'Unknown type "%s", valid types are: %s' % (
98
            self.value, list(JWKTypesRegistry.keys()))
Simo Sorce's avatar
Simo Sorce committed
99 100


101
class InvalidJWKUsage(JWException):
Simo Sorce's avatar
Simo Sorce committed
102 103 104 105 106
    """Invalid JWK usage Exception.

    This exception is raised when an invalid key usage is requested,
    based on the key type and declared usage constraints.
    """
Simo Sorce's avatar
Simo Sorce committed
107 108 109 110 111 112 113

    def __init__(self, use, value):
        super(InvalidJWKUsage, self).__init__()
        self.value = value
        self.use = use

    def __str__(self):
114
        if self.use in list(JWKUseRegistry.keys()):
Simo Sorce's avatar
Simo Sorce committed
115 116 117
            usage = JWKUseRegistry[self.use]
        else:
            usage = 'Unknown(%s)' % self.use
118
        if self.value in list(JWKUseRegistry.keys()):
Simo Sorce's avatar
Simo Sorce committed
119 120 121 122 123 124 125
            valid = JWKUseRegistry[self.value]
        else:
            valid = 'Unknown(%s)' % self.value
        return 'Invalid usage requested: "%s". Valid for: "%s"' % (usage,
                                                                   valid)


126
class InvalidJWKOperation(JWException):
Simo Sorce's avatar
Simo Sorce committed
127 128 129 130 131
    """Invalid JWK Operation Exception.

    This exception is raised when an invalid key operation is requested,
    based on the key type and declared usage constraints.
    """
Simo Sorce's avatar
Simo Sorce committed
132 133 134 135 136 137 138

    def __init__(self, operation, values):
        super(InvalidJWKOperation, self).__init__()
        self.op = operation
        self.values = values

    def __str__(self):
139
        if self.op in list(JWKOperationsRegistry.keys()):
Simo Sorce's avatar
Simo Sorce committed
140 141 142 143 144
            op = JWKOperationsRegistry[self.op]
        else:
            op = 'Unknown(%s)' % self.op
        valid = list()
        for v in self.values:
145
            if v in list(JWKOperationsRegistry.keys()):
Simo Sorce's avatar
Simo Sorce committed
146 147 148 149 150 151 152
                valid.append(JWKOperationsRegistry[v])
            else:
                valid.append('Unknown(%s)' % v)
        return 'Invalid operation requested: "%s". Valid for: "%s"' % (op,
                                                                       valid)


153
class InvalidJWKValue(JWException):
Simo Sorce's avatar
Simo Sorce committed
154 155 156 157 158 159 160
    """Invalid JWK Value Exception.

    This exception is raised when an invalid/unknown value is used in the
    context of an operation that requires specific values to be used based
    on the key type or other constraints.
    """

Simo Sorce's avatar
Simo Sorce committed
161 162 163 164
    pass


class JWK(object):
Simo Sorce's avatar
Simo Sorce committed
165 166 167 168
    """JSON Web Key object

    This object represent a Key.
    It must be instantiated by using the standard defined key/value pairs
169
    as arguments of the initialization function.
Simo Sorce's avatar
Simo Sorce committed
170
    """
Simo Sorce's avatar
Simo Sorce committed
171 172

    def __init__(self, **kwargs):
Simo Sorce's avatar
Simo Sorce committed
173 174 175 176 177 178 179 180 181 182
        """Creates a new JWK object.

        The function arguments must be valid parameters as defined in the
        'IANA JSON Web Key Set Parameters registry' and specified in
        the :data:`JWKParamsRegistry` variable. The 'kty' parameter must
        always be provided and its value must be a valid one as defined
        by the 'IANA JSON Web Key Types registry' and specified in the
        :data:`JWKTypesRegistry` variable. The valid key parameters per
        key type are defined in the :data:`JWKValuesregistry` variable.

183 184 185
        To generate a new random key call the class method generate() with
        the appropriate 'kty' parameter, and other parameters as needed (key
        size, public exponents, curve types, etc..)
186 187 188 189 190 191

        Valid options per type, when generating new keys:
         * oct: size(int)
         * RSA: public_exponent(int), size(int)
         * EC: curve(str) (one of P-256, P-384, P-521)

192 193 194 195 196
        Deprecated:
        Alternatively if the 'generate' parameter is provided, with a
        valid key type as value then a new key will be generated according
        to the defaults or provided key strenght options (type specific).

Simo Sorce's avatar
Simo Sorce committed
197 198
        :raises InvalidJWKType: if the key type is invalid
        :raises InvalidJWKValue: if incorrect or inconsistent parameters
Christian Heimes's avatar
Christian Heimes committed
199
            are provided.
Simo Sorce's avatar
Simo Sorce committed
200
        """
201 202 203 204 205 206
        self._params = dict()
        self._key = dict()
        self._unknown = dict()

        if 'generate' in kwargs:
            self.generate_key(**kwargs)
207
        elif kwargs:
208 209
            self.import_key(**kwargs)

210 211 212
    @classmethod
    def generate(cls, **kwargs):
        obj = cls()
213
        kty = None
214 215 216 217 218 219 220 221
        try:
            kty = kwargs['kty']
            gen = getattr(obj, '_generate_%s' % kty)
        except (KeyError, AttributeError):
            raise InvalidJWKType(kty)
        gen(kwargs)
        return obj

Simo Sorce's avatar
Simo Sorce committed
222
    def generate_key(self, **params):
223
        kty = None
224
        try:
Simo Sorce's avatar
Simo Sorce committed
225
            kty = params.pop('generate')
226 227 228
            gen = getattr(self, '_generate_%s' % kty)
        except (KeyError, AttributeError):
            raise InvalidJWKType(kty)
Simo Sorce's avatar
Simo Sorce committed
229

230 231
        gen(params)

232 233
    def _get_gen_size(self, params, default_size=None):
        size = default_size
234
        if 'size' in params:
Simo Sorce's avatar
Simo Sorce committed
235
            size = params.pop('size')
236 237 238 239 240 241
        elif 'alg' in params:
            try:
                from jwcrypto.jwa import JWA
                alg = JWA.instantiate_alg(params['alg'])
            except KeyError:
                raise ValueError("Invalid 'alg' parameter")
242
            size = alg.keysize
243 244 245 246
        return size

    def _generate_oct(self, params):
        size = self._get_gen_size(params, 128)
247 248
        key = os.urandom(size // 8)
        params['kty'] = 'oct'
249
        params['k'] = base64url_encode(key)
250
        self.import_key(**params)
251 252

    def _encode_int(self, i):
Simo Sorce's avatar
Simo Sorce committed
253 254
        intg = hex(i).rstrip("L").lstrip("0x")
        return base64url_encode(unhexlify((len(intg) % 2) * '0' + intg))
255 256 257

    def _generate_RSA(self, params):
        pubexp = 65537
258
        size = self._get_gen_size(params, 2048)
259
        if 'public_exponent' in params:
Simo Sorce's avatar
Simo Sorce committed
260
            pubexp = params.pop('public_exponent')
261
        key = rsa.generate_private_key(pubexp, size, default_backend())
Simo Sorce's avatar
Simo Sorce committed
262
        self._import_pyca_pri_rsa(key, **params)
263

Simo Sorce's avatar
Simo Sorce committed
264
    def _import_pyca_pri_rsa(self, key, **params):
265
        pn = key.private_numbers()
Simo Sorce's avatar
Simo Sorce committed
266 267 268 269 270 271 272 273 274 275 276
        params.update(
            kty='RSA',
            n=self._encode_int(pn.public_numbers.n),
            e=self._encode_int(pn.public_numbers.e),
            d=self._encode_int(pn.d),
            p=self._encode_int(pn.p),
            q=self._encode_int(pn.q),
            dp=self._encode_int(pn.dmp1),
            dq=self._encode_int(pn.dmq1),
            qi=self._encode_int(pn.iqmp)
        )
277
        self.import_key(**params)
278

Simo Sorce's avatar
Simo Sorce committed
279
    def _import_pyca_pub_rsa(self, key, **params):
280
        pn = key.public_numbers()
Simo Sorce's avatar
Simo Sorce committed
281 282 283 284 285
        params.update(
            kty='RSA',
            n=self._encode_int(pn.n),
            e=self._encode_int(pn.e)
        )
286 287
        self.import_key(**params)

288 289 290 291 292 293 294 295 296 297 298 299 300
    def _get_curve_by_name(self, name):
        if name == 'P-256':
            return ec.SECP256R1()
        elif name == 'P-384':
            return ec.SECP384R1()
        elif name == 'P-521':
            return ec.SECP521R1()
        else:
            raise InvalidJWKValue('Unknown Elliptic Curve Type')

    def _generate_EC(self, params):
        curve = 'P-256'
        if 'curve' in params:
Simo Sorce's avatar
Simo Sorce committed
301
            curve = params.pop('curve')
302 303 304
        # 'curve' is for backwards compat, if 'crv' is defined it takes
        # precedence
        if 'crv' in params:
Simo Sorce's avatar
Simo Sorce committed
305
            curve = params.pop('crv')
306 307
        curve_name = self._get_curve_by_name(curve)
        key = ec.generate_private_key(curve_name, default_backend())
Simo Sorce's avatar
Simo Sorce committed
308
        self._import_pyca_pri_ec(key, **params)
309

Simo Sorce's avatar
Simo Sorce committed
310
    def _import_pyca_pri_ec(self, key, **params):
311
        pn = key.private_numbers()
Simo Sorce's avatar
Simo Sorce committed
312 313 314 315 316 317 318
        params.update(
            kty='EC',
            crv=JWKpycaCurveMap[key.curve.name],
            x=self._encode_int(pn.public_numbers.x),
            y=self._encode_int(pn.public_numbers.y),
            d=self._encode_int(pn.private_value)
        )
319
        self.import_key(**params)
320

Simo Sorce's avatar
Simo Sorce committed
321
    def _import_pyca_pub_ec(self, key, **params):
322
        pn = key.public_numbers()
Simo Sorce's avatar
Simo Sorce committed
323 324 325 326 327 328
        params.update(
            kty='EC',
            crv=JWKpycaCurveMap[key.curve.name],
            x=self._encode_int(pn.x),
            y=self._encode_int(pn.y),
        )
329 330
        self.import_key(**params)

331
    def import_key(self, **kwargs):
332
        names = list(kwargs.keys())
Simo Sorce's avatar
Simo Sorce committed
333

334
        for name in list(JWKParamsRegistry.keys()):
Simo Sorce's avatar
Simo Sorce committed
335 336 337 338 339 340 341 342 343
            if name in kwargs:
                self._params[name] = kwargs[name]
                while name in names:
                    names.remove(name)

        kty = self._params.get('kty', None)
        if kty not in JWKTypesRegistry:
            raise InvalidJWKType(kty)

344
        for name in list(JWKValuesRegistry[kty].keys()):
Simo Sorce's avatar
Simo Sorce committed
345 346 347 348 349
            if name in kwargs:
                self._key[name] = kwargs[name]
                while name in names:
                    names.remove(name)

350 351 352 353
        for name, val in iteritems(JWKValuesRegistry[kty]):
            if val[2] == 'Required' and name not in self._key:
                raise InvalidJWKValue('Missing required value %s' % name)

354 355 356 357
        # Unknown key parameters are allowed
        # Let's just store them out of the way
        for name in names:
            self._unknown[name] = kwargs[name]
Simo Sorce's avatar
Simo Sorce committed
358 359 360 361

        if len(self._key) == 0:
            raise InvalidJWKValue('No Key Values found')

362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
        # check key_ops
        if 'key_ops' in self._params:
            for ko in self._params['key_ops']:
                c = 0
                for cko in self._params['key_ops']:
                    if ko == cko:
                        c += 1
                if c != 1:
                    raise InvalidJWKValue('Duplicate values in "key_ops"')

        # check use/key_ops consistency
        if 'use' in self._params and 'key_ops' in self._params:
            sigl = ['sign', 'verify']
            encl = ['encrypt', 'decrypt', 'wrapKey', 'unwrapKey',
                    'deriveKey', 'deriveBits']
            if self._params['use'] == 'sig':
                for op in encl:
                    if op in self._params['key_ops']:
                        raise InvalidJWKValue('Incompatible "use" and'
                                              ' "key_ops" values specified at'
                                              ' the same time')
            elif self._params['use'] == 'enc':
                for op in sigl:
                    if op in self._params['key_ops']:
                        raise InvalidJWKValue('Incompatible "use" and'
                                              ' "key_ops" values specified at'
                                              ' the same time')

390 391 392 393 394 395 396 397 398 399 400 401 402
    @classmethod
    def from_json(cls, key):
        """Creates a RFC 7517 JWK from the standard JSON format.

        :param key: The RFC 7517 representation of a JWK.
        """
        obj = cls()
        try:
            jkey = json_decode(key)
        except Exception as e:  # pylint: disable=broad-except
            raise InvalidJWKValue(e)
        return obj.import_key(**jkey)

403 404
    def export(self, private_key=True):
        """Exports the key in the standard JSON format.
405 406
        Exports the key regardless of type, if private_key is False
        and the key is_symmetric an exceptionis raised.
407 408 409 410

        :param private_key(bool): Whether to export the private key.
                                  Defaults to True.
        """
411 412 413 414 415
        if private_key is True:
            # Use _export_all for backwards compatibility, as this
            # function allows to export symmetrict keys too
            return self._export_all()
        else:
416
            return self.export_public()
Simo Sorce's avatar
Simo Sorce committed
417

418
    def export_public(self):
419
        """Exports the public key in the standard JSON format.
420 421 422
        It fails if one is not available like when this function
        is called on a symmetric key.
        """
423 424 425 426
        pub = self._public_params()
        return json_encode(pub)

    def _public_params(self):
427 428
        if not self.has_public:
            raise InvalidJWKType("No public key available")
429 430 431 432 433 434
        pub = {}
        preg = JWKParamsRegistry
        for name in preg:
            if preg[name][1] == 'Public':
                if name in self._params:
                    pub[name] = self._params[name]
435 436 437 438
        reg = JWKValuesRegistry[self._params['kty']]
        for param in reg:
            if reg[param][1] == 'Public':
                pub[param] = self._key[param]
439
        return pub
440

441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
    def _export_all(self):
        d = dict()
        d.update(self._params)
        d.update(self._key)
        d.update(self._unknown)
        return json_encode(d)

    def export_private(self):
        """Export the private key in the standard JSON format.
        It fails for a JWK that has only a public key or is symmetric.
        """
        if self.has_private:
            return self._export_all()
        raise InvalidJWKType("No private key available")

    def export_symmetric(self):
        if self.is_symmetric:
            return self._export_all()
        raise InvalidJWKType("Not a symmetric key")

461 462 463 464
    def public(self):
        pub = self._public_params()
        return JWK(**pub)

465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490
    @property
    def has_public(self):
        """Whether this JWK has an asymmetric Public key."""
        if self.is_symmetric:
            return False
        reg = JWKValuesRegistry[self._params['kty']]
        for value in reg:
            if reg[value][1] == 'Public' and value in self._key:
                return True

    @property
    def has_private(self):
        """Whether this JWK has an asymmetric key Private key."""
        if self.is_symmetric:
            return False
        reg = JWKValuesRegistry[self._params['kty']]
        for value in reg:
            if reg[value][1] == 'Private' and value in self._key:
                return True
        return False

    @property
    def is_symmetric(self):
        """Whether this JWK is a symmetric key."""
        return self.key_type == 'oct'

491 492
    @property
    def key_type(self):
Simo Sorce's avatar
Simo Sorce committed
493
        """The Key type"""
494 495
        return self._params.get('kty', None)

Simo Sorce's avatar
Simo Sorce committed
496 497
    @property
    def key_id(self):
Simo Sorce's avatar
Simo Sorce committed
498 499 500
        """The Key ID.
        Provided by the kid parameter if present, otherwise returns None.
        """
Simo Sorce's avatar
Simo Sorce committed
501 502
        return self._params.get('kid', None)

503 504 505 506 507 508 509
    @property
    def key_curve(self):
        """The Curve Name."""
        if self._params['kty'] != 'EC':
            raise InvalidJWKType('Not an EC key')
        return self._key['crv']

Simo Sorce's avatar
Simo Sorce committed
510
    def get_curve(self, arg):
Simo Sorce's avatar
Simo Sorce committed
511 512 513 514 515 516 517
        """Gets the Elliptic Curve associated with the key.

        :param arg: an optional curve name

        :raises InvalidJWKType: the key is not an EC key.
        :raises InvalidJWKValue: if the curve names is invalid.
        """
Simo Sorce's avatar
Simo Sorce committed
518 519 520 521 522 523
        k = self._key
        if self._params['kty'] != 'EC':
            raise InvalidJWKType('Not an EC key')
        if arg and k['crv'] != arg:
            raise InvalidJWKValue('Curve requested is "%s", but '
                                  'key curve is "%s"' % (arg, k['crv']))
524 525

        return self._get_curve_by_name(k['crv'])
Simo Sorce's avatar
Simo Sorce committed
526 527 528 529 530 531 532 533 534 535 536 537 538 539

    def _check_constraints(self, usage, operation):
        use = self._params.get('use', None)
        if use and use != usage:
            raise InvalidJWKUsage(usage, use)
        ops = self._params.get('key_ops', None)
        if ops:
            if not isinstance(ops, list):
                ops = [ops]
            if operation not in ops:
                raise InvalidJWKOperation(operation, ops)
        # TODO: check alg ?

    def _decode_int(self, n):
540
        return int(hexlify(base64url_decode(n)), 16)
Simo Sorce's avatar
Simo Sorce committed
541

Simo Sorce's avatar
Simo Sorce committed
542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
    def _rsa_pub(self, k):
        return rsa.RSAPublicNumbers(self._decode_int(k['e']),
                                    self._decode_int(k['n']))

    def _rsa_pri(self, k):
        return rsa.RSAPrivateNumbers(self._decode_int(k['p']),
                                     self._decode_int(k['q']),
                                     self._decode_int(k['d']),
                                     self._decode_int(k['dp']),
                                     self._decode_int(k['dq']),
                                     self._decode_int(k['qi']),
                                     self._rsa_pub(k))

    def _ec_pub(self, k, curve):
        return ec.EllipticCurvePublicNumbers(self._decode_int(k['x']),
                                             self._decode_int(k['y']),
                                             self.get_curve(curve))

    def _ec_pri(self, k, curve):
        return ec.EllipticCurvePrivateNumbers(self._decode_int(k['d']),
                                              self._ec_pub(k, curve))

564
    def _get_public_key(self, arg=None):
Simo Sorce's avatar
Simo Sorce committed
565 566 567
        if self._params['kty'] == 'oct':
            return self._key['k']
        elif self._params['kty'] == 'RSA':
Simo Sorce's avatar
Simo Sorce committed
568 569 570 571 572 573
            return self._rsa_pub(self._key).public_key(default_backend())
        elif self._params['kty'] == 'EC':
            return self._ec_pub(self._key, arg).public_key(default_backend())
        else:
            raise NotImplementedError

574
    def _get_private_key(self, arg=None):
Simo Sorce's avatar
Simo Sorce committed
575 576 577
        if self._params['kty'] == 'oct':
            return self._key['k']
        elif self._params['kty'] == 'RSA':
578
            return self._rsa_pri(self._key).private_key(default_backend())
Simo Sorce's avatar
Simo Sorce committed
579
        elif self._params['kty'] == 'EC':
580
            return self._ec_pri(self._key, arg).private_key(default_backend())
Simo Sorce's avatar
Simo Sorce committed
581 582 583
        else:
            raise NotImplementedError

584
    def get_op_key(self, operation=None, arg=None):
Simo Sorce's avatar
Simo Sorce committed
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
        """Get the key object associated to the requested opration.
        For example the public RSA key for the 'verify' operation or
        the private EC key for the 'decrypt' operation.

        :param operation: The requested operation.
         The valid set of operations is availble in the
         :data:`JWKOperationsRegistry` registry.
        :param arg: an optional, context specific, argument
         For example a curve name.

        :raises InvalidJWKOperation: if the operation is unknown or
         not permitted with this key.
        :raises InvalidJWKUsage: if the use constraints do not permit
         the operation.
        """
600 601
        validops = self._params.get('key_ops',
                                    list(JWKOperationsRegistry.keys()))
602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
        if validops is not list:
            validops = [validops]
        if operation is None:
            if self._params['kty'] == 'oct':
                return self._key['k']
            raise InvalidJWKOperation(operation, validops)
        elif operation == 'sign':
            self._check_constraints('sig', operation)
            return self._get_private_key(arg)
        elif operation == 'verify':
            self._check_constraints('sig', operation)
            return self._get_public_key(arg)
        elif operation == 'encrypt' or operation == 'wrapKey':
            self._check_constraints('enc', operation)
            return self._get_public_key(arg)
        elif operation == 'decrypt' or operation == 'unwrapKey':
            self._check_constraints('enc', operation)
            return self._get_private_key(arg)
Simo Sorce's avatar
Simo Sorce committed
620 621 622
        else:
            raise NotImplementedError

623 624 625 626 627 628 629 630 631 632 633 634
    def import_from_pyca(self, key):
        if isinstance(key, rsa.RSAPrivateKey):
            self._import_pyca_pri_rsa(key)
        elif isinstance(key, rsa.RSAPublicKey):
            self._import_pyca_pub_rsa(key)
        elif isinstance(key, ec.EllipticCurvePrivateKey):
            self._import_pyca_pri_ec(key)
        elif isinstance(key, ec.EllipticCurvePublicKey):
            self._import_pyca_pub_ec(key)
        else:
            raise InvalidJWKValue('Unknown key object %r' % key)

635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680
    def import_from_pem(self, data, password=None):
        """Imports a key from data loaded from a PEM file.
        The key may be encrypted with a password.
        Private keys (PKCS#8 format), public keys, and X509 certificate's
        public keys can be imported with this interface.

        :param data(bytes): The data contained in a PEM file.
        :param password(bytes): An optional password to unwrap the key.
        """

        try:
            key = serialization.load_pem_private_key(
                data, password=password, backend=default_backend())
        except ValueError as e:
            if password is not None:
                raise e
            try:
                key = serialization.load_pem_public_key(
                    data, backend=default_backend())
            except ValueError:
                try:
                    cert = x509.load_pem_x509_certificate(
                        data, backend=default_backend())
                    key = cert.public_key()
                except ValueError:
                    raise e

        self.import_from_pyca(key)
        self._params['kid'] = self.thumbprint()

    def export_to_pem(self, private_key=False, password=False):
        """Exports keys to a data buffer suitable to be stored as a PEM file.
        Either the public or the private key can be exported to a PEM file.
        For private keys the PKCS#8 format is used. If a password is provided
        the best encryption method available as determined by the cryptography
        module is used to wrap the key.

        :param private_key: Whether the private key should be exported.
         Defaults to `False` which means the public key is exported by default.
        :param password(bytes): A password for wrapping the private key.
         Defaults to False which will cause the operation to fail. To avoid
         encryption the user must explicitly pass None, otherwise the user
         needs to provide a password in a bytes buffer.
        """
        e = serialization.Encoding.PEM
        if private_key:
681 682
            if not self.has_private:
                raise InvalidJWKType("No private key available")
683 684 685 686 687 688 689 690 691 692 693 694
            f = serialization.PrivateFormat.PKCS8
            if password is None:
                a = serialization.NoEncryption()
            elif isinstance(password, bytes):
                a = serialization.BestAvailableEncryption(password)
            elif password is False:
                raise ValueError("The password must be None or a bytes string")
            else:
                raise TypeError("The password string must be bytes")
            return self._get_private_key().private_bytes(
                encoding=e, format=f, encryption_algorithm=a)
        else:
695 696
            if not self.has_public:
                raise InvalidJWKType("No public key available")
697 698 699
            f = serialization.PublicFormat.SubjectPublicKeyInfo
            return self._get_public_key().public_bytes(encoding=e, format=f)

700 701 702 703 704 705
    @classmethod
    def from_pyca(cls, key):
        obj = cls()
        obj.import_from_pyca(key)
        return obj

706 707 708 709 710 711 712 713 714 715 716 717
    @classmethod
    def from_pem(cls, data, password=None):
        """Creates a key from PKCS#8 formatted data loaded from a PEM file.
           See the function `import_from_pem` for details.

        :param data(bytes): The data contained in a PEM file.
        :param password(bytes): An optional password to unwrap the key.
        """
        obj = cls()
        obj.import_from_pem(data, password)
        return obj

718 719 720 721 722 723 724 725 726 727 728 729 730 731
    def thumbprint(self, hashalg=hashes.SHA256()):
        """Returns the key thumbprint as specified by RFC 7638.

        :param hashalg: A hash function (defaults to SHA256)
        """

        t = {'kty': self._params['kty']}
        for name, val in iteritems(JWKValuesRegistry[t['kty']]):
            if val[2] == 'Required':
                t[name] = self._key[name]
        digest = hashes.Hash(hashalg, backend=default_backend())
        digest.update(bytes(json_encode(t).encode('utf8')))
        return base64url_encode(digest.finalize())

Simo Sorce's avatar
Simo Sorce committed
732

Simo Sorce's avatar
Simo Sorce committed
733
class _JWKkeys(set):
Simo Sorce's avatar
Simo Sorce committed
734 735

    def add(self, elem):
Simo Sorce's avatar
Simo Sorce committed
736 737 738 739 740 741
        """Adds a JWK object to the set

        :param elem: the JWK object to add.

        :raises TypeError: if the object is not a JWK.
        """
Simo Sorce's avatar
Simo Sorce committed
742 743 744 745
        if not isinstance(elem, JWK):
            raise TypeError('Only JWK objects are valid elements')
        set.add(self, elem)

746 747 748 749 750 751 752 753 754 755

class JWKSet(dict):
    """A set of JWK objects.

    Inherits from the standard 'dict' bultin type.
    Creates a special key 'keys' that is of a type derived from 'set'
    The 'keys' attribute accepts only :class:`jwcrypto.jwk.JWK` elements.
    """
    def __init__(self, *args, **kwargs):
        super(JWKSet, self).__init__()
Simo Sorce's avatar
Simo Sorce committed
756
        super(JWKSet, self).__setitem__('keys', _JWKkeys())
757 758
        self.update(*args, **kwargs)

759 760 761 762 763 764
    def __iter__(self):
        return self['keys'].__iter__()

    def __contains__(self, key):
        return self['keys'].__contains__(key)

765 766 767 768 769 770 771 772 773 774 775 776 777
    def __setitem__(self, key, val):
        if key == 'keys':
            self['keys'].add(val)
        else:
            super(JWKSet, self).__setitem__(key, val)

    def update(self, *args, **kwargs):
        for k, v in iteritems(dict(*args, **kwargs)):
            self.__setitem__(k, v)

    def add(self, elem):
        self['keys'].add(elem)

778
    def export(self, private_keys=True):
779
        """Exports a RFC 7517 keyset using the standard JSON format
780 781 782 783

        :param private_key(bool): Whether to export private keys.
                                  Defaults to True.
        """
784 785 786 787 788 789 790 791 792
        exp_dict = dict()
        for k, v in iteritems(self):
            if k == 'keys':
                keys = list()
                for jwk in v:
                    keys.append(json_decode(jwk.export(private_keys)))
                v = keys
            exp_dict[k] = v
        return json_encode(exp_dict)
Simo Sorce's avatar
Simo Sorce committed
793

794
    def import_keyset(self, keyset):
795 796 797 798
        """Imports a RFC 7517 keyset using the standard JSON format.

        :param keyset: The RFC 7517 representation of a JOSE Keyset.
        """
799 800
        try:
            jwkset = json_decode(keyset)
Simo Sorce's avatar
Simo Sorce committed
801
        except Exception:  # pylint: disable=broad-except
802 803 804 805 806
            raise InvalidJWKValue()

        if 'keys' not in jwkset:
            raise InvalidJWKValue()

807 808 809 810 811 812
        for k, v in iteritems(jwkset):
            if k == 'keys':
                for jwk in v:
                    self['keys'].add(JWK(**jwk))
            else:
                self[k] = v
813 814 815

        return self

816 817 818 819 820 821 822 823 824
    @classmethod
    def from_json(cls, keyset):
        """Creates a RFC 7517 keyset from the standard JSON format.

        :param keyset: The RFC 7517 representation of a JOSE Keyset.
        """
        obj = cls()
        return obj.import_keyset(keyset)

Simo Sorce's avatar
Simo Sorce committed
825
    def get_key(self, kid):
Simo Sorce's avatar
Simo Sorce committed
826 827 828
        """Gets a key from the set.
        :param kid: the 'kid' key identifier.
        """
829
        for jwk in self['keys']:
Simo Sorce's avatar
Simo Sorce committed
830 831 832
            if jwk.key_id == kid:
                return jwk
        return None