Select Git revision
To find the state of this project's repository at the time of any of these versions, check out the tags.
contrib.py 10.33 KiB
"""Additional goodies"""
import functools
import itertools
import operator
from typing import Any, List, Tuple, Union
from django.db import models
from pgtrigger import core, utils
# A sentinel value to determine if a kwarg is unset
_unset = object()
def _get_column(model, field):
field = field if isinstance(field, models.Field) else model._meta.get_field(field)
col = field.column
if not col: # pragma: no cover
if getattr(field, "columns", None):
raise ValueError(
f"Field {field.name} references a composite key and is not supported."
)
else:
raise ValueError(f"Field {field.name} does not reference a database column.")
return col
def _get_columns(model, field):
field = field if isinstance(field, models.Field) else model._meta.get_field(field)
col = field.column
cols = [col] if col else getattr(field, "columns", None)
if not cols: # pragma: no cover
raise ValueError(f"Field {field.name} does not reference a database column.")
return cols
class Protect(core.Trigger):
"""A trigger that raises an exception."""
when: core.When = core.Before
def get_func(self, model):
sql = f"""
RAISE EXCEPTION
'pgtrigger: Cannot {str(self.operation).lower()} rows from % table',
TG_TABLE_NAME;
"""
return self.format_sql(sql)
class ReadOnly(Protect):
"""A trigger that prevents edits to fields.
If `fields` are provided, will protect edits to only those fields.
If `exclude` is provided, will protect all fields except the ones
excluded.
If none of these arguments are provided, all fields cannot be edited.
"""
fields: Union[List[str], None] = None
exclude: Union[List[str], None] = None
operation: core.Operation = core.Update
def __init__(
self,
*,
fields: Union[List[str], None] = None,
exclude: Union[List[str], None] = None,
**kwargs: Any,
):
self.fields = fields or self.fields
self.exclude = exclude or self.exclude
if self.fields and self.exclude:
raise ValueError('Must provide only one of "fields" or "exclude" to ReadOnly trigger')
super().__init__(**kwargs)
def get_condition(self, model):
if not self.fields and not self.exclude:
return core.Condition("OLD.* IS DISTINCT FROM NEW.*")
else:
if self.exclude:
# Sanity check that the exclude list contains valid fields
for field in self.exclude:
model._meta.get_field(field)
fields = [f.name for f in model._meta.fields if f.name not in self.exclude]
else:
fields = [model._meta.get_field(field).name for field in self.fields]
return functools.reduce(
operator.or_,
[core.Q(**{f"old__{field}__df": core.F(f"new__{field}")}) for field in fields],
)
class FSM(core.Trigger):
"""Enforces a finite state machine on a field.
Supply the trigger with the `field` that transitions and then
a list of tuples of valid transitions to the `transitions` argument.
!!! note
Only non-null `CharField` fields without quotes are currently supported.
If your strings have a colon symbol in them, you must override the
"separator" argument to be a value other than a colon.
"""
when: core.When = core.Before
operation: core.Operation = core.Update
field: str = None
transitions: List[Tuple[str, str]] = None
separator: str = ":"
def __init__(
self,
*,
name: str = None,
condition: Union[core.Condition, None] = None,
field: str = None,
transitions: List[Tuple[str, str]] = None,
separator: str = None,
):
self.field = field or self.field
self.transitions = transitions or self.transitions
self.separator = separator or self.separator
if not self.field: # pragma: no cover
raise ValueError('Must provide "field" for FSM')
if not self.transitions: # pragma: no cover
raise ValueError('Must provide "transitions" for FSM')
# This trigger doesn't accept quoted values or values that
# contain the configured separator
for value in itertools.chain(*self.transitions):
if "'" in value or '"' in value:
raise ValueError(f'FSM transition value "{value}" contains quotes')
elif self.separator in value:
raise ValueError(
f'FSM value "{value}" contains separator "{self.separator}".'
' Configure your trigger with a different "separator" attribute'
)
# The separator must be a single character that isn't a quote
if len(self.separator) != 1:
raise ValueError(f'Separator "{self.separator}" must be a single character')
elif self.separator in ('"', "'"):
raise ValueError("Separator must not have quotes")
super().__init__(name=name, condition=condition)
def get_declare(self, model):
return [("_is_valid_transition", "BOOLEAN")]
def get_func(self, model):
col = _get_column(model, self.field)
transition_uris = (
"{" + ",".join([f"{old}{self.separator}{new}" for old, new in self.transitions]) + "}"
)
sql = f"""
SELECT CONCAT(OLD.{utils.quote(col)}, '{self.separator}', NEW.{utils.quote(col)}) = ANY('{transition_uris}'::text[])
INTO _is_valid_transition;
IF (_is_valid_transition IS FALSE AND OLD.{utils.quote(col)} IS DISTINCT FROM NEW.{utils.quote(col)}) THEN
RAISE EXCEPTION
'pgtrigger: Invalid transition of field "{self.field}" from "%" to "%" on table %',
OLD.{utils.quote(col)},
NEW.{utils.quote(col)},
TG_TABLE_NAME;
ELSE
RETURN NEW;
END IF;
""" # noqa
return self.format_sql(sql)
class SoftDelete(core.Trigger):
"""Sets a field to a value when a delete happens.
Supply the trigger with the "field" that will be set
upon deletion and the "value" to which it should be set.
The "value" defaults to `False`.
!!! note
This trigger currently only supports nullable `BooleanField`,
`CharField`, and `IntField` fields.
"""
when: core.When = core.Before
operation: core.Operation = core.Delete
field: str = None
value: Union[bool, str, int, None] = False
def __init__(
self,
*,
name: str = None,
condition: Union[core.Condition, None] = None,
field: str = None,
value: Union[bool, str, int, None] = _unset,
):
self.field = field or self.field
self.value = value if value is not _unset else self.value
if not self.field: # pragma: no cover
raise ValueError('Must provide "field" for soft delete')
super().__init__(name=name, condition=condition)
def get_func(self, model):
soft_field = _get_column(model, self.field)
# Support composite primary keys in Django 5.2+
pk_cols = _get_columns(model, model._meta.pk)
table_pk_cols = ",".join(utils.quote(col) for col in pk_cols)
trigger_pk_cols = ",".join(f"OLD.{utils.quote(col)}" for col in pk_cols)
if len(pk_cols) > 1:
table_pk_cols = f"({table_pk_cols})"
trigger_pk_cols = f"({trigger_pk_cols})"
def _render_value():
if self.value is None:
return "NULL"
elif isinstance(self.value, str):
return f"'{self.value}'"
else:
return str(self.value)
sql = f"""
UPDATE {utils.quote(model._meta.db_table)}
SET {soft_field} = {_render_value()}
WHERE {table_pk_cols} = {trigger_pk_cols};
RETURN NULL;
"""
return self.format_sql(sql)
class UpdateSearchVector(core.Trigger):
"""Updates a `django.contrib.postgres.search.SearchVectorField` from document fields.
Supply the trigger with the `vector_field` that will be updated with
changes to the `document_fields`. Optionally provide a `config_name`, which
defaults to `pg_catalog.english`.
This trigger uses `tsvector_update_trigger` to update the vector field.
See [the Postgres docs](https://www.postgresql.org/docs/current/textsearch-features.html#TEXTSEARCH-UPDATE-TRIGGERS)
for more information.
!!! note
`UpdateSearchVector` triggers are not compatible with [pgtrigger.ignore][] since
it references a built-in trigger. Trying to ignore this trigger results in a
`RuntimeError`.
""" # noqa
when: core.When = core.Before
vector_field: str = None
document_fields: List[str] = None
config_name: str = "pg_catalog.english"
def __init__(
self,
*,
name: str = None,
vector_field: str = None,
document_fields: List[str] = None,
config_name: str = None,
):
self.vector_field = vector_field or self.vector_field
self.document_fields = document_fields or self.document_fields
self.config_name = config_name or self.config_name
if not self.vector_field:
raise ValueError('Must provide "vector_field" to update search vector')
if not self.document_fields:
raise ValueError('Must provide "document_fields" to update search vector')
if not self.config_name: # pragma: no cover
raise ValueError('Must provide "config_name" to update search vector')
super().__init__(name=name, operation=core.Insert | core.UpdateOf(*document_fields))
def ignore(self, model):
raise RuntimeError(f"Cannot ignore {self.__class__.__name__} triggers")
def get_func(self, model):
return ""
def render_execute(self, model):
document_cols = [_get_column(model, field) for field in self.document_fields]
rendered_document_cols = ", ".join(utils.quote(col) for col in document_cols)
vector_col = _get_column(model, self.vector_field)
return (
f"tsvector_update_trigger({utils.quote(vector_col)},"
f" {utils.quote(self.config_name)}, {rendered_document_cols})"
)