mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-07 14:25:37 +00:00
Improve recursive CTE in get_all_field_dependencies
This commit is contained in:
parent
6c08674f49
commit
66e5b28f2f
3 changed files with 109 additions and 27 deletions
backend
src/baserow/contrib/database
tests/baserow/contrib/database/field/dependencies
|
@ -34,7 +34,8 @@ class BaserowFormulaSelectOptionsSerializer(serializers.ListField):
|
|||
# but let's avoid the potentially slow query if not required.
|
||||
if field_type.can_represent_select_options(field):
|
||||
select_options = SelectOption.objects.filter(
|
||||
field_id__in=get_all_field_dependencies(field)
|
||||
field_id__in=get_all_field_dependencies(field),
|
||||
field__trashed=False,
|
||||
)
|
||||
return [self.child.to_representation(item) for item in select_options]
|
||||
else:
|
||||
|
|
|
@ -22,48 +22,102 @@
|
|||
# this query so it works with our own database models and structure.
|
||||
#
|
||||
|
||||
from django.conf import settings
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
|
||||
from baserow.contrib.database.fields.dependencies.exceptions import (
|
||||
CircularFieldDependencyError,
|
||||
)
|
||||
from baserow.contrib.database.fields.dependencies.models import FieldDependency
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from baserow.contrib.database.fields.models import Field
|
||||
|
||||
|
||||
def will_cause_circular_dep(from_field, to_field):
|
||||
return from_field.id in get_all_field_dependencies(to_field)
|
||||
|
||||
|
||||
def get_all_field_dependencies(field):
|
||||
def get_all_field_dependencies(field: "Field") -> set[int]:
|
||||
"""
|
||||
Get all field dependencies for a field. This includes all fields that the given
|
||||
field depends on, directly or indirectly, even if the field have been trashed. For
|
||||
example, if the given field is a formula that references another formula which in
|
||||
turn references a text field, both the intermediate formula and the text field will
|
||||
be returned as dependencies.
|
||||
|
||||
This function uses a recursive CTE to traverse the field dependency graph and return
|
||||
all field ids that are reachable from the given field id. If a circular dependency
|
||||
is detected, a CircularFieldDependencyError is raised.
|
||||
|
||||
:param field: The field to get dependencies for.
|
||||
:return: A set of field ids that the given field depends on.
|
||||
:raises CircularFieldDependencyError: If a circular dependency is detected.
|
||||
"""
|
||||
|
||||
from baserow.contrib.database.fields.models import Field
|
||||
|
||||
query_parameters = {
|
||||
"pk": field.pk,
|
||||
"max_depth": settings.MAX_FIELD_REFERENCE_DEPTH,
|
||||
}
|
||||
relationship_table = FieldDependency._meta.db_table
|
||||
pk_name = "id"
|
||||
filtered_field_dependencies = FieldDependency.objects.filter(
|
||||
dependant_id__table__database_id=Field.objects_and_trash.filter(pk=field.pk)
|
||||
.order_by()
|
||||
.values("table__database_id")[:1]
|
||||
)
|
||||
sql, params = filtered_field_dependencies.query.get_compiler(
|
||||
connection=connection
|
||||
).as_sql()
|
||||
|
||||
# Only pk_name and a table name get formatted in, no user controllable input, safe.
|
||||
# fmt: off
|
||||
raw_query = (
|
||||
f"""
|
||||
WITH RECURSIVE traverse({pk_name}, depth) AS (
|
||||
SELECT first.dependency_id, 1
|
||||
FROM {relationship_table} AS first
|
||||
LEFT OUTER JOIN {relationship_table} AS second
|
||||
ON first.dependency_id = second.dependant_id
|
||||
WHERE first.dependant_id = %(pk)s
|
||||
UNION
|
||||
SELECT DISTINCT dependency_id, traverse.depth + 1
|
||||
FROM traverse
|
||||
INNER JOIN {relationship_table}
|
||||
ON {relationship_table}.dependant_id = traverse.{pk_name}
|
||||
WHERE 1 = 1
|
||||
WITH RECURSIVE dependencies AS ({sql}),
|
||||
traverse(id, depth, path, is_circular) AS (
|
||||
SELECT
|
||||
dependency_id,
|
||||
1,
|
||||
ARRAY[dependant_id, dependency_id],
|
||||
FALSE
|
||||
FROM dependencies
|
||||
WHERE dependant_id = %s
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
d.dependency_id,
|
||||
traverse.depth + 1,
|
||||
path || d.dependency_id,
|
||||
d.dependency_id = ANY(path) OR traverse.is_circular -- detect circularity
|
||||
FROM traverse
|
||||
INNER JOIN dependencies d ON d.dependant_id = traverse.id
|
||||
WHERE NOT traverse.is_circular -- stop recursion when a cycle is found
|
||||
)
|
||||
SELECT {pk_name} FROM traverse
|
||||
WHERE depth <= %(max_depth)s
|
||||
GROUP BY {pk_name}
|
||||
ORDER BY MAX(depth) DESC, {pk_name} ASC
|
||||
SELECT id, is_circular
|
||||
FROM (
|
||||
SELECT
|
||||
id,
|
||||
is_circular,
|
||||
MAX(depth) AS max_depth
|
||||
FROM traverse
|
||||
WHERE depth <= %s
|
||||
GROUP BY id, is_circular
|
||||
) sub
|
||||
ORDER BY max_depth DESC, id ASC;
|
||||
""" # nosec b608
|
||||
)
|
||||
# fmt: on
|
||||
pks = Field.objects.raw(raw_query, query_parameters)
|
||||
return {item.pk for item in pks}
|
||||
|
||||
dep_ids = set()
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
raw_query, (*params, field.pk, settings.MAX_FIELD_REFERENCE_DEPTH)
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
for dep_id, is_circular in results:
|
||||
if is_circular:
|
||||
raise CircularFieldDependencyError()
|
||||
elif dep_id is not None: # Avoid broken references
|
||||
dep_ids.add(dep_id)
|
||||
|
||||
return dep_ids
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import pytest
|
||||
|
||||
from baserow.contrib.database.fields.dependencies.circular_reference_checker import (
|
||||
get_all_field_dependencies,
|
||||
)
|
||||
from baserow.contrib.database.fields.dependencies.exceptions import (
|
||||
CircularFieldDependencyError,
|
||||
)
|
||||
|
@ -368,3 +371,27 @@ def test_trashing_and_restoring_a_field_recreate_dependencies_correctly(data_fix
|
|||
row_2.refresh_from_db()
|
||||
assert getattr(row_2, f1.db_column) == "b"
|
||||
assert getattr(row_2, f2.db_column) == "b"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_even_with_circular_dependencies_queries_finish_in_time(data_fixture):
|
||||
# This should never happen, but if somehow circular dependencies are created
|
||||
# we should still be able to get the dependencies of a field without running
|
||||
# into an infinite loop in the recursive query.
|
||||
|
||||
user = data_fixture.create_user()
|
||||
table = data_fixture.create_database_table(user=user)
|
||||
f1 = data_fixture.create_formula_field(
|
||||
name="f1", table=table, formula_type="text", formula="1"
|
||||
)
|
||||
f2 = data_fixture.create_formula_field(
|
||||
name="f2", table=table, formula_type="text", formula="field('f1')"
|
||||
)
|
||||
# Forcefully create a circular dependency
|
||||
f1.dependencies.create(dependency=f2)
|
||||
|
||||
with pytest.raises(CircularFieldDependencyError):
|
||||
assert get_all_field_dependencies(f1) == {f2.id} # f1 -> f2
|
||||
|
||||
with pytest.raises(CircularFieldDependencyError):
|
||||
assert get_all_field_dependencies(f2) == {f1.id} # f2 -> f1
|
||||
|
|
Loading…
Add table
Reference in a new issue