diff --git a/backend/src/baserow/contrib/database/api/formula/serializers.py b/backend/src/baserow/contrib/database/api/formula/serializers.py index f777091f3..34a4c35b6 100644 --- a/backend/src/baserow/contrib/database/api/formula/serializers.py +++ b/backend/src/baserow/contrib/database/api/formula/serializers.py @@ -34,7 +34,8 @@ class BaserowFormulaSelectOptionsSerializer(serializers.ListField): # but let's avoid the potentially slow query if not required. if field_type.can_represent_select_options(field): select_options = SelectOption.objects.filter( - field_id__in=get_all_field_dependencies(field) + field_id__in=get_all_field_dependencies(field), + field__trashed=False, ) return [self.child.to_representation(item) for item in select_options] else: diff --git a/backend/src/baserow/contrib/database/fields/dependencies/circular_reference_checker.py b/backend/src/baserow/contrib/database/fields/dependencies/circular_reference_checker.py index 403f2d376..f173efe3c 100644 --- a/backend/src/baserow/contrib/database/fields/dependencies/circular_reference_checker.py +++ b/backend/src/baserow/contrib/database/fields/dependencies/circular_reference_checker.py @@ -22,48 +22,102 @@ # this query so it works with our own database models and structure. # -from django.conf import settings +from typing import TYPE_CHECKING +from django.conf import settings +from django.db import connection + +from baserow.contrib.database.fields.dependencies.exceptions import ( + CircularFieldDependencyError, +) from baserow.contrib.database.fields.dependencies.models import FieldDependency +if TYPE_CHECKING: + from baserow.contrib.database.fields.models import Field + def will_cause_circular_dep(from_field, to_field): return from_field.id in get_all_field_dependencies(to_field) -def get_all_field_dependencies(field): +def get_all_field_dependencies(field: "Field") -> set[int]: + """ + Get all field dependencies for a field. This includes all fields that the given + field depends on, directly or indirectly, even if the field have been trashed. For + example, if the given field is a formula that references another formula which in + turn references a text field, both the intermediate formula and the text field will + be returned as dependencies. + + This function uses a recursive CTE to traverse the field dependency graph and return + all field ids that are reachable from the given field id. If a circular dependency + is detected, a CircularFieldDependencyError is raised. + + :param field: The field to get dependencies for. + :return: A set of field ids that the given field depends on. + :raises CircularFieldDependencyError: If a circular dependency is detected. + """ + from baserow.contrib.database.fields.models import Field - query_parameters = { - "pk": field.pk, - "max_depth": settings.MAX_FIELD_REFERENCE_DEPTH, - } - relationship_table = FieldDependency._meta.db_table - pk_name = "id" + filtered_field_dependencies = FieldDependency.objects.filter( + dependant_id__table__database_id=Field.objects_and_trash.filter(pk=field.pk) + .order_by() + .values("table__database_id")[:1] + ) + sql, params = filtered_field_dependencies.query.get_compiler( + connection=connection + ).as_sql() # Only pk_name and a table name get formatted in, no user controllable input, safe. # fmt: off raw_query = ( f""" - WITH RECURSIVE traverse({pk_name}, depth) AS ( - SELECT first.dependency_id, 1 - FROM {relationship_table} AS first - LEFT OUTER JOIN {relationship_table} AS second - ON first.dependency_id = second.dependant_id - WHERE first.dependant_id = %(pk)s - UNION - SELECT DISTINCT dependency_id, traverse.depth + 1 - FROM traverse - INNER JOIN {relationship_table} - ON {relationship_table}.dependant_id = traverse.{pk_name} - WHERE 1 = 1 + WITH RECURSIVE dependencies AS ({sql}), + traverse(id, depth, path, is_circular) AS ( + SELECT + dependency_id, + 1, + ARRAY[dependant_id, dependency_id], + FALSE + FROM dependencies + WHERE dependant_id = %s + + UNION ALL + + SELECT + d.dependency_id, + traverse.depth + 1, + path || d.dependency_id, + d.dependency_id = ANY(path) OR traverse.is_circular -- detect circularity + FROM traverse + INNER JOIN dependencies d ON d.dependant_id = traverse.id + WHERE NOT traverse.is_circular -- stop recursion when a cycle is found ) - SELECT {pk_name} FROM traverse - WHERE depth <= %(max_depth)s - GROUP BY {pk_name} - ORDER BY MAX(depth) DESC, {pk_name} ASC + SELECT id, is_circular + FROM ( + SELECT + id, + is_circular, + MAX(depth) AS max_depth + FROM traverse + WHERE depth <= %s + GROUP BY id, is_circular + ) sub + ORDER BY max_depth DESC, id ASC; """ # nosec b608 ) # fmt: on - pks = Field.objects.raw(raw_query, query_parameters) - return {item.pk for item in pks} + + dep_ids = set() + with connection.cursor() as cursor: + cursor.execute( + raw_query, (*params, field.pk, settings.MAX_FIELD_REFERENCE_DEPTH) + ) + results = cursor.fetchall() + for dep_id, is_circular in results: + if is_circular: + raise CircularFieldDependencyError() + elif dep_id is not None: # Avoid broken references + dep_ids.add(dep_id) + + return dep_ids diff --git a/backend/tests/baserow/contrib/database/field/dependencies/test_dependency_rebuilder.py b/backend/tests/baserow/contrib/database/field/dependencies/test_dependency_rebuilder.py index c57d784ab..81fe5f5f4 100644 --- a/backend/tests/baserow/contrib/database/field/dependencies/test_dependency_rebuilder.py +++ b/backend/tests/baserow/contrib/database/field/dependencies/test_dependency_rebuilder.py @@ -1,5 +1,8 @@ import pytest +from baserow.contrib.database.fields.dependencies.circular_reference_checker import ( + get_all_field_dependencies, +) from baserow.contrib.database.fields.dependencies.exceptions import ( CircularFieldDependencyError, ) @@ -368,3 +371,27 @@ def test_trashing_and_restoring_a_field_recreate_dependencies_correctly(data_fix row_2.refresh_from_db() assert getattr(row_2, f1.db_column) == "b" assert getattr(row_2, f2.db_column) == "b" + + +@pytest.mark.django_db +def test_even_with_circular_dependencies_queries_finish_in_time(data_fixture): + # This should never happen, but if somehow circular dependencies are created + # we should still be able to get the dependencies of a field without running + # into an infinite loop in the recursive query. + + user = data_fixture.create_user() + table = data_fixture.create_database_table(user=user) + f1 = data_fixture.create_formula_field( + name="f1", table=table, formula_type="text", formula="1" + ) + f2 = data_fixture.create_formula_field( + name="f2", table=table, formula_type="text", formula="field('f1')" + ) + # Forcefully create a circular dependency + f1.dependencies.create(dependency=f2) + + with pytest.raises(CircularFieldDependencyError): + assert get_all_field_dependencies(f1) == {f2.id} # f1 -> f2 + + with pytest.raises(CircularFieldDependencyError): + assert get_all_field_dependencies(f2) == {f1.id} # f2 -> f1