Commit 75d03695 authored by Simo Sorce's avatar Simo Sorce

Change in/out params for key wrapping

This makes it possible to return more data as required from algorithms like
AES GCM Key Wrap, and input it back in the unwrapping function.
Signed-off-by: default avatarSimo Sorce <simo@redhat.com>
parent 81d43908
......@@ -139,10 +139,10 @@ class InvalidJWEKeyLength(Exception):
class _RawKeyMgmt(object):
def wrap(self, key, keylen, cek):
def wrap(self, key, keylen, cek, headers):
raise NotImplementedError
def unwrap(self, key, ek):
def unwrap(self, key, keylen, ek, headers):
raise NotImplementedError
......@@ -156,18 +156,20 @@ class _RSA(_RawKeyMgmt):
raise InvalidJWEKeyType('RSA', key.key_type)
# FIXME: get key size and insure > 2048 bits
def wrap(self, key, keylen, cek):
def wrap(self, key, keylen, cek, headers):
self._check_key(key)
if not cek:
cek = os.urandom(keylen)
rk = key.get_op_key('encrypt')
ek = rk.encrypt(cek, self.padfn)
return (cek, ek)
return {'cek': cek, 'ek': ek}
def unwrap(self, key, ek):
def unwrap(self, key, keylen, ek, headers):
self._check_key(key)
rk = key.get_op_key('decrypt')
cek = rk.decrypt(ek, self.padfn)
if len(cek) != keylen:
raise InvalidJWEKeyLength(keylen, len(cek))
return cek
......@@ -185,7 +187,7 @@ class _AesKw(_RawKeyMgmt):
raise InvalidJWEKeyLength(self.keysize * 8, len(rk) * 8)
return rk
def wrap(self, key, keylen, cek):
def wrap(self, key, keylen, cek, headers):
rk = self._get_key(key, 'encrypt')
if not cek:
cek = os.urandom(keylen)
......@@ -206,9 +208,9 @@ class _AesKw(_RawKeyMgmt):
ek = a
for i in range(0, n):
ek += r[i]
return (cek, ek)
return {'cek': cek, 'ek': ek}
def unwrap(self, key, ek):
def unwrap(self, key, keylen, ek, headers):
rk = self._get_key(key, 'decrypt')
# Implement RFC 3394 Key Unwrap - 2.2.3
......@@ -233,6 +235,8 @@ class _AesKw(_RawKeyMgmt):
raise InvalidJWEData('Decryption Failed')
cek = b''.join(r)
if len(cek) != keylen:
raise InvalidJWEKeyLength(keylen, len(cek))
return cek
......@@ -242,20 +246,23 @@ class _Direct(_RawKeyMgmt):
if key.key_type != 'oct':
raise InvalidJWEKeyType('oct', key.key_type)
def wrap(self, key, keylen, cek):
def wrap(self, key, keylen, cek, headers):
self._check_key(key)
if cek:
return (cek, None)
k = base64url_decode(key.get_op_key('encrypt'))
if len(k) != keylen:
raise InvalidCEKeyLength(keylen, len(k))
return (k, '')
return {'cek': k}
def unwrap(self, key, ek):
def unwrap(self, key, keylen, ek, headers):
self._check_key(key)
if ek != b'':
raise InvalidJWEData('Invalid Encryption Key.')
return base64url_decode(key.get_op_key('decrypt'))
cek = base64url_decode(key.get_op_key('decrypt'))
if len(cek) != keylen:
raise InvalidJWEKeyLength(keylen, len(cek))
return cek
class _RawJWE(object):
......@@ -562,9 +569,11 @@ class JWE(object):
if header:
rec['header'] = header
self.cek, ek = alg.wrap(key, enc.key_size, self.cek)
if ek:
rec['encrypted_key'] = ek
wrapped = alg.wrap(key, enc.key_size, self.cek, jh)
self.cek = wrapped['cek']
if 'ek' in wrapped:
rec['encrypted_key'] = wrapped['ek']
if 'ciphertext' not in self.objects:
aad = base64url_encode(self.objects.get('protected', ''))
......@@ -684,7 +693,7 @@ class JWE(object):
if 'aad' in self.objects:
aad += '.' + base64url_encode(self.objects['aad'])
cek = alg.unwrap(key, ppe.get('encrypted_key', b''))
cek = alg.unwrap(key, enc.key_size, ppe.get('encrypted_key', b''), jh)
data = enc.decrypt(cek, aad.encode('utf-8'),
self.objects['iv'],
self.objects['ciphertext'],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment