# Copyright 2020 Karlsruhe Institute of Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import redirect_stderr
from datetime import timezone
import alembic
from flask import current_app
from sqlalchemy import UniqueConstraint
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.properties import RelationshipProperty
from sqlalchemy.schema import CheckConstraint
from sqlalchemy.schema import Index
import kadi.lib.constants as const
from kadi.ext.db import db
from kadi.lib.utils import rgetattr
from kadi.lib.utils import utcnow
[docs]class UTCDateTime(db.TypeDecorator):
"""Custom timezone aware DateTime type using UTC.
As dates are currently saved without timezone information (and always interpreted as
UTC), the timezone information has to be removed from datetime objects before
persisting, as otherwise they are converted to local time. When retrieving the
value, the timezone will be added back in.
"""
impl = db.DateTime
cache_ok = True
[docs] def process_bind_param(self, value, dialect):
"""Convert to UTC and then remove the timezone."""
if value is None:
return value
return value.astimezone(timezone.utc).replace(tzinfo=None)
[docs] def process_result_value(self, value, dialect):
"""Replace the missing timezone with UTC."""
if value is None:
return value
return value.replace(tzinfo=timezone.utc)
[docs]class BaseTimestampMixin:
"""Base mixin class for SQLAlchemy models to add timestamp columns.
In all current implementations, changes in columns and relationship-based
collections can be ignored by specifying the ``Meta.timestamp_exclude`` attribute in
the inheriting class. It should be a list of strings specifying the attribute names
to exclude.
**Example:**
.. code-block:: python3
class Foo:
class Meta:
timestamp_exclude = ["bar", "baz"]
"""
created_at = db.Column(UTCDateTime, default=utcnow, nullable=False)
"""The date and time an object has been created at.
Always uses the current UTC time.
"""
last_modified = db.Column(UTCDateTime, default=utcnow, nullable=False)
"""The date and time an object was last modified.
Always uses the current UTC time as initial value. After calling
:meth:`register_timestamp_listener` for an inheriting mixin class, this timestamp
can automatically get updated by implementing "_before_flush_timestamp" as a class
method.
"""
[docs] @classmethod
def register_timestamp_listener(cls):
"""Register a listener to automatically update the last modification timestamp.
Uses SQLAlchemy's ``before_flush`` event and propagates to all inheriting
models.
"""
db.event.listen(
db.session, "before_flush", cls._before_flush_timestamp, propagate=True
)
[docs] def update_timestamp(self):
"""Manually trigger an update to the last modification timestamp."""
self.last_modified = utcnow()
[docs]class SimpleTimestampMixin(BaseTimestampMixin):
"""Timestamp mixin class which triggers on all changes."""
@classmethod
def _before_flush_timestamp(cls, session, flush_context, instances):
for obj in session.dirty:
if isinstance(obj, cls) and session.is_modified(obj):
excluded_attrs = rgetattr(obj.__class__, "Meta.timestamp_exclude", [])
for attr in db.inspect(obj).attrs:
if attr.key in excluded_attrs:
continue
if attr.load_history().has_changes():
obj.last_modified = utcnow()
break
[docs]class StateTimestampMixin(BaseTimestampMixin):
"""Timestamp mixin class which only triggers on changes in active objects.
An object is considered active if its marked as active via its ``state`` attribute,
which needs to be present in an inheriting model. Besides the object itself, changes
in relationship-based collections consisting of inactive objects only are also
ignored.
"""
@classmethod
def _before_flush_timestamp(cls, session, flush_context, instances):
for obj in session.dirty:
if isinstance(obj, cls) and session.is_modified(obj):
# Do not update the timestamp if the state of the object (if present) is
# not active, except for when it just changed.
if obj.state != const.MODEL_STATE_ACTIVE:
history = db.inspect(obj).attrs.state.load_history()
if (
not history.deleted
or history.deleted[0] != const.MODEL_STATE_ACTIVE
):
continue
excluded_attrs = rgetattr(obj.__class__, "Meta.timestamp_exclude", [])
timestamp_updated = False
# Do not update the timestamp if only changes in relationship-based
# collections occured and none of the related objects' state, if
# present, is active.
for attr in db.inspect(obj).attrs:
if timestamp_updated:
break
if attr.key in excluded_attrs:
continue
history = attr.load_history()
items = list(history.added) + list(history.deleted)
for item in items:
if (
not isinstance(item, db.Model)
or getattr(item, "state", const.MODEL_STATE_ACTIVE)
== const.MODEL_STATE_ACTIVE
):
obj.last_modified = utcnow()
timestamp_updated = True
break
[docs] def update_timestamp(self):
"""Manually trigger an update to the last modification timestamp.
Only actually triggers the update if the object is marked as active.
"""
if self.state == const.MODEL_STATE_ACTIVE:
self.last_modified = utcnow()
[docs]class NestedTransaction:
"""Context manager to start a "nested" transaction.
The nested transaction uses the ``SAVEPOINT`` feature of the database, which will be
triggered when entering the context manager. Once the context manager exits, the
savepoint is released, which does not actually persist the changes done in the
savepoint yet. If the release produces the given exception, the savepoint will be
rolled back automatically. The ``success`` attribute of the transaction can be used
to check the result of the release operation afterwards.
:param exc: (optional) The exception to catch when releasing the savepoint.
"""
def __init__(self, exc=SQLAlchemyError):
self.exc = exc
self._success = False
self._savepoint = None
def __enter__(self):
self._savepoint = db.session.begin_nested()
return self
def __exit__(self, type, value, traceback):
try:
self._savepoint.commit()
self._success = True
except self.exc:
self._savepoint.rollback()
@property
def success(self):
"""Get the status of the nested transaction once the savepoint is released.
Will always be ``False`` before the context manager exits.
"""
return self._success
[docs]def update_object(obj, **kwargs):
r"""Convenience function to update database objects.
Only columns (i.e. attributes) that actually exist will get updated.
:param obj: The object to update.
:param \**kwargs: The columns to update and their respective values.
"""
for key, value in kwargs.items():
if hasattr(obj, key):
setattr(obj, key, value)
[docs]def composite_index(tablename, *cols):
r"""Generate a composite index.
:param tablename: The name of the table.
:param \*cols: The names of the columns.
:return: The Index instance.
"""
return Index(f"ix_{tablename}_{'_'.join(cols)}", *cols)
[docs]def unique_constraint(tablename, *cols):
r"""Generate a unique constraint.
:param tablename: The name of the table.
:param \*cols: The names of the columns.
:return: The UniqueConstraint instance.
"""
return UniqueConstraint(*cols, name=f"uq_{tablename}_{'_'.join(cols)}")
[docs]def check_constraint(constraint, name):
"""Generate a check constraint.
:param constraint: The constraint expression as string.
:param name: The name of the constraint.
:return: The CheckConstraint instance.
"""
return CheckConstraint(constraint, name=f"ck_{name}")
[docs]def length_constraint(col, min_value=None, max_value=None):
"""Generate a length check constraint for a column.
:param col: The name of the column.
:param min_value: (optional) Minimum length.
:param max_value: (optional) Maximum length.
:return: The CheckConstraint instance.
"""
constraint = ""
if min_value is not None:
constraint += f"char_length({col}) >= {min_value}"
if max_value is not None:
if min_value is not None:
constraint += " AND "
constraint += f"char_length({col}) <= {max_value}"
return check_constraint(constraint, name=f"{col}_length")
[docs]def range_constraint(col, min_value=None, max_value=None):
"""Generate a range check constraint for a column.
:param col: The name of the column.
:param min_value: (optional) Minimum value.
:param max_value: (optional) Maximum value.
:return: The CheckConstraint instance.
"""
constraint = ""
if min_value is not None:
constraint += f"{col} >= {min_value}"
if max_value is not None:
if min_value is not None:
constraint += " AND "
constraint += f"{col} <= {max_value}"
return check_constraint(constraint, name=f"{col}_range")
[docs]def values_constraint(col, values):
"""Generate a values check constraint for a column.
:param col: The name of the column.
:param values: List of values.
:return: The CheckConstraint instance.
"""
values = values if values is not None else []
values = ", ".join(f"'{value}'" for value in values)
constraint = f"{col} IN ({values})"
return check_constraint(constraint, name=f"{col}_values")
[docs]def generate_check_constraints(constraints):
"""Generate database check constraints.
Supports check constraints of type ``"length"``, ``"range"`` and ``"values"``. The
constraints have to be given in the following form:
.. code-block:: python3
{
"col_1": {"length": {"min": 0, "max": 10}},
"col_2": {"range": {"min": 0, "max": 10}},
"col_3": {"values": ["val_1", "val_2"]},
}
:param constraints: Dictionary of constraints to generate.
:return: A tuple of CheckConstraint instances.
"""
results = []
for col, constraint in constraints.items():
for name, args in constraint.items():
if name == "length":
results.append(
length_constraint(
col, min_value=args.get("min"), max_value=args.get("max")
)
)
elif name == "range":
results.append(
range_constraint(
col, min_value=args.get("min"), max_value=args.get("max")
)
)
elif name == "values":
results.append(values_constraint(col, args))
return tuple(results)
[docs]def get_class_by_tablename(tablename):
"""Get the model class mapped to a certain database table name.
:param tablename: Name of the table.
:return: The class reference or ``None`` if the table does not exist.
"""
for mapper in db.Model.registry.mappers:
if getattr(mapper.class_, "__tablename__", None) == tablename:
return mapper.class_
return None
[docs]def is_column(model, attr):
"""Check if a model's attribute is a regular column.
:param model: The model that contains the column.
:param attr: Name of the column attribute.
:return: ``True`` if the attribute is a column, ``False`` otherwise.
"""
return isinstance(getattr(model, attr).property, ColumnProperty)
[docs]def get_column_type(model, attr):
"""Get the type of a column.
:param model: The model that contains the column.
:param attr: Name of the column attribute.
:return: The type of the column or ``None`` if the attribute is not a regular
column.
"""
if is_column(model, attr):
return getattr(model, attr).property.columns[0].type
return None
[docs]def is_relationship(model, attr):
"""Check if a model's attribute is a relationship.
:param model: The model that contains the column.
:param attr: Name of the relationship attribute.
:return: ``True`` if the attribute is a relationship, ``False`` otherwise.
"""
return isinstance(getattr(model, attr).property, RelationshipProperty)
[docs]def is_many_relationship(model, attr):
"""Check if a model's attribute is a many-relationship.
:param model: The model that contains the column.
:param attr: Name of the relationship attribute.
:return: ``True`` if the attribute is a many-relationship, ``False`` otherwise.
"""
if is_relationship(model, attr):
return getattr(model, attr).property.uselist
return False
[docs]def get_class_of_relationship(model, attr):
"""Get the class of a relationship.
:param model: The model that contains the relationship.
:param attr: Name of the relationship attribute.
:return: The class reference or ``None`` if the attribute is not a relationship.
"""
if is_relationship(model, attr):
return getattr(model, attr).property.mapper.class_
return None
[docs]def escape_like(value, escape="\\"):
"""Escape a string for use in LIKE queries.
Will escape ``"%"``, ``"_"`` and the escape character specified by ``escape``.
:param value: The string to escape.
:param escape: (optional) The escape character to use.
:return: The escaped string.
"""
return (
value.replace(escape, 2 * escape)
.replace("%", f"{escape}%")
.replace("_", f"{escape}_")
)
[docs]def acquire_lock(obj):
"""Acquire a database lock on a given object.
The database row corresponding to the given object will be locked using ``FOR
UPDATE`` until the session is committed or rolled back. Once the lock is acquired,
the given object is also refreshed. Should only be used if strictly necessary, e.g.
to prevent certain race conditions.
:param obj: The object to lock.
:return: The locked and refreshed object.
"""
return (
obj.__class__.query.populate_existing()
.with_for_update()
.filter(obj.__class__.id == obj.id)
.first()
)
[docs]def has_extension(extension_name):
"""Check if a given database extension is installed.
Currently only supported for PostgreSQL databases.
:param extension_name: The name of the extension.
:return: ``True`` if the extension is installed, ``False`` if the extension is not
installed or the database is not PostgreSQL.
"""
if db.engine.name != "postgresql":
return False
with db.engine.connect() as connection:
return (
connection.execute(
db.text("SELECT * FROM pg_extension WHERE extname=:extname"),
{"extname": extension_name},
).first()
is not None
)
[docs]def has_pending_revisions():
"""Check if the database has pending revisions.
:return: ``True`` if there are pending revisions, ``False`` otherwise.
"""
migration_config = current_app.extensions["migrate"].migrate.get_config()
script_dir = alembic.script.ScriptDirectory.from_config(migration_config)
latest_revision = script_dir.get_current_head()
# pylint: disable=unspecified-encoding
with open(os.devnull, mode="w") as f:
with redirect_stderr(f):
with db.engine.connect() as connection:
migration_context = (
alembic.runtime.migration.MigrationContext.configure(connection)
)
current_revision = migration_context.get_current_revision()
return current_revision != latest_revision