Skip to content
Snippets Groups Projects
Commit fad75c9f authored by Ray Luo's avatar Ray Luo
Browse files

Merge branch 'oauth2' into assemble

parents 17307d84 02030e98
No related branches found
No related tags found
No related merge requests found
__version__ = "0.0.1"
from .oauth2 import Client
import time
import binascii
import base64
import uuid
import logging
import jwt
logger = logging.getLogger(__file__)
class Signer(object):
def sign_assertion(
self, audience, issuer, subject, expires_at,
issued_at=None, assertion_id=None, **kwargs):
# Names are defined in https://tools.ietf.org/html/rfc7521#section-5
raise NotImplementedError("Will be implemented by sub-class")
class JwtSigner(Signer):
def __init__(self, key, algorithm, sha1_thumbprint=None, headers=None):
"""Create a signer.
Args:
key (str): The key for signing, e.g. a base64 encoded private key.
algorithm (str):
"RS256", etc.. See https://pyjwt.readthedocs.io/en/latest/algorithms.html
RSA and ECDSA algorithms require "pip install cryptography".
sha1_thumbprint (str): The x5t aka X.509 certificate SHA-1 thumbprint.
headers (dict): Additional headers, e.g. "kid" or "x5c" etc.
"""
self.key = key
self.algorithm = algorithm
self.headers = headers or {}
if sha1_thumbprint: # https://tools.ietf.org/html/rfc7515#section-4.1.7
self.headers["x5t"] = base64.urlsafe_b64encode(
binascii.a2b_hex(sha1_thumbprint)).decode()
def sign_assertion(
self, audience, issuer, subject=None, expires_at=None,
issued_at=None, assertion_id=None, not_before=None,
additional_claims=None, **kwargs):
"""Sign a JWT Assertion.
Parameters are defined in https://tools.ietf.org/html/rfc7523#section-3
Key-value pairs in additional_claims will be added into payload as-is.
"""
now = time.time()
payload = {
'aud': audience,
'iss': issuer,
'sub': subject or issuer,
'exp': expires_at or (now + 10*60), # 10 minutes
'iat': issued_at or now,
'jti': assertion_id or str(uuid.uuid4()),
}
if not_before:
payload['nbf'] = not_before
payload.update(additional_claims or {})
try:
return jwt.encode(
payload, self.key, algorithm=self.algorithm, headers=self.headers)
except:
if self.algorithm.startswith("RS") or self.algorithm.starswith("ES"):
logger.exception(
'Some algorithms requires "pip install cryptography". '
'See https://pyjwt.readthedocs.io/en/latest/installation.html#cryptographic-dependencies-optional')
raise
# Note: This docstring is also used by this script's command line help.
"""A one-stop helper for desktop app to acquire an authorization code.
It starts a web server to listen redirect_uri, waiting for auth code.
It optionally opens a browser window to guide a human user to manually login.
After obtaining an auth code, the web server will automatically shut down.
"""
import argparse
import webbrowser
import logging
try: # Python 3
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs, urlencode
except ImportError: # Fall back to Python 2
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from urlparse import urlparse, parse_qs
from urllib import urlencode
from .oauth2 import Client
logger = logging.getLogger(__file__)
def obtain_auth_code(listen_port, auth_uri=None):
"""This function will start a web server listening on http://localhost:port
and then you need to open a browser on this device and visit your auth_uri.
When interaction finishes, this function will return the auth code,
and then shut down the local web server.
:param listen_port:
The local web server will listen at http://localhost:<listen_port>
Unless the authorization server supports dynamic port,
you need to use the same port when you register with your app.
:param auth_uri: If provided, this function will try to open a local browser.
:return: Hang indefinitely, until it receives and then return the auth code.
"""
exit_hint = "Visit http://localhost:{p}?code=exit to abort".format(p=listen_port)
logger.warning(exit_hint)
if auth_uri:
page = "http://localhost:{p}?{q}".format(p=listen_port, q=urlencode({
"text": "Open this link to sign in. You may use incognito window",
"link": auth_uri,
"exit_hint": exit_hint,
}))
browse(page)
server = HTTPServer(("", int(listen_port)), AuthCodeReceiver)
try:
server.authcode = None
while not server.authcode:
# Derived from
# https://docs.python.org/2/library/basehttpserver.html#more-examples
server.handle_request()
return server.authcode
finally:
server.server_close()
def browse(auth_uri):
controller = webbrowser.get() # Get a default controller
# Some Linux Distro does not setup default browser properly,
# so we try to explicitly use some popular browser, if we found any.
for browser in ["chrome", "firefox", "safari", "windows-default"]:
try:
controller = webbrowser.get(browser)
break
except webbrowser.Error:
pass # This browser is not installed. Try next one.
logger.info("Please open a browser on THIS device to visit: %s" % auth_uri)
controller.open(auth_uri)
class AuthCodeReceiver(BaseHTTPRequestHandler):
def do_GET(self):
# For flexibility, we choose to not check self.path matching redirect_uri
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
qs = parse_qs(urlparse(self.path).query)
if qs.get('code'): # Then store it into the server instance
ac = self.server.authcode = qs['code'][0]
self._send_full_response('Authcode:\n{}'.format(ac))
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
elif qs.get('text') and qs.get('link'): # Then display a landing page
self._send_full_response(
'<a href={link}>{text}</a><hr/>{exit_hint}'.format(
link=qs['link'][0], text=qs['text'][0],
exit_hint=qs.get("exit_hint", [''])[0],
))
else:
self._send_full_response("This web service serves your redirect_uri")
def _send_full_response(self, body, is_ok=True):
self.send_response(200 if is_ok else 400)
content_type = 'text/html' if body.startswith('<') else 'text/plain'
self.send_header('Content-type', content_type)
self.end_headers()
self.wfile.write(body.encode("utf-8"))
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
p = parser = argparse.ArgumentParser(
description=__doc__ + "The auth code received will be shown at stdout.")
p.add_argument('endpoint',
help="The auth endpoint for your app. For example: "
"https://login.microsoftonline.com/your_tenant/oauth2/authorize")
p.add_argument('client_id', help="The client_id of your application")
p.add_argument('redirect_port', type=int, help="The port in redirect_uri")
args = parser.parse_args()
client = Client(args.client_id, authorization_endpoint=args.endpoint)
auth_uri = client.build_auth_request_uri("code")
print(obtain_auth_code(args.redirect_port, auth_uri))
This diff is collapsed.
"""This module extends the OAuth2 Client with a builtin refresh token storage.
Design Goal(s):
1. Keep the number of user prompts as low as possible.
This will be achieved by saving the refresh token (RT) into a storage
and then automatically use them to acquire new access token (AT) later.
Consequently, the RT concept will largely be abstracted away from the API,
so callers don't even need to know about and deal with RTs.
- The token storage should be scalable,
because a confidential client can potentially store lots of RTs in it.
- The token storage should be in a generic format,
so it would potentially be shared among different languages or platforms.
2. AT cache is NOT implemented in this module.
"""
import time
from .oauth2 import Client
class AbstractStorage(object): # The concrete implementation can be an RDBMS
"""Define the storage behaviors that will be used by this module."""
def find(self, query):
raise NotImplementedError("Will return items matching given query")
def add(self, item):
raise NotImplementedError("Will add item into this storage")
def remove(self, item):
# TBD: Shall this be changed into remove(query)?
raise NotImplementedError("Will remove item (if any) from this storage")
def commit(self):
"""This can be useful when you use a db cursor instance as the storage.
This default implementation does nothing, though."""
def build_item(client_id, token_endpoint, scope_in_request, response):
# response is defined in https://tools.ietf.org/html/rfc6749#section-5.1
# The scope is usually optional in both request and response, so likely None
return {
"scope": Client._stringify(response.get("scope", scope_in_request)),
"client_id": client_id,
"authority": token_endpoint,
"refresh_token": response.get("refresh_token"),
"access_token": response.get("access_token"),
}
def _is_subdict_of(small, big):
return dict(big, **small) == big
class StorageInRam(AbstractStorage):
def __init__(self):
self.storage = []
def find(self, query):
return [item for item in self.storage if _is_subdict_of(query, item)]
def add(self, item):
self.storage.append(item)
def remove(self, item):
self.storage.remove(item)
class ClientWithRefreshTokenStorage(Client):
def __init__(
self, client_id,
refresh_token_storage=StorageInRam(),
item_builder=build_item, # We use this callback rather than a method
# so that it can be customized without subclass this class
**kwargs):
super(ClientWithRefreshTokenStorage, self).__init__(client_id, **kwargs)
self.refresh_token_storage = refresh_token_storage
self.item_builder = item_builder
def _obtain_token(self, grant_type, params=None, data=None, *args, **kwargs):
"""Automatically maintains self.refresh_token_storage"""
resp = super(ClientWithRefreshTokenStorage, self)._obtain_token(
grant_type, params, data, *args, **kwargs)
if 'error' not in resp and 'refresh_token' in resp: # A RT is obtained
self.refresh_token_storage.add(self.item_builder(
self.client_id, self.token_endpoint, (data or {}).get('scope'),
resp))
self.refresh_token_storage.commit()
# if grant_type == "refresh_token" and "error" in resp,
# then the old RT is rejected.
# In such case, we could want to also remove old RT from storage.
# But, this method knows only a RT but not the item containing this RT,
# so we would have to delete by search:
# for item in find({'RT': kwargs['RT']}): delete(item)
# which would typically be slpw.
# So we make a decision to only do storage cleaning in the other method,
# obtain_token_silent(...), which can handily do
# delete(old_item)
# The tradeoff is that, those obtain_token_with_refresh_token(...) call
# which are NOT initialized by obtain_token_silent()
# would then NOT trigger a RT cleanup. This is not a big issue though,
# considering that we expect caller of this class no longer need to
# directly invoke obtain_token_with_refresh_token(...) anyway.
return resp
def obtain_token_silent(self, query):
"""If a RT in the storage matches query,
then use it to talk to authorization server, and acquire a fresh AT.
Otherwise returns None.
Caller is supposed to keep the returned AT,
typically in a local session, which is out of the scope of this class.
Usage:
# Leverage a pre-existing RT (if any) to skip user interaction
AT = client.obtain_token_silent(query)
if AT is None: # This happens when a matching RT does not exist
AT = client.acquire_token_with_one_of_actual_grants(...)
# Now you end up with a fresh AT, so it will just work
happily_access_resource_with(AT)
# Of course, after some time, the AT may expire or get revoked,
# so you will need to redo this process again.
There can be cases that, even this method returns an AT for you,
a resource server might still reject an AT with inadequate claims
(such as missing MFA, or other condtional access policy).
In those cases, repeating this method call will get you nowhere.
App developer is expected to call other grants instead.
:param query: A query to be matched against those items in storage.
It is conceptually a dict, e.g. {"client_id": "...", "scope": "..."}.
Its exact format, or even type, is decided by the token storage's
find() method. So you may be able to use a lambda here.
:returns: The json object from authorization server, or None
"""
has_state_changed = False
try:
for item in self.refresh_token_storage.find(query):
assert 'refresh_token' in item
if isinstance(query, dict) and 'scope' in query:
scope = query['scope']
else:
scope = item.get('scope')
resp = self.obtain_token_with_refresh_token(
item["refresh_token"], scope)
if resp.get('error') == 'invalid_grant' or 'refresh_token' in resp:
self.refresh_token_storage.remove(item) # Discard old RT
has_state_changed = True
if 'error' not in resp:
return resp
finally:
if has_state_changed: # commit all the changes during loop
self.refresh_token_storage.commit()
import sys
import logging
if sys.version_info[:2] < (2, 7):
# The unittest module got a significant overhaul in Python 2.7,
# so if we're in 2.6 we can use the backported version unittest2.
......@@ -6,3 +8,19 @@ if sys.version_info[:2] < (2, 7):
else:
import unittest
class Oauth2TestCase(unittest.TestCase):
logger = logging.getLogger(__file__)
def assertLoosely(self, response, assertion=None,
skippable_errors=("invalid_grant", "interaction_required")):
if response.get("error") in skippable_errors:
self.logger.debug("Response = %s", response)
# Some of these errors are configuration issues, not library issues
raise unittest.SkipTest(response.get("error_description"))
else:
if assertion is None:
assertion = lambda: self.assertIn("access_token", response)
assertion()
import os
import json
import logging
try: # Python 2
from urlparse import urljoin
except: # Python 3
from urllib.parse import urljoin
import time
import requests
from oauth2cli.oauth2 import Client
from oauth2cli.authcode import obtain_auth_code
from oauth2cli.assertion import JwtSigner
from tests import unittest, Oauth2TestCase
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__file__)
CONFIG_FILENAME = "config.json"
def load_conf(filename):
"""
Example of a configuration file:
{
"Note": "the OpenID Discovery will be updated by following optional content",
"openid_configuration": {
"authorization_endpoint": "https://example.com/tenant/oauth2/authorize",
"token_endpoint": "https://example.com/tenant/oauth2/token",
"device_authorization_endpoint": "device_authorization"
},
"client_id": "289a413d-284b-4303-9c79-94380abe5d22",
"client_secret": "your_secret",
"scope": ["your_scope"],
"resource": "Some IdP needs this",
"oidp": "https://example.com/tenant/",
"username": "you@example.com",
"password": "I could tell you but then I would have to kill you",
"placeholder": null
}
"""
try:
with open(filename) as f:
conf = json.load(f)
except:
logger.warning("Unable to open/read JSON configuration %s" % filename)
raise
openid_configuration = {}
try:
# The following line may duplicate a '/' at the joining point,
# but requests.get(...) would still work.
# Besides, standard urljoin(...) is picky on insisting oidp ends with '/'
discovery_uri = conf["oidp"] + '/.well-known/openid-configuration'
openid_configuration.update(requests.get(discovery_uri).json())
except:
logger.warning("openid-configuration uri not accesible: %s", discovery_uri)
openid_configuration.update(conf.get("openid_configuration", {}))
if openid_configuration.get("device_authorization_endpoint"):
# The following urljoin(..., ...) trick allows a "path_name" shorthand
openid_configuration["device_authorization_endpoint"] = urljoin(
openid_configuration.get("token_endpoint", ""),
openid_configuration.get("device_authorization_endpoint", ""))
conf["openid_configuration"] = openid_configuration
return conf
THIS_FOLDER = os.path.dirname(__file__)
CONFIG = load_conf(os.path.join(THIS_FOLDER, CONFIG_FILENAME)) or {}
# Since the OAuth2 specs uses snake_case, this test config also uses snake_case
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
class TestClient(Oauth2TestCase):
@classmethod
def setUpClass(cls):
if "client_certificate" in CONFIG:
private_key_path = CONFIG["client_certificate"]["private_key_path"]
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
private_key = f.read() # Expecting PEM format
cls.client = Client(
CONFIG["openid_configuration"],
CONFIG['client_id'],
client_assertion=JwtSigner(
private_key,
algorithm="RS256",
sha1_thumbprint=CONFIG["client_certificate"]["thumbprint"]
).sign_assertion(
audience=CONFIG["openid_configuration"]["token_endpoint"],
issuer=CONFIG["client_id"],
),
)
else:
cls.client = Client(
CONFIG["openid_configuration"], CONFIG['client_id'],
client_secret=CONFIG.get('client_secret'))
@unittest.skipUnless("client_secret" in CONFIG, "client_secret missing")
def test_client_credentials(self):
result = self.client.obtain_token_for_client(CONFIG.get('scope'))
self.assertIn('access_token', result)
@unittest.skipUnless(
"username" in CONFIG and "password" in CONFIG, "username/password missing")
def test_username_password(self):
result = self.client.obtain_token_by_username_password(
CONFIG["username"], CONFIG["password"],
data={"resource": CONFIG.get("resource")}, # MSFT AAD V1 only
scope=CONFIG.get("scope"))
self.assertLoosely(result)
@unittest.skipUnless(
"authorization_endpoint" in CONFIG.get("openid_configuration", {}),
"authorization_endpoint missing")
def test_auth_code(self):
port = CONFIG.get("listen_port", 44331)
redirect_uri = "http://localhost:%s" % port
auth_request_uri = self.client.build_auth_request_uri(
"code", redirect_uri=redirect_uri, scope=CONFIG.get("scope"))
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
self.assertNotEqual(ac, None)
result = self.client.obtain_token_by_authorization_code(
ac,
data={
"scope": CONFIG.get("scope"),
"resource": CONFIG.get("resource"),
}, # MSFT AAD only
redirect_uri=redirect_uri)
self.assertLoosely(result, lambda: self.assertIn('access_token', result))
@unittest.skipUnless(
CONFIG.get("openid_configuration", {}).get("device_authorization_endpoint"),
"device_authorization_endpoint is missing")
def test_device_flow(self):
flow = self.client.initiate_device_flow(scope=CONFIG.get("scope"))
try:
msg = ("Use a web browser to open the page {verification_uri} and "
"enter the code {user_code} to authenticate.".format(**flow))
except KeyError: # Some IdP might not be standard compliant
msg = flow["message"] # Not a standard parameter though
logger.warning(msg) # Avoid print(...) b/c its output would be buffered
duration = 30
logger.warning("We will wait up to %d seconds for you to sign in" % duration)
flow["expires_at"] = time.time() + duration # Shorten the time for quick test
result = self.client.obtain_token_by_device_flow(flow)
self.assertLoosely(
result,
assertion=lambda: self.assertIn('access_token', result),
skippable_errors=self.client.DEVICE_FLOW_RETRIABLE_ERRORS)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment