Commit e5c4cbb1 authored by Simo Sorce's avatar Simo Sorce

Turn JWKSet into a special dictionary

The JWKSet object is now a class derived from dict, with a special element
'keys' that is forced to be a class derived by set which enforces only JWK
object instances as elements.

A new classmethod "from_json" is added to create a JWKset out of a json object
directly.

The interface is fully backward compatible with the existing APIs otherwise.
Signed-off-by: default avatarSimo Sorce <simo@redhat.com>
Reviewed-by: default avatarNathaniel McCallum <npmccallum@redhat.com>
Closes #31
Closes #33
parent 72665470
......@@ -6,8 +6,12 @@ from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import ec
from jwcrypto.common import base64url_decode, base64url_encode
from jwcrypto.common import json_decode, json_encode
from six import iteritems
import os
# RFC 7518 - 7.4
JWKTypesRegistry = {'EC': 'Elliptic Curve',
'RSA': 'RSA',
......@@ -493,11 +497,7 @@ class JWK(object):
raise NotImplementedError
class JWKSet(set):
"""A set of JWK objects.
Inherits for the standard 'set' bultin type.
"""
class _jwkset(set):
def add(self, elem):
"""Adds a JWK object to the set
......@@ -510,18 +510,53 @@ class JWKSet(set):
raise TypeError('Only JWK objects are valid elements')
set.add(self, elem)
class JWKSet(dict):
"""A set of JWK objects.
Inherits from the standard 'dict' bultin type.
Creates a special key 'keys' that is of a type derived from 'set'
The 'keys' attribute accepts only :class:`jwcrypto.jwk.JWK` elements.
"""
def __init__(self, *args, **kwargs):
super(JWKSet, self).__init__()
super(JWKSet, self).__setitem__('keys', _jwkset())
self.update(*args, **kwargs)
def __setitem__(self, key, val):
if key == 'keys':
self['keys'].add(val)
else:
super(JWKSet, self).__setitem__(key, val)
def update(self, *args, **kwargs):
for k, v in iteritems(dict(*args, **kwargs)):
self.__setitem__(k, v)
def add(self, elem):
self['keys'].add(elem)
def export(self, private_keys=True):
"""Exports the set using the standard JSON format
"""Exports a RFC 7517 keyset using the standard JSON format
:param private_key(bool): Whether to export private keys.
Defaults to True.
"""
keys = list()
for jwk in self:
keys.append(json_decode(jwk.export(private_keys)))
return json_encode({'keys': keys})
exp_dict = dict()
for k, v in iteritems(self):
if k == 'keys':
keys = list()
for jwk in v:
keys.append(json_decode(jwk.export(private_keys)))
v = keys
exp_dict[k] = v
return json_encode(exp_dict)
def import_keyset(self, keyset):
"""Imports a RFC 7517 keyset using the standard JSON format.
:param keyset: The RFC 7517 representation of a JOSE Keyset.
"""
try:
jwkset = json_decode(keyset)
except:
......@@ -530,16 +565,29 @@ class JWKSet(set):
if 'keys' not in jwkset:
raise InvalidJWKValue()
for jwk in jwkset['keys']:
self.add(JWK(**jwk))
for k, v in iteritems(jwkset):
if k == 'keys':
for jwk in v:
self['keys'].add(JWK(**jwk))
else:
self[k] = v
return self
@classmethod
def from_json(cls, keyset):
"""Creates a RFC 7517 keyset from the standard JSON format.
:param keyset: The RFC 7517 representation of a JOSE Keyset.
"""
obj = cls()
return obj.import_keyset(keyset)
def get_key(self, kid):
"""Gets a key from the set.
:param kid: the 'kid' key identifier.
"""
for jwk in self:
for jwk in self['keys']:
if jwk.key_id == kid:
return jwk
return None
......@@ -229,6 +229,9 @@ class TestJWK(unittest.TestCase):
self.assertEqual(k1._key, k2._key)
# pylint: disable=protected-access
self.assertEqual(k1._key['d'], RSAPrivateKey['d'])
# test class method import too
ks3 = jwk.JWKSet.from_json(ks.export())
self.assertEqual(len(ks), len(ks3))
# RFC 7515 - A.1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment