From e2fcac5118d1a416841e6fab16a83a0eca9a9e31 Mon Sep 17 00:00:00 2001 From: Michael Fladischer <FladischerMichael@fladi.at> Date: Thu, 18 Jul 2024 20:55:24 +0000 Subject: [PATCH] New upstream version 0.19 --- .pre-commit-config.yaml | 10 +- CHANGELOG.rst | 12 ++ README.rst | 4 +- pyproject.toml | 17 ++- tests/testapp/models.py | 5 +- tests/testapp/test_queries.py | 219 +++++++++++++++++++++++++++-- tree_queries/__init__.py | 2 +- tree_queries/compiler.py | 255 ++++++++++++++++++++++------------ tree_queries/fields.py | 2 +- tree_queries/query.py | 43 +++++- 10 files changed, 448 insertions(+), 121 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index da146f6..0f399c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ exclude: ".yarn/|yarn.lock|\\.min\\.(css|js)$" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-added-large-files - id: check-builtin-literals @@ -14,7 +14,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/adamchainz/django-upgrade - rev: 1.15.0 + rev: 1.16.0 hooks: - id: django-upgrade args: [--target-version, "3.2"] @@ -23,7 +23,7 @@ repos: hooks: - id: absolufy-imports - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.6" + rev: "v0.4.1" hooks: - id: ruff - id: ruff-format @@ -34,10 +34,10 @@ repos: args: [--list-different, --no-semi] exclude: "^conf/|.*\\.html$" - repo: https://github.com/tox-dev/pyproject-fmt - rev: 1.5.1 + rev: 1.8.0 hooks: - id: pyproject-fmt - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.15 + rev: v0.16 hooks: - id: validate-pyproject diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2c2ccff..9256bb5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,18 @@ Change log Next version ~~~~~~~~~~~~ +0.19 (2024-04-25) +~~~~~~~~~~~~~~~~~ + +- Reimplemented the rank table construction using a real queryset; this enables + support for pre-filtering the tree queryset using ``.tree_filter()`` and + ``.tree_exclude()``. Thanks rhomboss! +- Added a ``.tree_fields()`` method to allow adding additional columns to the + tree queryset, allowing collecting ancestors fields directly when running the + initial query. For example, ``.tree_fields(tree_names="name")`` will collect + all ``name`` fields in a ``tree_fields`` array on the model instances. For + now the code only supports string fields and integer fields. + 0.18 (2024-04-03) ~~~~~~~~~~~~~~~~~ diff --git a/README.rst b/README.rst index 864bfe6..ff2615e 100644 --- a/README.rst +++ b/README.rst @@ -27,7 +27,9 @@ Features and limitations ``tree_depth``, ``tree_path`` and ``tree_ordering``. The names cannot be changed. ``tree_depth`` is an integer, ``tree_path`` an array of primary keys and ``tree_ordering`` an array of values used for - ordering nodes within their siblings. + ordering nodes within their siblings. Note that the contents of the + ``tree_path`` and ``tree_ordering`` are subject to change. You shouldn't rely + on their contents. - Besides adding the fields mentioned above the package only adds queryset methods for ordering siblings and filtering ancestors and descendants. Other features may be useful, but will not be added to the package just because diff --git a/pyproject.toml b/pyproject.toml index a367b68..a5d782d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,12 @@ include = ["tree_queries/"] path = "tree_queries/__init__.py" [tool.ruff] +fix = true +preview = true +show-fixes = true +target-version = "py38" + +[tool.ruff.lint] extend-select = [ # pyflakes, pycodestyle "F", "E", "W", @@ -80,7 +86,7 @@ extend-select = [ # pygrep-hooks "PGH", # pylint - "PL", + "PLC", "PLE", "PLW", # unused noqa "RUF100", ] @@ -90,18 +96,15 @@ extend-ignore = [ # No line length errors "E501", ] -fix = true -show-fixes = true -target-version = "py38" -[tool.ruff.isort] +[tool.ruff.lint.isort] combine-as-imports = true lines-after-imports = 2 -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 15 -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "*/migrat*/*" = [ # Allow using PascalCase model names in migrations "N806", diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 03ded71..4954f41 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -75,7 +75,7 @@ class AlwaysTreeQueryModel(TreeNode): class UUIDModel(TreeNode): - id = models.UUIDField(primary_key=True, default=uuid.uuid4) # noqa: A003 + id = models.UUIDField(primary_key=True, default=uuid.uuid4) name = models.CharField(max_length=100) def __str__(self): @@ -136,3 +136,6 @@ class OneToOneRelatedOrder(models.Model): related_name="related", ) order = models.PositiveIntegerField(default=0) + + def __str__(self): + return "" diff --git a/tests/testapp/test_queries.py b/tests/testapp/test_queries.py index 344ce45..13ee741 100644 --- a/tests/testapp/test_queries.py +++ b/tests/testapp/test_queries.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + from django import forms from django.core.exceptions import ValidationError from django.db import connections, models @@ -29,13 +31,13 @@ from tree_queries.query import pk @override_settings(DEBUG=True) class Test(TestCase): def create_tree(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = Model.objects.create(name="root") tree.child1 = Model.objects.create(parent=tree.root, order=0, name="1") tree.child2 = Model.objects.create(parent=tree.root, order=1, name="2") tree.child1_1 = Model.objects.create(parent=tree.child1, order=0, name="1-1") tree.child2_1 = Model.objects.create(parent=tree.child2, order=0, name="2-1") - tree.child2_2 = Model.objects.create(parent=tree.child2, order=1, name="2-2") + tree.child2_2 = Model.objects.create(parent=tree.child2, order=42, name="2-2") return tree def test_stuff(self): @@ -257,7 +259,7 @@ class Test(TestCase): self.assertNotIn("root", html) def test_string_ordering(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.americas = StringOrderedModel.objects.create(name="Americas") tree.europe = StringOrderedModel.objects.create(name="Europe") @@ -373,7 +375,7 @@ class Test(TestCase): def test_reference(self): tree = self.create_tree() - references = type("Namespace", (), {})() # SimpleNamespace for PY2... + references = SimpleNamespace() references.none = ReferenceModel.objects.create(position=0) references.root = ReferenceModel.objects.create( position=1, tree_field=tree.root @@ -500,9 +502,7 @@ class Test(TestCase): else: qs = qs.annotate( is_my_field=RawSQL( - 'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format( - pk=pk(tree.child2_1), sep=SEPARATOR - ), + f'instr(__tree.tree_path, "{SEPARATOR}{pk(tree.child2_1)}{SEPARATOR}") <> 0', [], output_field=models.BooleanField(), ) @@ -534,7 +534,7 @@ class Test(TestCase): ) def test_sibling_ordering(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = MultiOrderedModel.objects.create(name="root") tree.child1 = MultiOrderedModel.objects.create( @@ -682,7 +682,7 @@ class Test(TestCase): ) def test_multi_field_order(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = MultiOrderedModel.objects.create(name="root") tree.child1 = MultiOrderedModel.objects.create( @@ -717,7 +717,7 @@ class Test(TestCase): ) def test_order_by_related(self): - tree = type("Namespace", (), {})() # SimpleNamespace for PY2... + tree = SimpleNamespace() tree.root = RelatedOrderModel.objects.create(name="root") tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1") @@ -753,3 +753,202 @@ class Test(TestCase): tree.child2_2, ], ) + + def test_tree_exclude(self): + tree = self.create_tree() + # Tree-filter should remove children if + # the parent meets the filtering criteria + nodes = Model.objects.tree_exclude(name="2") + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child1, + tree.child1_1, + ], + ) + + def test_tree_filter(self): + tree = self.create_tree() + # Tree-filter should remove children if + # the parent does not meet the filtering criteria + nodes = Model.objects.tree_filter(name__in=["root", "1-1", "2", "2-1", "2-2"]) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + tree.child2_2, + ], + ) + + def test_tree_filter_chaining(self): + tree = self.create_tree() + # Tree-filter should remove children if + # the parent does not meet the filtering criteria + nodes = Model.objects.tree_exclude(name="2-2").tree_filter( + name__in=["root", "1-1", "2", "2-1", "2-2"] + ) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + ], + ) + + def test_tree_filter_related(self): + tree = SimpleNamespace() + + tree.root = RelatedOrderModel.objects.create(name="root") + tree.root_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.root, order=0 + ) + tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1") + tree.child1_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child1, order=0 + ) + tree.child2 = RelatedOrderModel.objects.create(parent=tree.root, name="2") + tree.child2_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child2, order=1 + ) + tree.child1_1 = RelatedOrderModel.objects.create(parent=tree.child1, name="1-1") + tree.child1_1_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child1_1, order=0 + ) + tree.child2_1 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-1") + tree.child2_1_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child2_1, order=0 + ) + tree.child2_2 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-2") + tree.child2_2_related = OneToOneRelatedOrder.objects.create( + relatedmodel=tree.child2_2, order=1 + ) + + nodes = RelatedOrderModel.objects.tree_filter(related__order=0) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child1, + tree.child1_1, + ], + ) + + def test_tree_filter_with_order(self): + tree = SimpleNamespace() + + tree.root = MultiOrderedModel.objects.create( + name="root", + first_position=1, + ) + tree.child1 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=0, second_position=1, name="1" + ) + tree.child2 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=1, second_position=0, name="2" + ) + tree.child1_1 = MultiOrderedModel.objects.create( + parent=tree.child1, first_position=1, second_position=1, name="1-1" + ) + tree.child2_1 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=1, name="2-1" + ) + tree.child2_2 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=0, name="2-2" + ) + + nodes = MultiOrderedModel.objects.tree_filter( + first_position__gt=0 + ).order_siblings_by("-second_position") + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + tree.child2_2, + ], + ) + + def test_tree_filter_q_objects(self): + tree = self.create_tree() + # Tree-filter should remove children if + # the parent does not meet the filtering criteria + nodes = Model.objects.tree_filter( + Q(name__in=["root", "1-1", "2", "2-1", "2-2"]) + ) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_1, + tree.child2_2, + ], + ) + + def test_tree_filter_q_mix(self): + tree = SimpleNamespace() + + tree.root = MultiOrderedModel.objects.create( + name="root", first_position=1, second_position=2 + ) + tree.child1 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=1, second_position=0, name="1" + ) + tree.child2 = MultiOrderedModel.objects.create( + parent=tree.root, first_position=1, second_position=2, name="2" + ) + tree.child1_1 = MultiOrderedModel.objects.create( + parent=tree.child1, first_position=1, second_position=1, name="1-1" + ) + tree.child2_1 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=1, name="2-1" + ) + tree.child2_2 = MultiOrderedModel.objects.create( + parent=tree.child2, first_position=1, second_position=2, name="2-2" + ) + # Tree-filter should remove children if + # the parent does not meet the filtering criteria + nodes = MultiOrderedModel.objects.tree_filter( + Q(first_position=1), second_position=2 + ) + self.assertEqual( + list(nodes), + [ + tree.root, + tree.child2, + tree.child2_2, + ], + ) + + def test_tree_fields(self): + self.create_tree() + qs = Model.objects.tree_fields(tree_names="name", tree_orders="order") + + names = [obj.tree_names for obj in qs] + self.assertEqual( + names, + [ + ["root"], + ["root", "1"], + ["root", "1", "1-1"], + ["root", "2"], + ["root", "2", "2-1"], + ["root", "2", "2-2"], + ], + ) + + orders = [obj.tree_orders for obj in qs] + self.assertEqual( + orders, [[0], [0, 0], [0, 0, 0], [0, 1], [0, 1, 0], [0, 1, 42]] + ) + + # ids = [obj.tree_pks for obj in Model.objects.tree_fields(tree_pks="custom_id")] + # self.assertIsInstance(ids[0][0], int) + + # ids = [obj.tree_pks for obj in Model.objects.tree_fields(tree_pks="parent_id")] + # self.assertEqual(ids[0], [""]) diff --git a/tree_queries/__init__.py b/tree_queries/__init__.py index 1317d75..11ac8e1 100644 --- a/tree_queries/__init__.py +++ b/tree_queries/__init__.py @@ -1 +1 @@ -__version__ = "0.18.0" +__version__ = "0.19.0" diff --git a/tree_queries/compiler.py b/tree_queries/compiler.py index 74cc32c..0aefd36 100644 --- a/tree_queries/compiler.py +++ b/tree_queries/compiler.py @@ -1,5 +1,7 @@ +import django from django.db import connections -from django.db.models import Value +from django.db.models import Expression, F, QuerySet, Value, Window +from django.db.models.functions import RowNumber from django.db.models.sql.compiler import SQLCompiler from django.db.models.sql.query import Query @@ -12,8 +14,33 @@ def _find_tree_model(cls): class TreeQuery(Query): - # Set by TreeQuerySet.order_siblings_by - sibling_order = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._setup_query() + + def _setup_query(self): + """ + Run on initialization and at the end of chaining. Any attributes that + would normally be set in __init__() should go here instead. + """ + # We add the variables for `sibling_order` and `rank_table_query` here so they + # act as instance variables which do not persist between user queries + # the way class variables do + + # Only add the sibling_order attribute if the query doesn't already have one to preserve cloning behavior + if not hasattr(self, "sibling_order"): + # Add an attribute to control the ordering of siblings within trees + opts = _find_tree_model(self.model)._meta + self.sibling_order = opts.ordering if opts.ordering else opts.pk.attname + + # Only add the rank_table_query attribute if the query doesn't already have one to preserve cloning behavior + if not hasattr(self, "rank_table_query"): + # Create a default QuerySet for the rank_table to use + # so we can avoid recursion + self.rank_table_query = QuerySet(model=_find_tree_model(self.model)) + + if not hasattr(self, "tree_fields"): + self.tree_fields = {} def get_compiler(self, using=None, connection=None, **kwargs): # Copied from django/db/models/sql/query.py @@ -28,37 +55,37 @@ class TreeQuery(Query): return TreeCompiler(self, connection, using, **kwargs) def get_sibling_order(self): - if self.sibling_order is not None: - return self.sibling_order - opts = _find_tree_model(self.model)._meta - if opts.ordering: - return opts.ordering - return opts.pk.attname + return self.sibling_order + + def get_rank_table_query(self): + return self.rank_table_query + + def get_tree_fields(self): + return self.tree_fields class TreeCompiler(SQLCompiler): CTE_POSTGRESQL = """ WITH RECURSIVE __rank_table( + {tree_fields_columns} "{pk}", "{parent}", "rank_order" ) AS ( - SELECT - {rank_pk}, - {rank_parent}, - ROW_NUMBER() OVER (ORDER BY {rank_order_by}) - FROM {rank_from} + {rank_table} ), __tree ( + {tree_fields_names} "tree_depth", "tree_path", "tree_ordering", "tree_pk" ) AS ( SELECT - 0 AS tree_depth, - array[T.{pk}] AS tree_path, - array[T.rank_order] AS tree_ordering, + {tree_fields_initial} + 0, + array[T.{pk}], + array[T.rank_order], T."{pk}" FROM __rank_table T WHERE T."{parent}" IS NULL @@ -66,7 +93,8 @@ class TreeCompiler(SQLCompiler): UNION ALL SELECT - __tree.tree_depth + 1 AS tree_depth, + {tree_fields_recursive} + __tree.tree_depth + 1, __tree.tree_path || T.{pk}, __tree.tree_ordering || T.rank_order, T."{pk}" @@ -76,15 +104,23 @@ class TreeCompiler(SQLCompiler): """ CTE_MYSQL = """ - WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS ( - SELECT - {rank_pk}, - {rank_parent}, - ROW_NUMBER() OVER (ORDER BY {rank_order_by}) - FROM {rank_from} + WITH RECURSIVE __rank_table( + {tree_fields_columns} + {pk}, + {parent}, + rank_order + ) AS ( + {rank_table} ), - __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( + __tree( + {tree_fields_names} + tree_depth, + tree_path, + tree_ordering, + tree_pk + ) AS ( SELECT + {tree_fields_initial} 0, -- Limit to max. 50 levels... CAST(CONCAT("{sep}", {pk}, "{sep}") AS char(1000)), @@ -97,6 +133,7 @@ class TreeCompiler(SQLCompiler): UNION ALL SELECT + {tree_fields_recursive} __tree.tree_depth + 1, CONCAT(__tree.tree_path, T2.{pk}, "{sep}"), CONCAT(__tree.tree_ordering, LPAD(CONCAT(T2.rank_order, "{sep}"), 20, "0")), @@ -106,19 +143,27 @@ class TreeCompiler(SQLCompiler): ) """ - CTE_SQLITE3 = """ - WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS ( - SELECT - {rank_pk}, - {rank_parent}, - row_number() OVER (ORDER BY {rank_order_by}) - FROM {rank_from} + CTE_SQLITE = """ + WITH RECURSIVE __rank_table( + {tree_fields_columns} + {pk}, + {parent}, + rank_order + ) AS ( + {rank_table} ), - __tree(tree_depth, tree_path, tree_ordering, tree_pk) AS ( + __tree( + {tree_fields_names} + tree_depth, + tree_path, + tree_ordering, + tree_pk + ) AS ( SELECT - 0 tree_depth, - printf("{sep}%%s{sep}", {pk}) tree_path, - printf("{sep}%%020s{sep}", T.rank_order) tree_ordering, + {tree_fields_initial} + 0, + printf("{sep}%%s{sep}", {pk}), + printf("{sep}%%020s{sep}", T.rank_order), T."{pk}" tree_pk FROM __rank_table T WHERE T."{parent}" IS NULL @@ -126,6 +171,7 @@ class TreeCompiler(SQLCompiler): UNION ALL SELECT + {tree_fields_recursive} __tree.tree_depth + 1, __tree.tree_path || printf("%%s{sep}", T.{pk}), __tree.tree_ordering || printf("%%020s{sep}", T.rank_order), @@ -135,13 +181,8 @@ class TreeCompiler(SQLCompiler): ) """ - def get_sibling_order_params(self): - """ - This method uses a simple django queryset to generate sql - that can be used to create the __rank_table that orders - siblings. This is done so that any joins required by order_by - are pre-calculated by django - """ + def get_rank_table(self): + # Get and validate sibling_order sibling_order = self.query.get_sibling_order() if isinstance(sibling_order, (list, tuple)): @@ -153,38 +194,41 @@ class TreeCompiler(SQLCompiler): "Sibling order must be a string or a list or tuple of strings." ) - # Use Django to make a SQL query whose parts can be repurposed for __rank_table - base_query = ( - _find_tree_model(self.query.model) - .objects.only("pk", "parent") - .order_by(*order_fields) - .query + # Convert strings to expressions. This is to maintain backwards compatibility + # with Django versions < 4.1 + if django.VERSION < (4, 1): + base_order = [] + for field in order_fields: + if isinstance(field, Expression): + base_order.append(field) + elif isinstance(field, str): + if field[0] == "-": + base_order.append(F(field[1:]).desc()) + else: + base_order.append(F(field).asc()) + order_fields = base_order + + # Get the rank table query + rank_table_query = self.query.get_rank_table_query() + + rank_table_query = ( + rank_table_query.order_by() # Ensure there is no ORDER BY at the end of the SQL + # Values allows us to both limit and specify the order of + # the columns selected so that they match the CTE + .values( + *self.query.get_tree_fields().values(), + "pk", + "parent", + rank_order=Window( + expression=RowNumber(), + order_by=order_fields, + ), + ) ) - # Use the base compiler because we want vanilla sql and want to avoid recursion. - base_compiler = SQLCompiler(base_query, self.connection, None) - base_sql, base_params = base_compiler.as_sql() - result_sql = base_sql % base_params - - # Split the base SQL string on the SQL keywords 'FROM' and 'ORDER BY' - from_split = result_sql.split("FROM") - order_split = from_split[1].split("ORDER BY") - - # Identify the FROM and ORDER BY parts of the base SQL - ordering_params = { - "rank_from": order_split[0].strip(), - "rank_order_by": order_split[1].strip(), - } - - # Identify the primary key field and parent_id field from the SELECT section - base_select = from_split[0][6:] - for field in base_select.split(","): - if "parent_id" in field: # XXX Taking advantage of Hardcoded. - ordering_params["rank_parent"] = field.strip() - else: - ordering_params["rank_pk"] = field.strip() + rank_table_sql, rank_table_params = rank_table_query.query.sql_with_params() - return ordering_params + return rank_table_sql, rank_table_params def as_sql(self, *args, **kwargs): # Try detecting if we're used in a EXISTS(1 as "a") subquery like @@ -229,8 +273,39 @@ class TreeCompiler(SQLCompiler): "sep": SEPARATOR, } - # Add ordering params to params - params.update(self.get_sibling_order_params()) + # Get the rank_table SQL and params + rank_table_sql, rank_table_params = self.get_rank_table() + params["rank_table"] = rank_table_sql + + if self.connection.vendor == "postgresql": + cte = self.CTE_POSTGRESQL + cte_initial = "array[T.{column}]::text[], " + cte_recursive = "__tree.{name} || T.{column}::text, " + elif self.connection.vendor == "sqlite": + cte = self.CTE_SQLITE + cte_initial = 'printf("{sep}%%s{sep}", {column}), ' + cte_recursive = '__tree.{name} || printf("%%s{sep}", T.{column}), ' + elif self.connection.vendor == "mysql": + cte = self.CTE_MYSQL + cte_initial = 'CAST(CONCAT("{sep}", {column}, "{sep}") AS char(1000)), ' + cte_recursive = 'CONCAT(__tree.{name}, T2.{column}, "{sep}"), ' + + tree_fields = self.query.get_tree_fields() + qn = self.connection.ops.quote_name + params.update({ + "tree_fields_columns": "".join( + f"{qn(column)}, " for column in tree_fields.values() + ), + "tree_fields_names": "".join(f"{qn(name)}, " for name in tree_fields), + "tree_fields_initial": "".join( + cte_initial.format(column=qn(column), name=qn(name), sep=SEPARATOR) + for name, column in tree_fields.items() + ), + "tree_fields_recursive": "".join( + cte_recursive.format(column=qn(column), name=qn(name), sep=SEPARATOR) + for name, column in tree_fields.items() + ), + }) if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely tree_params = params.copy() @@ -246,16 +321,16 @@ class TreeCompiler(SQLCompiler): if aliases: tree_params["db_table"] = aliases[0] + select = { + "tree_depth": "__tree.tree_depth", + "tree_path": "__tree.tree_path", + "tree_ordering": "__tree.tree_ordering", + } + select.update({name: f"__tree.{name}" for name in tree_fields}) self.query.add_extra( # Do not add extra fields to the select statement when it is a # summary query or when using .values() or .values_list() - select={} - if skip_tree_fields or self.query.values_select - else { - "tree_depth": "__tree.tree_depth", - "tree_path": "__tree.tree_path", - "tree_ordering": "__tree.tree_ordering", - }, + select={} if skip_tree_fields or self.query.values_select else select, select_params=None, where=["__tree.tree_pk = {db_table}.{pk}".format(**tree_params)], params=None, @@ -269,27 +344,29 @@ class TreeCompiler(SQLCompiler): ), ) - if self.connection.vendor == "postgresql": - cte = self.CTE_POSTGRESQL - elif self.connection.vendor == "sqlite": - cte = self.CTE_SQLITE3 - elif self.connection.vendor == "mysql": - cte = self.CTE_MYSQL sql_0, sql_1 = super().as_sql(*args, **kwargs) explain = "" if sql_0.startswith("EXPLAIN "): explain, sql_0 = sql_0.split(" ", 1) - return ("".join([explain, cte.format(**params), sql_0]), sql_1) + # Pass any additional rank table sql paramaters so that the db backend can handle them. + # This only works because we know that the CTE is at the start of the query. + return ( + "".join([explain, cte.format(**params), sql_0]), + rank_table_params + sql_1, + ) def get_converters(self, expressions): converters = super().get_converters(expressions) + tree_fields = {"__tree.tree_path", "__tree.tree_ordering"} | { + f"__tree.{name}" for name in self.query.tree_fields + } for i, expression in enumerate(expressions): # We care about tree fields and annotations only if not hasattr(expression, "sql"): continue - if expression.sql in {"__tree.tree_path", "__tree.tree_ordering"}: + if expression.sql in tree_fields: converters[i] = ([converter], expression) return converters diff --git a/tree_queries/fields.py b/tree_queries/fields.py index 76340bd..fe5dd07 100644 --- a/tree_queries/fields.py +++ b/tree_queries/fields.py @@ -5,7 +5,7 @@ from tree_queries.forms import TreeNodeChoiceField class TreeNodeForeignKey(models.ForeignKey): def deconstruct(self): - name, path, args, kwargs = super().deconstruct() + name, _path, args, kwargs = super().deconstruct() return (name, "django.db.models.ForeignKey", args, kwargs) def formfield(self, **kwargs): diff --git a/tree_queries/query.py b/tree_queries/query.py index 250ad48..f408ef4 100644 --- a/tree_queries/query.py +++ b/tree_queries/query.py @@ -27,6 +27,7 @@ class TreeQuerySet(models.QuerySet): """ if tree_fields: self.query.__class__ = TreeQuery + self.query._setup_query() else: self.query.__class__ = Query return self @@ -45,10 +46,44 @@ class TreeQuerySet(models.QuerySet): to order tree siblings by those model fields """ self.query.__class__ = TreeQuery + self.query._setup_query() self.query.sibling_order = order_by return self - def as_manager(cls, *, with_tree_fields=False): # noqa: N805 + def tree_filter(self, *args, **kwargs): + """ + Adds a filter to the TreeQuery rank_table_query + + Takes the same arguements as a Django QuerySet .filter() + """ + self.query.__class__ = TreeQuery + self.query._setup_query() + self.query.rank_table_query = self.query.rank_table_query.filter( + *args, **kwargs + ) + return self + + def tree_exclude(self, *args, **kwargs): + """ + Adds a filter to the TreeQuery rank_table_query + + Takes the same arguements as a Django QuerySet .exclude() + """ + self.query.__class__ = TreeQuery + self.query._setup_query() + self.query.rank_table_query = self.query.rank_table_query.exclude( + *args, **kwargs + ) + return self + + def tree_fields(self, **tree_fields): + self.query.__class__ = TreeQuery + self.query._setup_query() + self.query.tree_fields = tree_fields + return self + + @classmethod + def as_manager(cls, *, with_tree_fields=False): manager_class = TreeManager.from_queryset(cls) # Only used in deconstruct: manager_class._built_with_as_manager = True @@ -59,7 +94,6 @@ class TreeQuerySet(models.QuerySet): return manager_class() as_manager.queryset_only = True - as_manager = classmethod(as_manager) def ancestors(self, of, *, include_self=False): """ @@ -94,10 +128,7 @@ class TreeQuerySet(models.QuerySet): where=[ # XXX This *may* be unsafe with some primary key field types. # It is certainly safe with integers. - 'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format( - pk=self.model._meta.pk.get_db_prep_value(pk(of), connection), - sep=SEPARATOR, - ) + f'instr(__tree.tree_path, "{SEPARATOR}{self.model._meta.pk.get_db_prep_value(pk(of), connection)}{SEPARATOR}") <> 0' ] ) -- GitLab