diff --git a/oslo_limit/limit.py b/oslo_limit/limit.py index 1393eedd9a700e3a2e0456035430e13eab7709b7..6fc3941b9e9fc52478abd988ae3c06f167173cfb 100644 --- a/oslo_limit/limit.py +++ b/oslo_limit/limit.py @@ -181,6 +181,13 @@ class Enforcer(object): return {resource: ProjectUsage(limit, usage[resource]) for resource, limit in limits} + def get_registered_limits(self, resources_to_check): + return self.model.get_registered_limits(resources_to_check) + + def get_project_limits(self, project_id, resources_to_check): + return self.model.get_project_limits(project_id, + resources_to_check) + class _FlatEnforcer(object): @@ -190,6 +197,9 @@ class _FlatEnforcer(object): self._usage_callback = usage_callback self._utils = _EnforcerUtils(cache=cache) + def get_registered_limits(self, resources_to_check): + return self._utils.get_registered_limits(resources_to_check) + def get_project_limits(self, project_id, resources_to_check): return self._utils.get_project_limits(project_id, resources_to_check) @@ -217,6 +227,9 @@ class _StrictTwoLevelEnforcer(object): def __init__(self, usage_callback, cache=True): self._usage_callback = usage_callback + def get_registered_limits(self, resources_to_check): + raise NotImplementedError() + def get_project_limits(self, project_id, resources_to_check): raise NotImplementedError() @@ -285,6 +298,24 @@ class _EnforcerUtils(object): LOG.debug("hit limit for project: %s", over_limit_list) raise exception.ProjectOverLimit(project_id, over_limit_list) + def get_registered_limits(self, resource_names): + """Get all the default limits for a given resource name list + + :param resource_names: list of resource_name strings + :return: list of (resource_name, limit) pairs + """ + # Using a list to preserve the resource_name order + registered_limits = [] + for resource_name in resource_names: + reg_limit = self._get_registered_limit(resource_name) + if reg_limit: + limit = reg_limit.default_limit + else: + limit = 0 + registered_limits.append((resource_name, limit)) + + return registered_limits + def get_project_limits(self, project_id, resource_names): """Get all the limits for given project a resource_name list diff --git a/oslo_limit/tests/test_limit.py b/oslo_limit/tests/test_limit.py index 2c9f1515ec5f2a1da22a00ca79fc6697ace8314d..87de31df32125d1587485bae2d2e480f07affaba 100644 --- a/oslo_limit/tests/test_limit.py +++ b/oslo_limit/tests/test_limit.py @@ -198,6 +198,27 @@ class TestEnforcer(base.BaseTestCase): enforcer.calculate_usage, 'project', ['a', 123, 'b']) + @mock.patch.object(limit._EnforcerUtils, "get_registered_limits") + def test_get_registered_limits(self, mock_get_limits): + mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)] + + enforcer = limit.Enforcer(lambda: None) + limits = enforcer.get_registered_limits(["a", "b", "c"]) + + mock_get_limits.assert_called_once_with(["a", "b", "c"]) + self.assertEqual(mock_get_limits.return_value, limits) + + @mock.patch.object(limit._EnforcerUtils, "get_project_limits") + def test_get_project_limits(self, mock_get_limits): + project_id = uuid.uuid4().hex + mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)] + + enforcer = limit.Enforcer(lambda: None) + limits = enforcer.get_project_limits(project_id, ["a", "b", "c"]) + + mock_get_limits.assert_called_once_with(project_id, ["a", "b", "c"]) + self.assertEqual(mock_get_limits.return_value, limits) + class TestFlatEnforcer(base.BaseTestCase): def setUp(self): @@ -205,6 +226,27 @@ class TestFlatEnforcer(base.BaseTestCase): self.mock_conn = mock.MagicMock() limit._SDK_CONNECTION = self.mock_conn + @mock.patch.object(limit._EnforcerUtils, "get_registered_limits") + def test_get_registered_limits(self, mock_get_limits): + mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)] + + enforcer = limit._FlatEnforcer(lambda: None) + limits = enforcer.get_registered_limits(["a", "b", "c"]) + + mock_get_limits.assert_called_once_with(["a", "b", "c"]) + self.assertEqual(mock_get_limits.return_value, limits) + + @mock.patch.object(limit._EnforcerUtils, "get_project_limits") + def test_get_project_limits(self, mock_get_limits): + project_id = uuid.uuid4().hex + mock_get_limits.return_value = [("a", 1), ("b", 0), ("c", 2)] + + enforcer = limit._FlatEnforcer(lambda: None) + limits = enforcer.get_project_limits(project_id, ["a", "b", "c"]) + + mock_get_limits.assert_called_once_with(project_id, ["a", "b", "c"]) + self.assertEqual(mock_get_limits.return_value, limits) + @mock.patch.object(limit._EnforcerUtils, "get_project_limits") def test_enforce(self, mock_get_limits): mock_usage = mock.MagicMock() @@ -298,6 +340,33 @@ class TestEnforcerUtils(base.BaseTestCase): self.assertEqual(foo, reg_limit) + def test_get_registered_limits(self): + fake_endpoint = endpoint.Endpoint() + fake_endpoint.service_id = "service_id" + fake_endpoint.region_id = "region_id" + self.mock_conn.get_endpoint.return_value = fake_endpoint + + # a and c have limits, b doesn't have one + empty_iterator = iter([]) + + a = registered_limit.RegisteredLimit() + a.resource_name = "a" + a.default_limit = 1 + a_iterator = iter([a]) + + c = registered_limit.RegisteredLimit() + c.resource_name = "c" + c.default_limit = 2 + c_iterator = iter([c]) + + self.mock_conn.registered_limits.side_effect = [a_iterator, + empty_iterator, + c_iterator] + + utils = limit._EnforcerUtils() + limits = utils.get_registered_limits(["a", "b", "c"]) + self.assertEqual([('a', 1), ('b', 0), ('c', 2)], limits) + def test_get_project_limits(self): fake_endpoint = endpoint.Endpoint() fake_endpoint.service_id = "service_id"