Skip to content
Snippets Groups Projects
Select Git revision
  • upstream/2.5.8+ds
  • main default
  • pristine-tar
  • upstream
  • debian/2.5.8+ds-1
  • debian/2.5.7+ds-1
  • upstream/2.5.7+ds
  • debian/2.5.4+ds-1
  • upstream/2.5.4+ds
  • debian/2.4.3+ds-2
  • debian/2.4.3+ds-1
  • upstream/2.4.3+ds
  • debian/2.4.2+ds-3
  • debian/2.4.2+ds-2
  • debian/2.4.2+ds-1
  • upstream/2.4.2+ds
  • debian/2.3.3-1
  • upstream/2.3.3
  • debian/2.1.2-1
  • upstream/2.1.2
  • debian/1.5.0-1
  • upstream/1.5.0
  • upstream/1.4.0
23 results

test_history.py

Blame
  • contrib.py 10.33 KiB
    """Additional goodies"""
    
    import functools
    import itertools
    import operator
    from typing import Any, List, Tuple, Union
    
    from django.db import models
    
    from pgtrigger import core, utils
    
    # A sentinel value to determine if a kwarg is unset
    _unset = object()
    
    
    def _get_column(model, field):
        field = field if isinstance(field, models.Field) else model._meta.get_field(field)
        col = field.column
        if not col:  # pragma: no cover
            if getattr(field, "columns", None):
                raise ValueError(
                    f"Field {field.name} references a composite key and is not supported."
                )
            else:
                raise ValueError(f"Field {field.name} does not reference a database column.")
        return col
    
    
    def _get_columns(model, field):
        field = field if isinstance(field, models.Field) else model._meta.get_field(field)
        col = field.column
        cols = [col] if col else getattr(field, "columns", None)
        if not cols:  # pragma: no cover
            raise ValueError(f"Field {field.name} does not reference a database column.")
        return cols
    
    
    class Protect(core.Trigger):
        """A trigger that raises an exception."""
    
        when: core.When = core.Before
    
        def get_func(self, model):
            sql = f"""
                RAISE EXCEPTION
                    'pgtrigger: Cannot {str(self.operation).lower()} rows from % table',
                    TG_TABLE_NAME;
            """
            return self.format_sql(sql)
    
    
    class ReadOnly(Protect):
        """A trigger that prevents edits to fields.
    
        If `fields` are provided, will protect edits to only those fields.
        If `exclude` is provided, will protect all fields except the ones
        excluded.
        If none of these arguments are provided, all fields cannot be edited.
        """
    
        fields: Union[List[str], None] = None
        exclude: Union[List[str], None] = None
        operation: core.Operation = core.Update
    
        def __init__(
            self,
            *,
            fields: Union[List[str], None] = None,
            exclude: Union[List[str], None] = None,
            **kwargs: Any,
        ):
            self.fields = fields or self.fields
            self.exclude = exclude or self.exclude
    
            if self.fields and self.exclude:
                raise ValueError('Must provide only one of "fields" or "exclude" to ReadOnly trigger')
    
            super().__init__(**kwargs)
    
        def get_condition(self, model):
            if not self.fields and not self.exclude:
                return core.Condition("OLD.* IS DISTINCT FROM NEW.*")
            else:
                if self.exclude:
                    # Sanity check that the exclude list contains valid fields
                    for field in self.exclude:
                        model._meta.get_field(field)
    
                    fields = [f.name for f in model._meta.fields if f.name not in self.exclude]
                else:
                    fields = [model._meta.get_field(field).name for field in self.fields]
    
                return functools.reduce(
                    operator.or_,
                    [core.Q(**{f"old__{field}__df": core.F(f"new__{field}")}) for field in fields],
                )
    
    
    class FSM(core.Trigger):
        """Enforces a finite state machine on a field.
    
        Supply the trigger with the `field` that transitions and then
        a list of tuples of valid transitions to the `transitions` argument.
    
        !!! note
    
            Only non-null `CharField` fields without quotes are currently supported.
            If your strings have a colon symbol in them, you must override the
            "separator" argument to be a value other than a colon.
        """
    
        when: core.When = core.Before
        operation: core.Operation = core.Update
        field: str = None
        transitions: List[Tuple[str, str]] = None
        separator: str = ":"
    
        def __init__(
            self,
            *,
            name: str = None,
            condition: Union[core.Condition, None] = None,
            field: str = None,
            transitions: List[Tuple[str, str]] = None,
            separator: str = None,
        ):
            self.field = field or self.field
            self.transitions = transitions or self.transitions
            self.separator = separator or self.separator
    
            if not self.field:  # pragma: no cover
                raise ValueError('Must provide "field" for FSM')
    
            if not self.transitions:  # pragma: no cover
                raise ValueError('Must provide "transitions" for FSM')
    
            # This trigger doesn't accept quoted values or values that
            # contain the configured separator
            for value in itertools.chain(*self.transitions):
                if "'" in value or '"' in value:
                    raise ValueError(f'FSM transition value "{value}" contains quotes')
                elif self.separator in value:
                    raise ValueError(
                        f'FSM value "{value}" contains separator "{self.separator}".'
                        ' Configure your trigger with a different "separator" attribute'
                    )
    
            # The separator must be a single character that isn't a quote
            if len(self.separator) != 1:
                raise ValueError(f'Separator "{self.separator}" must be a single character')
            elif self.separator in ('"', "'"):
                raise ValueError("Separator must not have quotes")
    
            super().__init__(name=name, condition=condition)
    
        def get_declare(self, model):
            return [("_is_valid_transition", "BOOLEAN")]
    
        def get_func(self, model):
            col = _get_column(model, self.field)
            transition_uris = (
                "{" + ",".join([f"{old}{self.separator}{new}" for old, new in self.transitions]) + "}"
            )
    
            sql = f"""
                SELECT CONCAT(OLD.{utils.quote(col)}, '{self.separator}', NEW.{utils.quote(col)}) = ANY('{transition_uris}'::text[])
                    INTO _is_valid_transition;
    
                IF (_is_valid_transition IS FALSE AND OLD.{utils.quote(col)} IS DISTINCT FROM NEW.{utils.quote(col)}) THEN
                    RAISE EXCEPTION
                        'pgtrigger: Invalid transition of field "{self.field}" from "%" to "%" on table %',
                        OLD.{utils.quote(col)},
                        NEW.{utils.quote(col)},
                        TG_TABLE_NAME;
                ELSE
                    RETURN NEW;
                END IF;
            """  # noqa
            return self.format_sql(sql)
    
    
    class SoftDelete(core.Trigger):
        """Sets a field to a value when a delete happens.
    
        Supply the trigger with the "field" that will be set
        upon deletion and the "value" to which it should be set.
        The "value" defaults to `False`.
    
        !!! note
    
            This trigger currently only supports nullable `BooleanField`,
            `CharField`, and `IntField` fields.
        """
    
        when: core.When = core.Before
        operation: core.Operation = core.Delete
        field: str = None
        value: Union[bool, str, int, None] = False
    
        def __init__(
            self,
            *,
            name: str = None,
            condition: Union[core.Condition, None] = None,
            field: str = None,
            value: Union[bool, str, int, None] = _unset,
        ):
            self.field = field or self.field
            self.value = value if value is not _unset else self.value
    
            if not self.field:  # pragma: no cover
                raise ValueError('Must provide "field" for soft delete')
    
            super().__init__(name=name, condition=condition)
    
        def get_func(self, model):
            soft_field = _get_column(model, self.field)
    
            # Support composite primary keys in Django 5.2+
            pk_cols = _get_columns(model, model._meta.pk)
            table_pk_cols = ",".join(utils.quote(col) for col in pk_cols)
            trigger_pk_cols = ",".join(f"OLD.{utils.quote(col)}" for col in pk_cols)
    
            if len(pk_cols) > 1:
                table_pk_cols = f"({table_pk_cols})"
                trigger_pk_cols = f"({trigger_pk_cols})"
    
            def _render_value():
                if self.value is None:
                    return "NULL"
                elif isinstance(self.value, str):
                    return f"'{self.value}'"
                else:
                    return str(self.value)
    
            sql = f"""
                UPDATE {utils.quote(model._meta.db_table)}
                SET {soft_field} = {_render_value()}
                WHERE {table_pk_cols} = {trigger_pk_cols};
                RETURN NULL;
            """
            return self.format_sql(sql)
    
    
    class UpdateSearchVector(core.Trigger):
        """Updates a `django.contrib.postgres.search.SearchVectorField` from document fields.
    
        Supply the trigger with the `vector_field` that will be updated with
        changes to the `document_fields`. Optionally provide a `config_name`, which
        defaults to `pg_catalog.english`.
    
        This trigger uses `tsvector_update_trigger` to update the vector field.
        See [the Postgres docs](https://www.postgresql.org/docs/current/textsearch-features.html#TEXTSEARCH-UPDATE-TRIGGERS)
        for more information.
    
        !!! note
    
            `UpdateSearchVector` triggers are not compatible with [pgtrigger.ignore][] since
            it references a built-in trigger. Trying to ignore this trigger results in a
            `RuntimeError`.
        """  # noqa
    
        when: core.When = core.Before
        vector_field: str = None
        document_fields: List[str] = None
        config_name: str = "pg_catalog.english"
    
        def __init__(
            self,
            *,
            name: str = None,
            vector_field: str = None,
            document_fields: List[str] = None,
            config_name: str = None,
        ):
            self.vector_field = vector_field or self.vector_field
            self.document_fields = document_fields or self.document_fields
            self.config_name = config_name or self.config_name
    
            if not self.vector_field:
                raise ValueError('Must provide "vector_field" to update search vector')
    
            if not self.document_fields:
                raise ValueError('Must provide "document_fields" to update search vector')
    
            if not self.config_name:  # pragma: no cover
                raise ValueError('Must provide "config_name" to update search vector')
    
            super().__init__(name=name, operation=core.Insert | core.UpdateOf(*document_fields))
    
        def ignore(self, model):
            raise RuntimeError(f"Cannot ignore {self.__class__.__name__} triggers")
    
        def get_func(self, model):
            return ""
    
        def render_execute(self, model):
            document_cols = [_get_column(model, field) for field in self.document_fields]
            rendered_document_cols = ", ".join(utils.quote(col) for col in document_cols)
            vector_col = _get_column(model, self.vector_field)
            return (
                f"tsvector_update_trigger({utils.quote(vector_col)},"
                f" {utils.quote(self.config_name)}, {rendered_document_cols})"
            )