diff --git a/backend/models.py b/backend/models.py index 0469bf088b88c5b070d1123f74d96195313ad0bf..02000576fae2736c66906d089d32386960fe176e 100644 --- a/backend/models.py +++ b/backend/models.py @@ -49,6 +49,7 @@ class PersonManager(BaseUserManager): def create_superuser(self, email, **other_fields): other_fields["is_superuser"] = True + other_fields["is_staff"] = True return self.create_user(email, **other_fields) def get_or_none(self, *args, **kw): diff --git a/backend/unittest.py b/backend/unittest.py index abc0b1b7c516621b89ed17100e7dd91096ee33f0..13a2f1f296faa6aaf311d19a72054635665bf78c 100644 --- a/backend/unittest.py +++ b/backend/unittest.py @@ -103,7 +103,8 @@ class TestBase(nm2.lib.unittest.TestBase): else: raise NotImplementedError(f"{identity.issuer} not supported as identity during testing") - person = self.persons[person] + if isinstance(person, str): + person = self.persons[person] if person is not None: client.force_login(person, backend=self.TEST_AUTH_BACKEND) client.visitor = person @@ -128,7 +129,8 @@ class TestBase(nm2.lib.unittest.TestBase): else: raise NotImplementedError(f"{identity.issuer} not supported as identity during testing") - person = self.persons[person] + if isinstance(person, str): + person = self.persons[person] if person is not None: client.force_login(person, backend=self.TEST_AUTH_BACKEND) client.visitor = person diff --git a/impersonate/middleware.py b/impersonate/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3fcc12bf3ff5cfc4b6320a99c12869e1676114 --- /dev/null +++ b/impersonate/middleware.py @@ -0,0 +1,34 @@ +from __future__ import annotations +from django.core.exceptions import ImproperlyConfigured +from django.contrib.auth import get_user_model + + +class ImpersonateMiddleware: + def __init__(self, get_response): + self.get_response = get_response + self.User = get_user_model() + + def __call__(self, request): + # AuthenticationMiddleware is required so that request.user exists. + if not hasattr(request, 'user'): + raise ImproperlyConfigured( + "The impersonator middleware requires the authentication middleware" + " to be installed. Edit your MIDDLEWARE setting to insert" + " 'django.contrib.auth.middleware.AuthenticationMiddleware'" + " before the ImpersonateMiddleware class.") + + if request.user.is_authenticated: + # Implement impersonation if requested in session + if request.user.is_staff: + pk = request.session.get("impersonate", None) + if pk is not None: + try: + user = self.User.objects.get(pk=pk) + except self.User.DoesNotExist: + user = None + + if user is not None: + request.impersonator = request.user + request.user = user + + return self.get_response(request) diff --git a/impersonate/tests.py b/impersonate/tests.py index 1c0f2dbccf49264fcce8fa5af62f69357fc48eb5..e655703b195ade8b09ff98216176c55d5a58dae2 100644 --- a/impersonate/tests.py +++ b/impersonate/tests.py @@ -1,29 +1,63 @@ from __future__ import annotations from django.test import TestCase from django.urls import reverse -from backend.unittest import PersonFixtureMixin +from backend.unittest import TestBase +from django.contrib.auth import get_user_model -class TestPermissions(PersonFixtureMixin, TestCase): - @classmethod - def __add_extra_tests__(cls): - non_fd = ["pending", "dc", "dc_ga", "dm", "dm_ga", "dd_nu", "dd_u", "dd_e", "dd_r", "activeam", "oldam"] - fd = ["fd", "dam"] - - for visitor in [None] + non_fd: - for visited in non_fd + fd: - cls._add_method(cls._test_impersonate_fail, visitor, visited) +class TestPermissions(TestBase, TestCase): + def test_impersonate_staff(self): + User = get_user_model() + visitor = User.objects.create_superuser(email="admin@example.org", fullname="Admin", audit_skip=True) + visited = User.objects.create_user(email="user@example.org", fullname="User", audit_skip=True) + client = self.make_test_client(visitor) - for visitor in fd: - for visited in non_fd + fd: - cls._add_method(cls._test_impersonate_success, visitor, visited) + response = client.get(reverse("impersonate:whoami")) + self.assertJSONEqual(response.content, { + 'impersonator': None, + 'impersonator_desc': None, + 'user': visitor.pk, + 'user_desc': str(visitor), + }) - def _test_impersonate_success(self, visitor, visited): - client = self.make_test_client(visitor) - response = client.post(reverse("impersonate"), data={"pk": self.persons[visited].pk, "next": "/"}) + response = client.post(reverse("impersonate:impersonate"), data={"pk": visited.pk, "next": "/"}) self.assertRedirectMatches(response, "^/$") - def _test_impersonate_fail(self, visitor, visited): + response = client.get(reverse("impersonate:whoami")) + self.assertJSONEqual(response.content, { + 'impersonator': visitor.pk, + 'impersonator_desc': str(visitor), + 'user': visited.pk, + 'user_desc': str(visited), + }) + + def test_impersonate_user(self): + User = get_user_model() + visitor = User.objects.create_user(email="user@example.org", fullname="User", audit_skip=True) + visited = User.objects.create_user(email="user1@example.org", fullname="User1", audit_skip=True) client = self.make_test_client(visitor) - response = client.post(reverse("impersonate"), data={"pk": self.persons[visited].pk}) + response = client.post(reverse("impersonate:impersonate"), data={"pk": visited.pk}) self.assertPermissionDenied(response) + + response = client.get(reverse("impersonate:whoami")) + self.assertJSONEqual(response.content, { + 'impersonator': None, + 'impersonator_desc': None, + 'user': visitor.pk, + 'user_desc': str(visitor), + }) + + def test_impersonate_anonymous(self): + User = get_user_model() + visited = User.objects.create_user(email="user@example.org", fullname="User", audit_skip=True) + client = self.make_test_client(None) + response = client.post(reverse("impersonate:impersonate"), data={"pk": visited.pk}) + self.assertPermissionDenied(response) + + response = client.get(reverse("impersonate:whoami")) + self.assertJSONEqual(response.content, { + 'impersonator': None, + 'impersonator_desc': None, + 'user': None, + 'user_desc': "AnonymousUser", + }) diff --git a/impersonate/urls.py b/impersonate/urls.py index dce4619b931404960fe0ab464e785feb9c5f039b..0dfa00731f56ffc010c86a2bb818460f93c71c0d 100644 --- a/impersonate/urls.py +++ b/impersonate/urls.py @@ -6,4 +6,5 @@ app_name = "impersonate" urlpatterns = [ # Impersonate a user path('impersonate/', views.Impersonate.as_view(), name="impersonate"), + path('whoami/', views.Whoami.as_view(), name="whoami"), ] diff --git a/impersonate/views.py b/impersonate/views.py index eab368c6d6562732e352636e7e23132151c3a65d..83e52f43dc153601bd698de65ae35853aec9ea56 100644 --- a/impersonate/views.py +++ b/impersonate/views.py @@ -5,6 +5,7 @@ from django.shortcuts import redirect from django.core.exceptions import PermissionDenied from django.contrib import messages from django.contrib.auth import get_user_model +from django import http class Impersonate(View): @@ -13,7 +14,7 @@ class Impersonate(View): effective_user = getattr(request, "impersonator", None) if effective_user is None: effective_user = request.user - if not effective_user.is_authenticated or not effective_user.is_admin: + if not effective_user.is_authenticated or not effective_user.is_staff: raise PermissionDenied pk = request.POST.get("pk") if pk is None: @@ -33,3 +34,14 @@ class Impersonate(View): return redirect(user.get_absolute_url()) else: return redirect(url) + + +class Whoami(View): + def get(self, request, *args, **kw): + impersonator = getattr(request, "impersonator", None) + return http.JsonResponse({ + "user": request.user.pk, + "user_desc": str(request.user), + "impersonator": None if impersonator is None else impersonator.pk, + "impersonator_desc": None if impersonator is None else str(impersonator), + }) diff --git a/nm2/settings.py b/nm2/settings.py index 868d1d3f1d279fcd30ab0f4dd1f9d9db9ea900b0..a7aa49089a77b51836681b33b606511600ea4e60 100644 --- a/nm2/settings.py +++ b/nm2/settings.py @@ -88,6 +88,7 @@ MIDDLEWARE = [ 'signon.middleware.SignonMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'impersonate.middleware.ImpersonateMiddleware', ] AUTHENTICATION_BACKENDS = [