mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-14 09:08:32 +00:00
Improve recursive CTE in get_all_field_dependencies
This commit is contained in:
parent
6c08674f49
commit
66e5b28f2f
3 changed files with 109 additions and 27 deletions
backend
src/baserow/contrib/database
tests/baserow/contrib/database/field/dependencies
|
@ -34,7 +34,8 @@ class BaserowFormulaSelectOptionsSerializer(serializers.ListField):
|
||||||
# but let's avoid the potentially slow query if not required.
|
# but let's avoid the potentially slow query if not required.
|
||||||
if field_type.can_represent_select_options(field):
|
if field_type.can_represent_select_options(field):
|
||||||
select_options = SelectOption.objects.filter(
|
select_options = SelectOption.objects.filter(
|
||||||
field_id__in=get_all_field_dependencies(field)
|
field_id__in=get_all_field_dependencies(field),
|
||||||
|
field__trashed=False,
|
||||||
)
|
)
|
||||||
return [self.child.to_representation(item) for item in select_options]
|
return [self.child.to_representation(item) for item in select_options]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,48 +22,102 @@
|
||||||
# this query so it works with our own database models and structure.
|
# this query so it works with our own database models and structure.
|
||||||
#
|
#
|
||||||
|
|
||||||
from django.conf import settings
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
from django.db import connection
|
||||||
|
|
||||||
|
from baserow.contrib.database.fields.dependencies.exceptions import (
|
||||||
|
CircularFieldDependencyError,
|
||||||
|
)
|
||||||
from baserow.contrib.database.fields.dependencies.models import FieldDependency
|
from baserow.contrib.database.fields.dependencies.models import FieldDependency
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from baserow.contrib.database.fields.models import Field
|
||||||
|
|
||||||
|
|
||||||
def will_cause_circular_dep(from_field, to_field):
|
def will_cause_circular_dep(from_field, to_field):
|
||||||
return from_field.id in get_all_field_dependencies(to_field)
|
return from_field.id in get_all_field_dependencies(to_field)
|
||||||
|
|
||||||
|
|
||||||
def get_all_field_dependencies(field):
|
def get_all_field_dependencies(field: "Field") -> set[int]:
|
||||||
|
"""
|
||||||
|
Get all field dependencies for a field. This includes all fields that the given
|
||||||
|
field depends on, directly or indirectly, even if the field have been trashed. For
|
||||||
|
example, if the given field is a formula that references another formula which in
|
||||||
|
turn references a text field, both the intermediate formula and the text field will
|
||||||
|
be returned as dependencies.
|
||||||
|
|
||||||
|
This function uses a recursive CTE to traverse the field dependency graph and return
|
||||||
|
all field ids that are reachable from the given field id. If a circular dependency
|
||||||
|
is detected, a CircularFieldDependencyError is raised.
|
||||||
|
|
||||||
|
:param field: The field to get dependencies for.
|
||||||
|
:return: A set of field ids that the given field depends on.
|
||||||
|
:raises CircularFieldDependencyError: If a circular dependency is detected.
|
||||||
|
"""
|
||||||
|
|
||||||
from baserow.contrib.database.fields.models import Field
|
from baserow.contrib.database.fields.models import Field
|
||||||
|
|
||||||
query_parameters = {
|
filtered_field_dependencies = FieldDependency.objects.filter(
|
||||||
"pk": field.pk,
|
dependant_id__table__database_id=Field.objects_and_trash.filter(pk=field.pk)
|
||||||
"max_depth": settings.MAX_FIELD_REFERENCE_DEPTH,
|
.order_by()
|
||||||
}
|
.values("table__database_id")[:1]
|
||||||
relationship_table = FieldDependency._meta.db_table
|
)
|
||||||
pk_name = "id"
|
sql, params = filtered_field_dependencies.query.get_compiler(
|
||||||
|
connection=connection
|
||||||
|
).as_sql()
|
||||||
|
|
||||||
# Only pk_name and a table name get formatted in, no user controllable input, safe.
|
# Only pk_name and a table name get formatted in, no user controllable input, safe.
|
||||||
# fmt: off
|
# fmt: off
|
||||||
raw_query = (
|
raw_query = (
|
||||||
f"""
|
f"""
|
||||||
WITH RECURSIVE traverse({pk_name}, depth) AS (
|
WITH RECURSIVE dependencies AS ({sql}),
|
||||||
SELECT first.dependency_id, 1
|
traverse(id, depth, path, is_circular) AS (
|
||||||
FROM {relationship_table} AS first
|
SELECT
|
||||||
LEFT OUTER JOIN {relationship_table} AS second
|
dependency_id,
|
||||||
ON first.dependency_id = second.dependant_id
|
1,
|
||||||
WHERE first.dependant_id = %(pk)s
|
ARRAY[dependant_id, dependency_id],
|
||||||
UNION
|
FALSE
|
||||||
SELECT DISTINCT dependency_id, traverse.depth + 1
|
FROM dependencies
|
||||||
FROM traverse
|
WHERE dependant_id = %s
|
||||||
INNER JOIN {relationship_table}
|
|
||||||
ON {relationship_table}.dependant_id = traverse.{pk_name}
|
UNION ALL
|
||||||
WHERE 1 = 1
|
|
||||||
|
SELECT
|
||||||
|
d.dependency_id,
|
||||||
|
traverse.depth + 1,
|
||||||
|
path || d.dependency_id,
|
||||||
|
d.dependency_id = ANY(path) OR traverse.is_circular -- detect circularity
|
||||||
|
FROM traverse
|
||||||
|
INNER JOIN dependencies d ON d.dependant_id = traverse.id
|
||||||
|
WHERE NOT traverse.is_circular -- stop recursion when a cycle is found
|
||||||
)
|
)
|
||||||
SELECT {pk_name} FROM traverse
|
SELECT id, is_circular
|
||||||
WHERE depth <= %(max_depth)s
|
FROM (
|
||||||
GROUP BY {pk_name}
|
SELECT
|
||||||
ORDER BY MAX(depth) DESC, {pk_name} ASC
|
id,
|
||||||
|
is_circular,
|
||||||
|
MAX(depth) AS max_depth
|
||||||
|
FROM traverse
|
||||||
|
WHERE depth <= %s
|
||||||
|
GROUP BY id, is_circular
|
||||||
|
) sub
|
||||||
|
ORDER BY max_depth DESC, id ASC;
|
||||||
""" # nosec b608
|
""" # nosec b608
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
pks = Field.objects.raw(raw_query, query_parameters)
|
|
||||||
return {item.pk for item in pks}
|
dep_ids = set()
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(
|
||||||
|
raw_query, (*params, field.pk, settings.MAX_FIELD_REFERENCE_DEPTH)
|
||||||
|
)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
for dep_id, is_circular in results:
|
||||||
|
if is_circular:
|
||||||
|
raise CircularFieldDependencyError()
|
||||||
|
elif dep_id is not None: # Avoid broken references
|
||||||
|
dep_ids.add(dep_id)
|
||||||
|
|
||||||
|
return dep_ids
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from baserow.contrib.database.fields.dependencies.circular_reference_checker import (
|
||||||
|
get_all_field_dependencies,
|
||||||
|
)
|
||||||
from baserow.contrib.database.fields.dependencies.exceptions import (
|
from baserow.contrib.database.fields.dependencies.exceptions import (
|
||||||
CircularFieldDependencyError,
|
CircularFieldDependencyError,
|
||||||
)
|
)
|
||||||
|
@ -368,3 +371,27 @@ def test_trashing_and_restoring_a_field_recreate_dependencies_correctly(data_fix
|
||||||
row_2.refresh_from_db()
|
row_2.refresh_from_db()
|
||||||
assert getattr(row_2, f1.db_column) == "b"
|
assert getattr(row_2, f1.db_column) == "b"
|
||||||
assert getattr(row_2, f2.db_column) == "b"
|
assert getattr(row_2, f2.db_column) == "b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_even_with_circular_dependencies_queries_finish_in_time(data_fixture):
|
||||||
|
# This should never happen, but if somehow circular dependencies are created
|
||||||
|
# we should still be able to get the dependencies of a field without running
|
||||||
|
# into an infinite loop in the recursive query.
|
||||||
|
|
||||||
|
user = data_fixture.create_user()
|
||||||
|
table = data_fixture.create_database_table(user=user)
|
||||||
|
f1 = data_fixture.create_formula_field(
|
||||||
|
name="f1", table=table, formula_type="text", formula="1"
|
||||||
|
)
|
||||||
|
f2 = data_fixture.create_formula_field(
|
||||||
|
name="f2", table=table, formula_type="text", formula="field('f1')"
|
||||||
|
)
|
||||||
|
# Forcefully create a circular dependency
|
||||||
|
f1.dependencies.create(dependency=f2)
|
||||||
|
|
||||||
|
with pytest.raises(CircularFieldDependencyError):
|
||||||
|
assert get_all_field_dependencies(f1) == {f2.id} # f1 -> f2
|
||||||
|
|
||||||
|
with pytest.raises(CircularFieldDependencyError):
|
||||||
|
assert get_all_field_dependencies(f2) == {f1.id} # f2 -> f1
|
||||||
|
|
Loading…
Add table
Reference in a new issue