Skip to content
Snippets Groups Projects
Commit 95ccef07 authored by Ray Luo's avatar Ray Luo
Browse files

Managed Identity implementation

Fix docs

Adjusting error message and docs

Fix typo
parent 12566ba1
Branches
Tags
No related merge requests found
......@@ -168,3 +168,23 @@ You may want to catch them to provide a better error message to your end users.
.. autoclass:: msal.IdTokenError
Managed Identity
================
MSAL supports
`Managed Identity <https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/overview>`_.
You can create one of these two kinds of managed identity configuration objects:
.. autoclass:: msal.SystemAssignedManagedIdentity
:members:
.. autoclass:: msal.UserAssignedManagedIdentity
:members:
And then feed the configuration object into a :class:`ManagedIdentityClient` object.
.. autoclass:: msal.ManagedIdentityClient
:members:
.. automethod:: __init__
......@@ -34,8 +34,14 @@ from .application import (
from .oauth2cli.oidc import Prompt, IdTokenError
from .token_cache import TokenCache, SerializableTokenCache
from .auth_scheme import PopAuthScheme
from .managed_identity import (
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
ManagedIdentityClient,
ManagedIdentityError,
)
# Putting module-level exceptions into the package namespace, to make them
# 1. officially part of the MSAL public API, and
# 2. can still be caught by the user code even if we change the module structure.
from .oauth2cli.oauth2 import BrowserInteractionTimeoutError
This diff is collapsed.
# This sample can be configured to work with Microsoft Entra ID's Managed Identity.
#
# A user-assigned managed identity can be represented as a JSON blob.
# Check MSAL Python's API Reference for its syntax.
# https://msal-python.readthedocs.io/en/latest/#managed-identity
#
# Example value when using a user-assigned managed identity:
# {"ManagedIdentityIdType": "ClientId", "Id": "your_managed_identity_id"}
# Leave it empty or absent if you are using a system-assigned managed identity.
MANAGED_IDENTITY=<managed_identity>
# Managed Identity works with resource, not scopes.
RESOURCE=<your resource>
# Required if the sample app wants to call an API.
#ENDPOINT=https://graph.microsoft.com/v1.0/me
"""
This sample demonstrates a daemon application that acquires a token using a
managed identity and then calls a web API with the token.
This sample loads its configuration from a .env file.
To make this sample work, you need to choose this template:
.env.sample.managed_identity
Copy the chosen template to a new file named .env, and fill in the values.
You can then run this sample:
python name_of_this_script.py
"""
import json
import logging
import os
import time
from dotenv import load_dotenv # Need "pip install python-dotenv"
import msal
import requests
# Optional logging
# logging.basicConfig(level=logging.DEBUG) # Enable DEBUG log for entire script
# logging.getLogger("msal").setLevel(logging.INFO) # Optionally disable MSAL DEBUG logs
load_dotenv() # We use this to load configuration from a .env file
# If for whatever reason you plan to recreate same ClientApplication periodically,
# you shall create one global token cache and reuse it by each ClientApplication
global_token_cache = msal.TokenCache() # The TokenCache() is in-memory.
# See more options in https://msal-python.readthedocs.io/en/latest/#tokencache
# Create a managed identity instance based on the environment variable value
if os.getenv('MANAGED_IDENTITY'):
managed_identity = json.loads(os.getenv('MANAGED_IDENTITY'))
else:
managed_identity = msal.SystemAssignedManagedIdentity()
# Create a preferably long-lived app instance, to avoid the overhead of app creation
global_app = msal.ManagedIdentityClient(
managed_identity,
http_client=requests.Session(),
token_cache=global_token_cache, # Let this app (re)use an existing token cache.
# If absent, ClientApplication will create its own empty token cache
)
resource = os.getenv("RESOURCE")
def acquire_and_use_token():
# ManagedIdentityClient.acquire_token_for_client(...) will automatically look up
# a token from cache, and fall back to acquire a fresh token when needed.
result = global_app.acquire_token_for_client(resource=resource)
if "access_token" in result:
if os.getenv('ENDPOINT'):
# Calling a web API using the access token
api_result = requests.get(
os.getenv('ENDPOINT'),
headers={'Authorization': 'Bearer ' + result['access_token']},
).json() # Assuming the response is JSON
print("Web API call result", json.dumps(api_result, indent=2))
else:
print("Token acquisition result", json.dumps(result, indent=2))
else:
print("Token acquisition failed", result) # Examine result["error_description"] etc. to diagnose error
while True: # Here we mimic a long-lived daemon
acquire_and_use_token()
print("Press Ctrl-C to stop.")
time.sleep(5) # Let's say your app would run a workload every X minutes
import json
import os
import sys
import time
import unittest
try:
from unittest.mock import patch, ANY, mock_open
except:
from mock import patch, ANY, mock_open
import requests
from tests.http_client import MinimalResponse
from msal import (
SystemAssignedManagedIdentity, UserAssignedManagedIdentity,
ManagedIdentityClient,
ManagedIdentityError,
)
class ManagedIdentityTestCase(unittest.TestCase):
def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_from_file_or_env_var(self):
self.assertEqual(
UserAssignedManagedIdentity(client_id="foo"),
{"ManagedIdentityIdType": "ClientId", "Id": "foo"})
self.assertEqual(
UserAssignedManagedIdentity(resource_id="foo"),
{"ManagedIdentityIdType": "ResourceId", "Id": "foo"})
self.assertEqual(
UserAssignedManagedIdentity(object_id="foo"),
{"ManagedIdentityIdType": "ObjectId", "Id": "foo"})
with self.assertRaises(ManagedIdentityError):
UserAssignedManagedIdentity()
with self.assertRaises(ManagedIdentityError):
UserAssignedManagedIdentity(client_id="foo", resource_id="bar")
self.assertEqual(
SystemAssignedManagedIdentity(),
{"ManagedIdentityIdType": "SystemAssigned", "Id": None})
class ClientTestCase(unittest.TestCase):
maxDiff = None
def setUp(self):
self.app = ManagedIdentityClient(
{ # Here we test it with the raw dict form, to test that
# the client has no hard dependency on ManagedIdentity object
"ManagedIdentityIdType": "SystemAssigned", "Id": None,
},
http_client=requests.Session(),
)
def _test_token_cache(self, app):
cache = app._token_cache._cache
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
at = list(cache["AccessToken"].values())[0]
self.assertEqual(
app._managed_identity.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"),
at["client_id"],
"Should have expected client_id")
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
def _test_happy_path(self, app, mocked_http):
result = app.acquire_token_for_client(resource="R")
mocked_http.assert_called()
self.assertEqual({
"access_token": "AT",
"expires_in": 1234,
"resource": "R",
"token_type": "Bearer",
}, result, "Should obtain a token response")
self.assertEqual(
result["access_token"],
app.acquire_token_for_client(resource="R").get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)
class VmTestCase(ClientTestCase):
def test_happy_path(self):
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
def test_vm_error_should_be_returned_as_is(self):
raw_error = '{"raw": "error format is undefined"}'
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=400,
text=raw_error,
)) as mocked_method:
self.assertEqual(
json.loads(raw_error), self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"})
class AppServiceTestCase(ClientTestCase):
def test_happy_path(self):
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
def test_app_service_error_should_be_normalized(self):
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=500,
text=raw_error,
)) as mocked_method:
self.assertEqual({
"error": "invalid_scope",
"error_description": "500, error content is undefined",
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost",
"IDENTITY_HEADER": "foo",
"IDENTITY_SERVER_THUMBPRINT": "bar",
})
class ServiceFabricTestCase(ClientTestCase):
def _test_happy_path(self, app):
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
int(time.time()) + 1234),
)) as mocked_method:
super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method)
def test_happy_path(self):
self._test_happy_path(self.app)
def test_unified_api_service_should_ignore_unnecessary_client_id(self):
self._test_happy_path(ManagedIdentityClient(
{"ManagedIdentityIdType": "ClientId", "Id": "foo"},
http_client=requests.Session(),
))
def test_sf_error_should_be_normalized(self):
raw_error = '''
{"error": {
"correlationId": "foo",
"code": "SecretHeaderNotFound",
"message": "Secret is not found in the request headers."
}}''' # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=404,
text=raw_error,
)) as mocked_method:
self.assertEqual({
"error": "unauthorized_client",
"error_description": raw_error,
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
@patch.dict(os.environ, {
"IDENTITY_ENDPOINT": "http://localhost/token",
"IMDS_ENDPOINT": "http://localhost",
})
@patch(
"builtins.open" if sys.version_info.major >= 3 else "__builtin__.open",
new=mock_open(read_data="secret"), # `new` requires no extra argument on the decorated function.
# https://docs.python.org/3/library/unittest.mock.html#unittest.mock.patch
)
class ArcTestCase(ClientTestCase):
challenge = MinimalResponse(status_code=401, text="", headers={
"WWW-Authenticate": "Basic realm=/tmp/foo",
})
def test_happy_path(self):
with patch.object(self.app._http_client, "get", side_effect=[
self.challenge,
MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
),
]) as mocked_method:
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
def test_arc_error_should_be_normalized(self):
with patch.object(self.app._http_client, "get", side_effect=[
self.challenge,
MinimalResponse(status_code=400, text="undefined"),
]) as mocked_method:
self.assertEqual({
"error": "invalid_request",
"error_description": "undefined",
}, self.app.acquire_token_for_client(resource="R"))
self.assertEqual({}, self.app._token_cache._cache)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment