Skip to content
Snippets Groups Projects
Commit e2fcac51 authored by Michael Fladischer's avatar Michael Fladischer
Browse files

New upstream version 0.19

parent 02f9666f
No related branches found
No related tags found
No related merge requests found
exclude: ".yarn/|yarn.lock|\\.min\\.(css|js)$"
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-added-large-files
- id: check-builtin-literals
......@@ -14,7 +14,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/adamchainz/django-upgrade
rev: 1.15.0
rev: 1.16.0
hooks:
- id: django-upgrade
args: [--target-version, "3.2"]
......@@ -23,7 +23,7 @@ repos:
hooks:
- id: absolufy-imports
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.6"
rev: "v0.4.1"
hooks:
- id: ruff
- id: ruff-format
......@@ -34,10 +34,10 @@ repos:
args: [--list-different, --no-semi]
exclude: "^conf/|.*\\.html$"
- repo: https://github.com/tox-dev/pyproject-fmt
rev: 1.5.1
rev: 1.8.0
hooks:
- id: pyproject-fmt
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.15
rev: v0.16
hooks:
- id: validate-pyproject
......@@ -4,6 +4,18 @@ Change log
Next version
~~~~~~~~~~~~
0.19 (2024-04-25)
~~~~~~~~~~~~~~~~~
- Reimplemented the rank table construction using a real queryset; this enables
support for pre-filtering the tree queryset using ``.tree_filter()`` and
``.tree_exclude()``. Thanks rhomboss!
- Added a ``.tree_fields()`` method to allow adding additional columns to the
tree queryset, allowing collecting ancestors fields directly when running the
initial query. For example, ``.tree_fields(tree_names="name")`` will collect
all ``name`` fields in a ``tree_fields`` array on the model instances. For
now the code only supports string fields and integer fields.
0.18 (2024-04-03)
~~~~~~~~~~~~~~~~~
......
......@@ -27,7 +27,9 @@ Features and limitations
``tree_depth``, ``tree_path`` and ``tree_ordering``. The names cannot
be changed. ``tree_depth`` is an integer, ``tree_path`` an array of
primary keys and ``tree_ordering`` an array of values used for
ordering nodes within their siblings.
ordering nodes within their siblings. Note that the contents of the
``tree_path`` and ``tree_ordering`` are subject to change. You shouldn't rely
on their contents.
- Besides adding the fields mentioned above the package only adds queryset
methods for ordering siblings and filtering ancestors and descendants. Other
features may be useful, but will not be added to the package just because
......
......@@ -46,6 +46,12 @@ include = ["tree_queries/"]
path = "tree_queries/__init__.py"
[tool.ruff]
fix = true
preview = true
show-fixes = true
target-version = "py38"
[tool.ruff.lint]
extend-select = [
# pyflakes, pycodestyle
"F", "E", "W",
......@@ -80,7 +86,7 @@ extend-select = [
# pygrep-hooks
"PGH",
# pylint
"PL",
"PLC", "PLE", "PLW",
# unused noqa
"RUF100",
]
......@@ -90,18 +96,15 @@ extend-ignore = [
# No line length errors
"E501",
]
fix = true
show-fixes = true
target-version = "py38"
[tool.ruff.isort]
[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
max-complexity = 15
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"*/migrat*/*" = [
# Allow using PascalCase model names in migrations
"N806",
......
......@@ -75,7 +75,7 @@ class AlwaysTreeQueryModel(TreeNode):
class UUIDModel(TreeNode):
id = models.UUIDField(primary_key=True, default=uuid.uuid4) # noqa: A003
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
name = models.CharField(max_length=100)
def __str__(self):
......@@ -136,3 +136,6 @@ class OneToOneRelatedOrder(models.Model):
related_name="related",
)
order = models.PositiveIntegerField(default=0)
def __str__(self):
return ""
from types import SimpleNamespace
from django import forms
from django.core.exceptions import ValidationError
from django.db import connections, models
......@@ -29,13 +31,13 @@ from tree_queries.query import pk
@override_settings(DEBUG=True)
class Test(TestCase):
def create_tree(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...
tree = SimpleNamespace()
tree.root = Model.objects.create(name="root")
tree.child1 = Model.objects.create(parent=tree.root, order=0, name="1")
tree.child2 = Model.objects.create(parent=tree.root, order=1, name="2")
tree.child1_1 = Model.objects.create(parent=tree.child1, order=0, name="1-1")
tree.child2_1 = Model.objects.create(parent=tree.child2, order=0, name="2-1")
tree.child2_2 = Model.objects.create(parent=tree.child2, order=1, name="2-2")
tree.child2_2 = Model.objects.create(parent=tree.child2, order=42, name="2-2")
return tree
def test_stuff(self):
......@@ -257,7 +259,7 @@ class Test(TestCase):
self.assertNotIn("root", html)
def test_string_ordering(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...
tree = SimpleNamespace()
tree.americas = StringOrderedModel.objects.create(name="Americas")
tree.europe = StringOrderedModel.objects.create(name="Europe")
......@@ -373,7 +375,7 @@ class Test(TestCase):
def test_reference(self):
tree = self.create_tree()
references = type("Namespace", (), {})() # SimpleNamespace for PY2...
references = SimpleNamespace()
references.none = ReferenceModel.objects.create(position=0)
references.root = ReferenceModel.objects.create(
position=1, tree_field=tree.root
......@@ -500,9 +502,7 @@ class Test(TestCase):
else:
qs = qs.annotate(
is_my_field=RawSQL(
'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format(
pk=pk(tree.child2_1), sep=SEPARATOR
),
f'instr(__tree.tree_path, "{SEPARATOR}{pk(tree.child2_1)}{SEPARATOR}") <> 0',
[],
output_field=models.BooleanField(),
)
......@@ -534,7 +534,7 @@ class Test(TestCase):
)
def test_sibling_ordering(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...
tree = SimpleNamespace()
tree.root = MultiOrderedModel.objects.create(name="root")
tree.child1 = MultiOrderedModel.objects.create(
......@@ -682,7 +682,7 @@ class Test(TestCase):
)
def test_multi_field_order(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...
tree = SimpleNamespace()
tree.root = MultiOrderedModel.objects.create(name="root")
tree.child1 = MultiOrderedModel.objects.create(
......@@ -717,7 +717,7 @@ class Test(TestCase):
)
def test_order_by_related(self):
tree = type("Namespace", (), {})() # SimpleNamespace for PY2...
tree = SimpleNamespace()
tree.root = RelatedOrderModel.objects.create(name="root")
tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1")
......@@ -753,3 +753,202 @@ class Test(TestCase):
tree.child2_2,
],
)
def test_tree_exclude(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent meets the filtering criteria
nodes = Model.objects.tree_exclude(name="2")
self.assertEqual(
list(nodes),
[
tree.root,
tree.child1,
tree.child1_1,
],
)
def test_tree_filter(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.tree_filter(name__in=["root", "1-1", "2", "2-1", "2-2"])
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)
def test_tree_filter_chaining(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.tree_exclude(name="2-2").tree_filter(
name__in=["root", "1-1", "2", "2-1", "2-2"]
)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
],
)
def test_tree_filter_related(self):
tree = SimpleNamespace()
tree.root = RelatedOrderModel.objects.create(name="root")
tree.root_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.root, order=0
)
tree.child1 = RelatedOrderModel.objects.create(parent=tree.root, name="1")
tree.child1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child1, order=0
)
tree.child2 = RelatedOrderModel.objects.create(parent=tree.root, name="2")
tree.child2_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2, order=1
)
tree.child1_1 = RelatedOrderModel.objects.create(parent=tree.child1, name="1-1")
tree.child1_1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child1_1, order=0
)
tree.child2_1 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-1")
tree.child2_1_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2_1, order=0
)
tree.child2_2 = RelatedOrderModel.objects.create(parent=tree.child2, name="2-2")
tree.child2_2_related = OneToOneRelatedOrder.objects.create(
relatedmodel=tree.child2_2, order=1
)
nodes = RelatedOrderModel.objects.tree_filter(related__order=0)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child1,
tree.child1_1,
],
)
def test_tree_filter_with_order(self):
tree = SimpleNamespace()
tree.root = MultiOrderedModel.objects.create(
name="root",
first_position=1,
)
tree.child1 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=0, second_position=1, name="1"
)
tree.child2 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=0, name="2"
)
tree.child1_1 = MultiOrderedModel.objects.create(
parent=tree.child1, first_position=1, second_position=1, name="1-1"
)
tree.child2_1 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=1, name="2-1"
)
tree.child2_2 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=0, name="2-2"
)
nodes = MultiOrderedModel.objects.tree_filter(
first_position__gt=0
).order_siblings_by("-second_position")
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)
def test_tree_filter_q_objects(self):
tree = self.create_tree()
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = Model.objects.tree_filter(
Q(name__in=["root", "1-1", "2", "2-1", "2-2"])
)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_1,
tree.child2_2,
],
)
def test_tree_filter_q_mix(self):
tree = SimpleNamespace()
tree.root = MultiOrderedModel.objects.create(
name="root", first_position=1, second_position=2
)
tree.child1 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=0, name="1"
)
tree.child2 = MultiOrderedModel.objects.create(
parent=tree.root, first_position=1, second_position=2, name="2"
)
tree.child1_1 = MultiOrderedModel.objects.create(
parent=tree.child1, first_position=1, second_position=1, name="1-1"
)
tree.child2_1 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=1, name="2-1"
)
tree.child2_2 = MultiOrderedModel.objects.create(
parent=tree.child2, first_position=1, second_position=2, name="2-2"
)
# Tree-filter should remove children if
# the parent does not meet the filtering criteria
nodes = MultiOrderedModel.objects.tree_filter(
Q(first_position=1), second_position=2
)
self.assertEqual(
list(nodes),
[
tree.root,
tree.child2,
tree.child2_2,
],
)
def test_tree_fields(self):
self.create_tree()
qs = Model.objects.tree_fields(tree_names="name", tree_orders="order")
names = [obj.tree_names for obj in qs]
self.assertEqual(
names,
[
["root"],
["root", "1"],
["root", "1", "1-1"],
["root", "2"],
["root", "2", "2-1"],
["root", "2", "2-2"],
],
)
orders = [obj.tree_orders for obj in qs]
self.assertEqual(
orders, [[0], [0, 0], [0, 0, 0], [0, 1], [0, 1, 0], [0, 1, 42]]
)
# ids = [obj.tree_pks for obj in Model.objects.tree_fields(tree_pks="custom_id")]
# self.assertIsInstance(ids[0][0], int)
# ids = [obj.tree_pks for obj in Model.objects.tree_fields(tree_pks="parent_id")]
# self.assertEqual(ids[0], [""])
__version__ = "0.18.0"
__version__ = "0.19.0"
import django
from django.db import connections
from django.db.models import Value
from django.db.models import Expression, F, QuerySet, Value, Window
from django.db.models.functions import RowNumber
from django.db.models.sql.compiler import SQLCompiler
from django.db.models.sql.query import Query
......@@ -12,8 +14,33 @@ def _find_tree_model(cls):
class TreeQuery(Query):
# Set by TreeQuerySet.order_siblings_by
sibling_order = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._setup_query()
def _setup_query(self):
"""
Run on initialization and at the end of chaining. Any attributes that
would normally be set in __init__() should go here instead.
"""
# We add the variables for `sibling_order` and `rank_table_query` here so they
# act as instance variables which do not persist between user queries
# the way class variables do
# Only add the sibling_order attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "sibling_order"):
# Add an attribute to control the ordering of siblings within trees
opts = _find_tree_model(self.model)._meta
self.sibling_order = opts.ordering if opts.ordering else opts.pk.attname
# Only add the rank_table_query attribute if the query doesn't already have one to preserve cloning behavior
if not hasattr(self, "rank_table_query"):
# Create a default QuerySet for the rank_table to use
# so we can avoid recursion
self.rank_table_query = QuerySet(model=_find_tree_model(self.model))
if not hasattr(self, "tree_fields"):
self.tree_fields = {}
def get_compiler(self, using=None, connection=None, **kwargs):
# Copied from django/db/models/sql/query.py
......@@ -28,37 +55,37 @@ class TreeQuery(Query):
return TreeCompiler(self, connection, using, **kwargs)
def get_sibling_order(self):
if self.sibling_order is not None:
return self.sibling_order
opts = _find_tree_model(self.model)._meta
if opts.ordering:
return opts.ordering
return opts.pk.attname
return self.sibling_order
def get_rank_table_query(self):
return self.rank_table_query
def get_tree_fields(self):
return self.tree_fields
class TreeCompiler(SQLCompiler):
CTE_POSTGRESQL = """
WITH RECURSIVE __rank_table(
{tree_fields_columns}
"{pk}",
"{parent}",
"rank_order"
) AS (
SELECT
{rank_pk},
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
{rank_table}
),
__tree (
{tree_fields_names}
"tree_depth",
"tree_path",
"tree_ordering",
"tree_pk"
) AS (
SELECT
0 AS tree_depth,
array[T.{pk}] AS tree_path,
array[T.rank_order] AS tree_ordering,
{tree_fields_initial}
0,
array[T.{pk}],
array[T.rank_order],
T."{pk}"
FROM __rank_table T
WHERE T."{parent}" IS NULL
......@@ -66,7 +93,8 @@ class TreeCompiler(SQLCompiler):
UNION ALL
SELECT
__tree.tree_depth + 1 AS tree_depth,
{tree_fields_recursive}
__tree.tree_depth + 1,
__tree.tree_path || T.{pk},
__tree.tree_ordering || T.rank_order,
T."{pk}"
......@@ -76,15 +104,23 @@ class TreeCompiler(SQLCompiler):
"""
CTE_MYSQL = """
WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS (
SELECT
{rank_pk},
{rank_parent},
ROW_NUMBER() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
WITH RECURSIVE __rank_table(
{tree_fields_columns}
{pk},
{parent},
rank_order
) AS (
{rank_table}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
__tree(
{tree_fields_names}
tree_depth,
tree_path,
tree_ordering,
tree_pk
) AS (
SELECT
{tree_fields_initial}
0,
-- Limit to max. 50 levels...
CAST(CONCAT("{sep}", {pk}, "{sep}") AS char(1000)),
......@@ -97,6 +133,7 @@ class TreeCompiler(SQLCompiler):
UNION ALL
SELECT
{tree_fields_recursive}
__tree.tree_depth + 1,
CONCAT(__tree.tree_path, T2.{pk}, "{sep}"),
CONCAT(__tree.tree_ordering, LPAD(CONCAT(T2.rank_order, "{sep}"), 20, "0")),
......@@ -106,19 +143,27 @@ class TreeCompiler(SQLCompiler):
)
"""
CTE_SQLITE3 = """
WITH RECURSIVE __rank_table({pk}, {parent}, rank_order) AS (
SELECT
{rank_pk},
{rank_parent},
row_number() OVER (ORDER BY {rank_order_by})
FROM {rank_from}
CTE_SQLITE = """
WITH RECURSIVE __rank_table(
{tree_fields_columns}
{pk},
{parent},
rank_order
) AS (
{rank_table}
),
__tree(tree_depth, tree_path, tree_ordering, tree_pk) AS (
__tree(
{tree_fields_names}
tree_depth,
tree_path,
tree_ordering,
tree_pk
) AS (
SELECT
0 tree_depth,
printf("{sep}%%s{sep}", {pk}) tree_path,
printf("{sep}%%020s{sep}", T.rank_order) tree_ordering,
{tree_fields_initial}
0,
printf("{sep}%%s{sep}", {pk}),
printf("{sep}%%020s{sep}", T.rank_order),
T."{pk}" tree_pk
FROM __rank_table T
WHERE T."{parent}" IS NULL
......@@ -126,6 +171,7 @@ class TreeCompiler(SQLCompiler):
UNION ALL
SELECT
{tree_fields_recursive}
__tree.tree_depth + 1,
__tree.tree_path || printf("%%s{sep}", T.{pk}),
__tree.tree_ordering || printf("%%020s{sep}", T.rank_order),
......@@ -135,13 +181,8 @@ class TreeCompiler(SQLCompiler):
)
"""
def get_sibling_order_params(self):
"""
This method uses a simple django queryset to generate sql
that can be used to create the __rank_table that orders
siblings. This is done so that any joins required by order_by
are pre-calculated by django
"""
def get_rank_table(self):
# Get and validate sibling_order
sibling_order = self.query.get_sibling_order()
if isinstance(sibling_order, (list, tuple)):
......@@ -153,38 +194,41 @@ class TreeCompiler(SQLCompiler):
"Sibling order must be a string or a list or tuple of strings."
)
# Use Django to make a SQL query whose parts can be repurposed for __rank_table
base_query = (
_find_tree_model(self.query.model)
.objects.only("pk", "parent")
.order_by(*order_fields)
.query
# Convert strings to expressions. This is to maintain backwards compatibility
# with Django versions < 4.1
if django.VERSION < (4, 1):
base_order = []
for field in order_fields:
if isinstance(field, Expression):
base_order.append(field)
elif isinstance(field, str):
if field[0] == "-":
base_order.append(F(field[1:]).desc())
else:
base_order.append(F(field).asc())
order_fields = base_order
# Get the rank table query
rank_table_query = self.query.get_rank_table_query()
rank_table_query = (
rank_table_query.order_by() # Ensure there is no ORDER BY at the end of the SQL
# Values allows us to both limit and specify the order of
# the columns selected so that they match the CTE
.values(
*self.query.get_tree_fields().values(),
"pk",
"parent",
rank_order=Window(
expression=RowNumber(),
order_by=order_fields,
),
)
)
# Use the base compiler because we want vanilla sql and want to avoid recursion.
base_compiler = SQLCompiler(base_query, self.connection, None)
base_sql, base_params = base_compiler.as_sql()
result_sql = base_sql % base_params
# Split the base SQL string on the SQL keywords 'FROM' and 'ORDER BY'
from_split = result_sql.split("FROM")
order_split = from_split[1].split("ORDER BY")
# Identify the FROM and ORDER BY parts of the base SQL
ordering_params = {
"rank_from": order_split[0].strip(),
"rank_order_by": order_split[1].strip(),
}
# Identify the primary key field and parent_id field from the SELECT section
base_select = from_split[0][6:]
for field in base_select.split(","):
if "parent_id" in field: # XXX Taking advantage of Hardcoded.
ordering_params["rank_parent"] = field.strip()
else:
ordering_params["rank_pk"] = field.strip()
rank_table_sql, rank_table_params = rank_table_query.query.sql_with_params()
return ordering_params
return rank_table_sql, rank_table_params
def as_sql(self, *args, **kwargs):
# Try detecting if we're used in a EXISTS(1 as "a") subquery like
......@@ -229,8 +273,39 @@ class TreeCompiler(SQLCompiler):
"sep": SEPARATOR,
}
# Add ordering params to params
params.update(self.get_sibling_order_params())
# Get the rank_table SQL and params
rank_table_sql, rank_table_params = self.get_rank_table()
params["rank_table"] = rank_table_sql
if self.connection.vendor == "postgresql":
cte = self.CTE_POSTGRESQL
cte_initial = "array[T.{column}]::text[], "
cte_recursive = "__tree.{name} || T.{column}::text, "
elif self.connection.vendor == "sqlite":
cte = self.CTE_SQLITE
cte_initial = 'printf("{sep}%%s{sep}", {column}), '
cte_recursive = '__tree.{name} || printf("%%s{sep}", T.{column}), '
elif self.connection.vendor == "mysql":
cte = self.CTE_MYSQL
cte_initial = 'CAST(CONCAT("{sep}", {column}, "{sep}") AS char(1000)), '
cte_recursive = 'CONCAT(__tree.{name}, T2.{column}, "{sep}"), '
tree_fields = self.query.get_tree_fields()
qn = self.connection.ops.quote_name
params.update({
"tree_fields_columns": "".join(
f"{qn(column)}, " for column in tree_fields.values()
),
"tree_fields_names": "".join(f"{qn(name)}, " for name in tree_fields),
"tree_fields_initial": "".join(
cte_initial.format(column=qn(column), name=qn(name), sep=SEPARATOR)
for name, column in tree_fields.items()
),
"tree_fields_recursive": "".join(
cte_recursive.format(column=qn(column), name=qn(name), sep=SEPARATOR)
for name, column in tree_fields.items()
),
})
if "__tree" not in self.query.extra_tables: # pragma: no branch - unlikely
tree_params = params.copy()
......@@ -246,16 +321,16 @@ class TreeCompiler(SQLCompiler):
if aliases:
tree_params["db_table"] = aliases[0]
select = {
"tree_depth": "__tree.tree_depth",
"tree_path": "__tree.tree_path",
"tree_ordering": "__tree.tree_ordering",
}
select.update({name: f"__tree.{name}" for name in tree_fields})
self.query.add_extra(
# Do not add extra fields to the select statement when it is a
# summary query or when using .values() or .values_list()
select={}
if skip_tree_fields or self.query.values_select
else {
"tree_depth": "__tree.tree_depth",
"tree_path": "__tree.tree_path",
"tree_ordering": "__tree.tree_ordering",
},
select={} if skip_tree_fields or self.query.values_select else select,
select_params=None,
where=["__tree.tree_pk = {db_table}.{pk}".format(**tree_params)],
params=None,
......@@ -269,27 +344,29 @@ class TreeCompiler(SQLCompiler):
),
)
if self.connection.vendor == "postgresql":
cte = self.CTE_POSTGRESQL
elif self.connection.vendor == "sqlite":
cte = self.CTE_SQLITE3
elif self.connection.vendor == "mysql":
cte = self.CTE_MYSQL
sql_0, sql_1 = super().as_sql(*args, **kwargs)
explain = ""
if sql_0.startswith("EXPLAIN "):
explain, sql_0 = sql_0.split(" ", 1)
return ("".join([explain, cte.format(**params), sql_0]), sql_1)
# Pass any additional rank table sql paramaters so that the db backend can handle them.
# This only works because we know that the CTE is at the start of the query.
return (
"".join([explain, cte.format(**params), sql_0]),
rank_table_params + sql_1,
)
def get_converters(self, expressions):
converters = super().get_converters(expressions)
tree_fields = {"__tree.tree_path", "__tree.tree_ordering"} | {
f"__tree.{name}" for name in self.query.tree_fields
}
for i, expression in enumerate(expressions):
# We care about tree fields and annotations only
if not hasattr(expression, "sql"):
continue
if expression.sql in {"__tree.tree_path", "__tree.tree_ordering"}:
if expression.sql in tree_fields:
converters[i] = ([converter], expression)
return converters
......
......@@ -5,7 +5,7 @@ from tree_queries.forms import TreeNodeChoiceField
class TreeNodeForeignKey(models.ForeignKey):
def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
name, _path, args, kwargs = super().deconstruct()
return (name, "django.db.models.ForeignKey", args, kwargs)
def formfield(self, **kwargs):
......
......@@ -27,6 +27,7 @@ class TreeQuerySet(models.QuerySet):
"""
if tree_fields:
self.query.__class__ = TreeQuery
self.query._setup_query()
else:
self.query.__class__ = Query
return self
......@@ -45,10 +46,44 @@ class TreeQuerySet(models.QuerySet):
to order tree siblings by those model fields
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.sibling_order = order_by
return self
def as_manager(cls, *, with_tree_fields=False): # noqa: N805
def tree_filter(self, *args, **kwargs):
"""
Adds a filter to the TreeQuery rank_table_query
Takes the same arguements as a Django QuerySet .filter()
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.rank_table_query = self.query.rank_table_query.filter(
*args, **kwargs
)
return self
def tree_exclude(self, *args, **kwargs):
"""
Adds a filter to the TreeQuery rank_table_query
Takes the same arguements as a Django QuerySet .exclude()
"""
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.rank_table_query = self.query.rank_table_query.exclude(
*args, **kwargs
)
return self
def tree_fields(self, **tree_fields):
self.query.__class__ = TreeQuery
self.query._setup_query()
self.query.tree_fields = tree_fields
return self
@classmethod
def as_manager(cls, *, with_tree_fields=False):
manager_class = TreeManager.from_queryset(cls)
# Only used in deconstruct:
manager_class._built_with_as_manager = True
......@@ -59,7 +94,6 @@ class TreeQuerySet(models.QuerySet):
return manager_class()
as_manager.queryset_only = True
as_manager = classmethod(as_manager)
def ancestors(self, of, *, include_self=False):
"""
......@@ -94,10 +128,7 @@ class TreeQuerySet(models.QuerySet):
where=[
# XXX This *may* be unsafe with some primary key field types.
# It is certainly safe with integers.
'instr(__tree.tree_path, "{sep}{pk}{sep}") <> 0'.format(
pk=self.model._meta.pk.get_db_prep_value(pk(of), connection),
sep=SEPARATOR,
)
f'instr(__tree.tree_path, "{SEPARATOR}{self.model._meta.pk.get_db_prep_value(pk(of), connection)}{SEPARATOR}") <> 0'
]
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment