1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-10 23:50:12 +00:00

Resolve "Optimize RBAC performance"

This commit is contained in:
Jrmi 2023-01-11 14:19:38 +00:00
parent c64cca3c66
commit 35da3622dd
20 changed files with 659 additions and 180 deletions

5
.gitignore vendored
View file

@ -92,8 +92,9 @@ __pycache__
docker-compose.override.yml
# python virtual envs
# python
venv/
.profiles/
web-frontend/plugins/
backend/plugins/
@ -132,4 +133,4 @@ field-diagrams/
# Intellij needs this package.json to allow running tests from the IDE but this isn't
# actually a node module and so we ignore it.
premium/web-frontend/package.json
premium/web-frontend/package.json

View file

@ -1,3 +1,5 @@
from django.db.models import Q
from baserow.contrib.database.models import Field
from baserow.contrib.database.object_scopes import DatabaseObjectScopeType
from baserow.contrib.database.table.object_scopes import DatabaseTableObjectScopeType
@ -15,17 +17,22 @@ class FieldObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.table
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related(
"table", "table__database", "table__database__group"
)
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Field.objects.filter(table__database__group=scope)
return Q(table__database__group__in=[s.id for s in scopes])
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return Field.objects.filter(table__database=scope)
return Q(table__database__in=[s.id for s in scopes])
if scope_type.type == DatabaseTableObjectScopeType.type:
return Field.objects.filter(table=scope)
if scope_type.type == FieldObjectScopeType.type:
return [scope]
return []
return Q(table__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -1,3 +1,5 @@
from django.db.models import Q
from baserow.contrib.database.models import Database
from baserow.core.object_scopes import ApplicationObjectScopeType, GroupObjectScopeType
from baserow.core.registries import ObjectScopeType, object_scope_type_registry
@ -15,14 +17,13 @@ class DatabaseObjectScopeType(ObjectScopeType):
# but it's a more generic type
return context.application_ptr
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
if scope_type.type == GroupObjectScopeType.type:
return Database.objects.filter(group=scope)
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return [scope]
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("group")
return []
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Q(group__in=[s.id for s in scopes])
if scope_type.type == ApplicationObjectScopeType.type:
return Q(id__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -1,3 +1,5 @@
from django.db.models import Q
from baserow.contrib.database.object_scopes import DatabaseObjectScopeType
from baserow.contrib.database.table.models import Table
from baserow.core.object_scopes import ApplicationObjectScopeType, GroupObjectScopeType
@ -14,15 +16,17 @@ class DatabaseTableObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.database
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("database", "database__group")
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Table.objects.filter(database__group=scope.id)
return Q(database__group__in=[s.id for s in scopes])
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return Table.objects.filter(database=scope.id)
if scope_type.type == DatabaseTableObjectScopeType.type:
return [scope]
return []
return Q(database__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -33,12 +33,10 @@ class RestoreDatabaseTableOperationType(DatabaseTableOperationType):
class ListRowsDatabaseTableOperationType(DatabaseTableOperationType):
type = "database.table.list_rows"
object_scope_name = "database_row"
class ListRowNamesDatabaseTableOperationType(DatabaseTableOperationType):
type = "database.table.list_row_names"
object_scope_name = "database_row"
class ListAggregationDatabaseTableOperationType(DatabaseTableOperationType):

View file

@ -1,8 +1,8 @@
from typing import Iterable
from django.db.models import Q
from baserow.contrib.database.tokens.models import Token
from baserow.core.object_scopes import GroupObjectScopeType
from baserow.core.registries import ObjectScopeType, object_scope_type_registry
from baserow.core.types import ScopeObject
class TokenObjectScopeType(ObjectScopeType):
@ -16,10 +16,8 @@ class TokenObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.group
def get_all_context_objects_in_scope(self, scope: ScopeObject) -> Iterable:
scope_type = object_scope_type_registry.get_by_model(scope)
if scope_type.type == "group":
return Token.objects.filter(group=scope.id)
if scope_type.type == self.type:
return [scope]
return []
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Q(group__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -1,4 +1,4 @@
from typing import Iterable
from django.db.models import Q
from baserow.contrib.database.object_scopes import DatabaseObjectScopeType
from baserow.contrib.database.table.object_scopes import DatabaseTableObjectScopeType
@ -10,7 +10,6 @@ from baserow.contrib.database.views.models import (
)
from baserow.core.object_scopes import ApplicationObjectScopeType, GroupObjectScopeType
from baserow.core.registries import ObjectScopeType, object_scope_type_registry
from baserow.core.types import ScopeObject
class DatabaseViewObjectScopeType(ObjectScopeType):
@ -24,20 +23,25 @@ class DatabaseViewObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.table
def get_all_context_objects_in_scope(self, scope: ScopeObject) -> Iterable:
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related(
"table", "table__database", "table__database__group"
)
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return View.objects.filter(table__database__group=scope.id)
return Q(table__database__group__in=[s.id for s in scopes])
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return View.objects.filter(table__database=scope.id)
return Q(table__database__in=[s.id for s in scopes])
if scope_type.type == DatabaseTableObjectScopeType.type:
return View.objects.filter(table=scope.id)
if scope_type.type == self.type:
return [scope]
return []
return Q(table__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")
class DatabaseViewDecorationObjectScopeType(ObjectScopeType):
@ -51,22 +55,31 @@ class DatabaseViewDecorationObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.view
def get_all_context_objects_in_scope(self, scope: ScopeObject) -> Iterable:
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related(
"view",
"view__table",
"view__table__database",
"view__table__database__group",
)
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return ViewDecoration.objects.filter(view__table__database__group=scope.id)
return Q(view_table__database__group__in=[s.id for s in scopes])
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return ViewDecoration.objects.filter(view__table__database=scope.id)
return Q(view__table__database__in=[s.id for s in scopes])
if scope_type.type == DatabaseTableObjectScopeType.type:
return ViewDecoration.objects.filter(view__table=scope.id)
return Q(view__table__in=[s.id for s in scopes])
if scope_type.type == DatabaseViewObjectScopeType.type:
return ViewDecoration.objects.filter(view=scope.id)
if scope_type.type == self.type:
return [scope]
return []
return Q(view__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")
class DatabaseViewSortObjectScopeType(ObjectScopeType):
@ -80,22 +93,31 @@ class DatabaseViewSortObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.view
def get_all_context_objects_in_scope(self, scope: ScopeObject) -> Iterable:
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related(
"view",
"view__table",
"view__table__database",
"view__table__database__group",
)
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return ViewSort.objects.filter(view__table__database__group=scope.id)
return Q(view_table__database__group__in=[s.id for s in scopes])
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return ViewSort.objects.filter(view__table__database=scope.id)
return Q(view__table__database__in=[s.id for s in scopes])
if scope_type.type == DatabaseTableObjectScopeType.type:
return ViewSort.objects.filter(view__table=scope.id)
return Q(view__table__in=[s.id for s in scopes])
if scope_type.type == DatabaseViewObjectScopeType.type:
return ViewSort.objects.filter(view=scope.id)
if scope_type.type == self.type:
return [scope]
return []
return Q(view__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")
class DatabaseViewFilterObjectScopeType(ObjectScopeType):
@ -109,19 +131,28 @@ class DatabaseViewFilterObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.view
def get_all_context_objects_in_scope(self, scope: ScopeObject) -> Iterable:
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related(
"view",
"view__table",
"view__table__database",
"view__table__database__group",
)
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return ViewFilter.objects.filter(view__table__database__group=scope.id)
return Q(view_table__database__group__in=[s.id for s in scopes])
if (
scope_type.type == DatabaseObjectScopeType.type
or scope_type.type == ApplicationObjectScopeType.type
):
return ViewFilter.objects.filter(view__table__database=scope.id)
return Q(view__table__database__in=[s.id for s in scopes])
if scope_type.type == DatabaseTableObjectScopeType.type:
return ViewFilter.objects.filter(view__table=scope.id)
return Q(view__table__in=[s.id for s in scopes])
if scope_type.type == DatabaseViewObjectScopeType.type:
return ViewFilter.objects.filter(view=scope.id)
if scope_type.type == self.type:
return [scope]
return []
return Q(view__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -28,6 +28,7 @@ class CreateViewSortOperationType(ViewOperationType):
class ListViewSortOperationType(ViewOperationType):
type = "database.table.view.list_sort"
object_scope_name = DatabaseViewSortObjectScopeType.type
class ReadViewSortOperationType(ViewSortOperationType):

View file

@ -1,3 +1,5 @@
from django.db.models import Q
from baserow.core.models import Application, Group, GroupInvitation, GroupUser
from baserow.core.registries import ObjectScopeType, object_scope_type_registry
@ -6,38 +8,37 @@ class CoreObjectScopeType(ObjectScopeType):
model_class = type(None)
type = "core"
def get_all_context_objects_in_scope(self, scope):
return []
def get_filter_for_scope_type(self, scope_type, scopes):
raise TypeError("The given type is not handled.")
class GroupObjectScopeType(ObjectScopeType):
type = "group"
model_class = Group
def get_all_context_objects_in_scope(self, scope):
if object_scope_type_registry.get_by_model(scope).type == self.type:
return [scope]
return []
def get_filter_for_scope_type(self, scope_type, scopes):
raise TypeError("The given type is not handled.")
class ApplicationObjectScopeType(ObjectScopeType):
type = "application"
model_class = Application
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
if scope_type.type == GroupObjectScopeType.type:
return Application.objects.filter(group=scope)
if scope_type.type == self.type:
return [scope]
return []
def get_parent_scope(self):
return object_scope_type_registry.get("group")
def get_parent(self, context):
return context.group
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("group")
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Q(group__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")
class GroupInvitationObjectScopeType(ObjectScopeType):
type = "group_invitation"
@ -49,13 +50,14 @@ class GroupInvitationObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.group
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("group")
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return GroupInvitation.objects.filter(group=scope)
if scope_type.type == self.type:
return [scope]
return []
return Q(group__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")
class GroupUserObjectScopeType(ObjectScopeType):
@ -68,10 +70,11 @@ class GroupUserObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.group
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("group")
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return GroupUser.objects.filter(group=scope)
if scope_type.type == self.type:
return [scope]
return []
return Q(group__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -1,11 +1,12 @@
import abc
from collections import defaultdict
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
from xmlrpc.client import Boolean
from zipfile import ZipFile
from django.core.files.storage import Storage
from django.db.models import QuerySet
from django.db.models import Q, QuerySet
from django.db.transaction import Atomic
from rest_framework.serializers import Serializer
@ -498,6 +499,14 @@ class ObjectScopeType(Instance, ModelInstanceMixin):
context.
"""
def get_content_type(self):
from django.contrib.contenttypes.models import ContentType
return ContentType.objects.get_for_model(self.model_class)
def get_object_for_this_type(self, **kwargs):
return self.get_content_type().get_object_for_this_type(**kwargs)
def get_parent_scope(self) -> Optional["ObjectScopeType"]:
"""
Returns the parent scope of the current scope.
@ -507,6 +516,19 @@ class ObjectScopeType(Instance, ModelInstanceMixin):
return None
def get_parent_scopes(self) -> List["ObjectScopeType"]:
"""
Returns the parent scope of the current scope.
:return: the parent `ObjectScopeType` or `None` if it's a root scope.
"""
parent_scope = self.get_parent_scope()
if not parent_scope:
return []
return [parent_scope] + parent_scope.get_parent_scopes()
def get_parent(self, context: ContextObject) -> Optional[ContextObject]:
"""
Returns the parent object of the given context which belongs to the current
@ -548,10 +570,101 @@ class ObjectScopeType(Instance, ModelInstanceMixin):
:return: An iterable containing the context objects for the given scope.
"""
return self.get_objects_in_scopes([scope])[scope]
def get_filter_for_scope_type(
self, scope_type: "ObjectScopeType", scopes: List[Any]
) -> Q:
"""
Returns the filter to apply to the queryset that selects all the context
objects included in the given scopes.
All the scopes must be members of the given scope type.
:param scope_type: The scope type the scopes belongs to.
:param scopes: The scopes objects we want the context object for.
:return: A Q object that can be used in a filter operation.
"""
raise NotImplementedError(
f"Must be implemented by the specific type <{self.type}>"
)
def get_enhanced_queryset(self) -> QuerySet:
"""
Returns the base queryset for the objects of this scope enhanced for better
performances.
"""
return self.model_class.objects.all()
def get_filter_for_scopes(self, scopes: List[Any]) -> Dict[Any, Any]:
"""
Computes the filter to apply get all the objects instance of `self.model_class`
included in the given scopes.
:param scopes: A list of scopes we want the object for.
:return: A Q object filter.
"""
# Group scope by types to use `.get_filter_for_scope_type` later
scope_by_types = defaultdict(set)
for s in scopes:
scope_by_types[object_scope_type_registry.get_by_model(s)].add(s)
union_query = Q(id__in=[])
for scope_type, scopes in scope_by_types.items():
if scope_type.type == self.type:
# Simple case: the scope type is the same as this one
# Just filter by id
union_query |= Q(id__in=[s.id for s in scopes])
else:
# Otherwise it's a parent scope. We add a part to the query_parts
union_query |= self.get_filter_for_scope_type(scope_type, scopes)
return union_query
def get_objects_in_scopes(self, scopes: List[Any]) -> Dict[Any, Any]:
"""
Computes the list of all objects, instance of the model_class property
included in the given scopes.
:param scopes: A list of scopes we want the object for.
:return: A dict where the keys are the given scopes and the value is a list
of the child objects of each scope.
"""
objects_per_scope = {}
parent_scopes = []
for scope in scopes:
if object_scope_type_registry.get_by_model(scope).type == self.type:
# Scope of the same type doesn't need to be queried
objects_per_scope[scope] = set([scope])
else:
parent_scopes.append(scope)
if parent_scopes:
query_result = list(
self.get_enhanced_queryset().filter(
self.get_filter_for_scopes(parent_scopes)
)
)
# We have all the objects in the queryset, but now we want to sort them
# into buckets per original scope they are a child of.
for scope in scopes:
objects_per_scope[scope] = set()
scope_type = object_scope_type_registry.get_by_model(scope)
for obj in query_result:
parent_scope = object_scope_type_registry.get_parent(
obj, at_scope_type=scope_type
)
if parent_scope == scope:
objects_per_scope[scope].add(obj)
return objects_per_scope
def contains(self, context: ContextObject):
"""
Returns True if the context is one object of this context.
@ -695,7 +808,7 @@ class SubjectType(abc.ABC, Instance, ModelInstanceMixin):
of subject that is being serialized
:param model_instance: instance of a subject
:param kwargs: additional kwargs that are parsed to serializer
:return: the correct seralizer for the subject
:return: the correct serializer for the subject
"""
pass

View file

@ -1,3 +1,5 @@
from django.db.models import Q
from baserow.core.models import Snapshot
from baserow.core.object_scopes import ApplicationObjectScopeType, GroupObjectScopeType
from baserow.core.registries import ObjectScopeType, object_scope_type_registry
@ -13,12 +15,16 @@ class SnapshotObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.snapshot_from_application.specific
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope)
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related(
"snapshot_from_application", "snapshot_from_application__group"
)
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Snapshot.objects.filter(snapshot_from_application__group=scope)
return Q(snapshot_from_application__group__in=[s.id for s in scopes])
if scope_type.type == ApplicationObjectScopeType.type:
return Snapshot.objects.filter(snapshot_from_application=scope)
if scope_type.type == self.type:
return [scope]
return []
return Q(snapshot_from_application__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -1,12 +1,15 @@
import asyncio
import contextlib
import os
from pathlib import Path
from typing import Dict, Optional
from django.apps import apps
from django.core.management import call_command
from django.db import DEFAULT_DB_ALIAS
import pytest
from pyinstrument import Profiler
from baserow.core.apps import sync_operations_after_migrate
from baserow_enterprise.apps import sync_default_roles_after_migrate
@ -172,3 +175,50 @@ def pytest_collection_modifyitems(config, items):
)
item.add_marker(skip_marker)
break
@pytest.fixture()
def profiler():
"""
A fixture to provide an easy way to profile code in your tests.
"""
TESTS_ROOT = Path.cwd()
PROFILE_ROOT = TESTS_ROOT / ".profiles"
profiler = Profiler()
@contextlib.contextmanager
def profile_this(
print_result: bool = True,
html_report_name: str = "",
output_text_params: Optional[Dict] = None,
output_html_params: Optional[Dict] = None,
):
"""
Context manager to profile something.
"""
profiler.start()
yield profiler
profiler.stop()
output_text_params = output_text_params or {}
output_html_params = output_html_params or {}
output_text_params.setdefault("unicode", True)
output_text_params.setdefault("color", True)
if print_result:
print(profiler.output_text(**output_text_params))
if html_report_name:
PROFILE_ROOT.mkdir(exist_ok=True)
results_file = PROFILE_ROOT / f"{html_report_name}.html"
with open(results_file, "w", encoding="utf-8") as f_html:
f_html.write(profiler.output_html(**output_html_params))
profiler.reset()
return profile_this

View file

@ -1,6 +1,7 @@
import inspect
from django.contrib.auth.models import AnonymousUser
from django.db.models import Q, QuerySet
from django.test.utils import override_settings
import pytest
@ -217,3 +218,17 @@ def test_all_scope_types_referenced_by_operations_are_registered():
) + " or somehow the following context types are registered but not subclasses?: " + str(
object_scope_types.difference(all_op_context_types)
)
@pytest.mark.django_db
def test_all_scope_types_query_methods():
all_scope_type = object_scope_type_registry.get_all()
for scope_type in all_scope_type:
if scope_type.type == "core":
continue
assert isinstance(scope_type.get_enhanced_queryset(), QuerySet)
for parent in scope_type.get_parent_scopes():
assert isinstance(scope_type.get_filter_for_scope_type(parent, []), Q)

View file

@ -106,7 +106,9 @@ class AssignRoleActionType(ActionType):
role_assignment_handler = RoleAssignmentHandler()
group = Group.objects.get(id=params.group_id)
scope = role_assignment_handler.get_scope(params.scope_id, params.scope_type)
scope_type = object_scope_type_registry.get(params.scope_type)
scope = scope_type.get_object_for_this_type(id=params.scope_id)
LicenseHandler.raise_if_user_doesnt_have_feature(RBAC, user, group)
@ -141,10 +143,10 @@ class AssignRoleActionType(ActionType):
role_assignment_handler = RoleAssignmentHandler()
group = Group.objects.get(id=params.group_id)
scope = role_assignment_handler.get_scope(params.scope_id, params.scope_type)
scope_type = object_scope_type_registry.get(params.scope_type)
scope = scope_type.get_object_for_this_type(id=params.scope_id)
LicenseHandler.raise_if_user_doesnt_have_feature(RBAC, user, group)
scope_type = object_scope_type_registry.get_by_model(scope)
CoreHandler().check_permissions(
user,

View file

@ -20,7 +20,12 @@ from baserow_enterprise.signals import (
)
from baserow_enterprise.teams.models import Team, TeamSubject
from .constants import NO_ACCESS_ROLE, NO_ROLE_LOW_PRIORITY_ROLE, SUBJECT_PRIORITY
from .constants import (
NO_ACCESS_ROLE,
NO_ROLE_LOW_PRIORITY_ROLE,
ROLE_ASSIGNABLE_OBJECT_MAP,
SUBJECT_PRIORITY,
)
User = get_user_model()
@ -192,7 +197,7 @@ class RoleAssignmentHandler:
actor_subject_type = subject_type_registry.get_by_model(actor)
content_types = ContentType.objects.get_for_models(
actor_subject_type.model_class, Team, Group
*[s.model_class for s in subject_type_registry.get_all()], Group
)
subjects_q = Q(
@ -218,7 +223,10 @@ class RoleAssignmentHandler:
scope_type=ContentType.objects.get_for_model(scope_type.model_class),
then=Value(scope_type.level),
)
for scope_type in object_scope_type_registry.get_all()
for scope_type in [
object_scope_type_registry.get(name)
for name in ROLE_ASSIGNABLE_OBJECT_MAP
]
if scope_type.type != CoreObjectScopeType.type
]
@ -258,23 +266,27 @@ class RoleAssignmentHandler:
.select_related("subject_type")
)
roles_by_scope = {group: []}
group_scope_param = (group.id, content_types[Group].id)
# we are using a tuple of (scope.id, content_type.id) to prevent the query
# automatically done when accessing the property from the role assignments
# to query them all at once later
roles_by_scope = {group_scope_param: []}
priorities_by_scope = {}
for role_assignment in role_assignments:
scope = role_assignment.scope
scope_param = (role_assignment.scope_id, role_assignment.scope_type_id)
role = self.get_role_by_id(role_assignment.role_id)
priority = role_assignment.role_priority
# We don't use defaultdict here to be sure we have the right key order
if scope not in roles_by_scope:
roles_by_scope[scope] = []
if scope_param not in roles_by_scope:
roles_by_scope[scope_param] = []
existing_priority = priorities_by_scope.setdefault(scope, priority)
existing_priority = priorities_by_scope.setdefault(scope_param, priority)
if priority < existing_priority:
roles_by_scope[scope] = [role]
roles_by_scope[scope_param] = [role]
elif existing_priority == priority:
roles_by_scope[scope].append(role)
roles_by_scope[scope_param].append(role)
# Get the group level role by reading the GroupUser permissions property for
# User actors.
@ -290,13 +302,19 @@ class RoleAssignmentHandler:
)
if group_level_role.uid == NO_ROLE_LOW_PRIORITY_ROLE:
# Low priority role -> Use team role or NO_ACCESS if no team role
if not roles_by_scope.get(group):
roles_by_scope[group] = [self.get_role_by_uid(NO_ACCESS_ROLE)]
if not roles_by_scope.get(group_scope_param):
roles_by_scope[group_scope_param] = [
self.get_role_by_uid(NO_ACCESS_ROLE)
]
else:
# Otherwise user role wins
roles_by_scope[group] = [group_level_role]
roles_by_scope[group_scope_param] = [group_level_role]
return list(roles_by_scope.items())
roles_by_scope = [
(self.get_scope(*key), value) for (key, value) in roles_by_scope.items()
]
return roles_by_scope
def get_computed_roles(
self, group: Group, actor: AbstractUser, context: Any, include_trash=False
@ -460,18 +478,16 @@ class RoleAssignmentHandler:
content_type = subject_type_registry.get(subject_type).get_content_type()
return content_type.get_object_for_this_type(id=subject_id)
def get_scope(self, scope_id: int, scope_type: str):
def get_scope(self, scope_id: int, content_type_id: int) -> Any:
"""
Helper method that returns the actual scope object given the scope ID and
the scope type.
:param scope_id: The scope id.
:param scope_type: The scope type. This type must be registered in the
`object_scope_registry`.
:param content_type_id: The content_type id
"""
scope_type = object_scope_type_registry.get(scope_type)
content_type = ContentType.objects.get_for_model(scope_type.model_class)
content_type = ContentType.objects.get_for_id(content_type_id)
return content_type.get_object_for_this_type(id=scope_id)
def assign_role_batch(

View file

@ -137,10 +137,9 @@ class RolePermissionManagerType(PermissionManagerType):
for (scope, roles) in roles_by_scope[1:]:
allowed_operations = set()
[
for role in roles:
allowed_operations.update(self.get_role_operations(role))
for role in roles
]
scope_type = object_scope_type_registry.get_by_model(scope)
@ -151,22 +150,20 @@ class RolePermissionManagerType(PermissionManagerType):
scope_type, base_scope_type
):
context_exceptions = list(
base_scope_type.get_all_context_objects_in_scope(scope)
)
context_exception = scope
# Remove or add exceptions to the exception list according to the
# default policy for the group
if operation_type.type not in allowed_operations:
if default:
exceptions |= set(context_exceptions)
exceptions.add(context_exception)
else:
exceptions = exceptions.difference(context_exceptions)
exceptions.discard(context_exception)
else:
if default:
exceptions = exceptions.difference(context_exceptions)
exceptions.discard(context_exception)
else:
exceptions |= set(context_exceptions)
exceptions.add(context_exception)
# Second case
# The scope of the role assignment is included by the role of the operation
# And we are doing a read operation
@ -188,13 +185,12 @@ class RolePermissionManagerType(PermissionManagerType):
if default:
if found_object in exceptions:
exceptions.remove(found_object)
exceptions.discard(found_object)
else:
exceptions.add(found_object)
return default, exceptions
# Probably needs a cache?
def get_permissions_object(
self, actor: AbstractUser, group: Optional[Group] = None
) -> List[Dict[str, OperationPermissionContent]]:
@ -219,19 +215,50 @@ class RolePermissionManagerType(PermissionManagerType):
# Get all role assignments for this actor into this group
roles_by_scope = RoleAssignmentHandler().get_roles_per_scope(group, actor)
result = defaultdict(lambda: {"default": False, "exceptions": []})
policy_per_operation = defaultdict(lambda: {"default": False, "exceptions": []})
exceptions_with_mixed_types_per_scope = defaultdict(set)
# First, for each operation we want the default policy and exceptions
for operation_type in operation_type_registry.get_all():
default, exceptions = self.get_operation_policy(
roles_by_scope, operation_type
)
if default or exceptions:
result[operation_type.type]["default"] = default
result[operation_type.type]["exceptions"] = [e.id for e in exceptions]
policy_per_operation[operation_type.type]["default"] = default
policy_per_operation[operation_type.type]["exceptions"] = exceptions
return result
if exceptions:
# We store the exceptions by scope to get all objects at once later
exceptions_with_mixed_types_per_scope[
operation_type.context_scope
] |= exceptions
# Get all objects for all exceptions at once to improve perfs
exception_ids_per_scope = {}
for object_scope, exceptions in exceptions_with_mixed_types_per_scope.items():
exception_ids_per_scope[object_scope] = {
scope: {o.id for o in exc}
for scope, exc in object_scope.get_objects_in_scopes(exceptions).items()
}
# Dispatch actual context object ids for each exceptions scopes
policy_per_operation_with_exception_ids = {}
for operation_type in operation_type_registry.get_all():
# Gather all ids for all scopes of the exception list of this operation
exceptions_ids = set()
for scope in policy_per_operation[operation_type.type]["exceptions"]:
exceptions_ids |= exception_ids_per_scope[operation_type.context_scope][
scope
]
policy_per_operation_with_exception_ids[operation_type.type] = {
"default": policy_per_operation[operation_type.type]["default"],
"exceptions": list(exceptions_ids),
}
return policy_per_operation_with_exception_ids
def filter_queryset(
self, actor, operation_name, queryset, group=None, context=None
@ -246,7 +273,6 @@ class RolePermissionManagerType(PermissionManagerType):
# Get all role assignments for this user into this group
roles_by_scope = RoleAssignmentHandler().get_roles_per_scope(group, actor)
print(roles_by_scope)
operation_type = operation_type_registry.get(operation_name)
@ -254,17 +280,17 @@ class RolePermissionManagerType(PermissionManagerType):
roles_by_scope, operation_type, True
)
exceptions = [e.id for e in exceptions]
exceptions_filter = operation_type.object_scope.get_filter_for_scopes(
exceptions
)
print(default, exceptions)
# Finally filter the queryset with the exception list.
# Finally filter the queryset with the exception filter.
if default:
if exceptions:
queryset = queryset.exclude(id__in=list(exceptions))
queryset = queryset.exclude(exceptions_filter)
else:
if exceptions:
queryset = queryset.filter(id__in=list(exceptions))
queryset = queryset.filter(exceptions_filter)
else:
queryset = queryset.none()

View file

@ -1,3 +1,5 @@
from django.db.models import Q
from baserow.core.object_scopes import GroupObjectScopeType
from baserow.core.registries import ObjectScopeType, object_scope_type_registry
from baserow_enterprise.models import Team, TeamSubject
@ -13,13 +15,15 @@ class TeamObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.group
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope).type
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("group")
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return Team.objects.filter(group=scope)
if object_scope_type_registry.get_by_model(scope).type == self.type:
return [scope]
return []
return Q(group__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")
class TeamSubjectObjectScopeType(ObjectScopeType):
@ -32,12 +36,15 @@ class TeamSubjectObjectScopeType(ObjectScopeType):
def get_parent(self, context):
return context.team
def get_all_context_objects_in_scope(self, scope):
scope_type = object_scope_type_registry.get_by_model(scope).type
def get_enhanced_queryset(self):
return self.model_class.objects.prefetch_related("team", "team__group")
def get_filter_for_scope_type(self, scope_type, scopes):
if scope_type.type == GroupObjectScopeType.type:
return TeamSubject.objects.filter(team__group=scope)
return Q(team__group__in=[s.id for s in scopes])
if scope_type.type == TeamObjectScopeType.type:
return TeamSubject.objects.filter(team=scope)
if scope_type.type == self.type:
return [scope]
return []
return Q(team__in=[s.id for s in scopes])
raise TypeError("The given type is not handled.")

View file

@ -1,7 +1,5 @@
from unittest.mock import patch
from django.test import override_settings
import pytest
from baserow.core.action.handler import ActionHandler
@ -20,9 +18,6 @@ def enable_enterprise_and_roles_for_all_tests_here(enable_enterprise, synced_rol
@pytest.mark.django_db
@pytest.mark.undo_redo
@override_settings(
PERMISSION_MANAGERS=["core", "staff", "member", "basic", "role"],
)
@patch("baserow.core.handler.CoreHandler.check_permissions")
def test_can_undo_assign_role(mock_check_permissions, data_fixture):
session_id = "session-id"

View file

@ -1,8 +1,9 @@
from unittest.mock import patch
from django.contrib.contenttypes.models import ContentType
from django.db import IntegrityError
from django.db import IntegrityError, connection
from django.db.models import Q
from django.test.utils import CaptureQueriesContext
import pytest
from pyinstrument import Profiler
@ -778,3 +779,67 @@ def test_get_roles_per_scope_trashed_teams(data_fixture, enterprise_data_fixture
assert RoleAssignmentHandler().get_roles_per_scope(group, user) == [
(group, [admin_role]),
]
@pytest.mark.disabled_in_ci
# You must add --run-disabled-in-ci -s to pytest to run this test, you can do this in
# intellij by editing the run config for this test and adding --run-disabled-in-ci -s
# to additional args.
# pytest -k "test_check_get_role_per_scope_performance" -s --run-disabled-in-ci
def test_check_get_role_per_scope_performance(
data_fixture, enterprise_data_fixture, profiler
):
user = data_fixture.create_user()
user2 = data_fixture.create_user()
group = data_fixture.create_group(user=user, members=[user2])
database1 = data_fixture.create_database_application(user=user, group=group)
table11 = data_fixture.create_database_table(user=user, database=database1)
table12 = data_fixture.create_database_table(user=user, database=database1)
database2 = data_fixture.create_database_application(user=user, group=group)
table21 = data_fixture.create_database_table(user=user, database=database2)
table22 = data_fixture.create_database_table(user=user, database=database2)
team1 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
team2 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
team3 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
editor_role = Role.objects.get(uid="EDITOR")
builder_role = Role.objects.get(uid="BUILDER")
viewer_role = Role.objects.get(uid="VIEWER")
no_role_role = Role.objects.get(uid="NO_ACCESS")
low_priority_role = Role.objects.get(uid="NO_ROLE_LOW_PRIORITY")
RoleAssignmentHandler().assign_role(
user, group, role=low_priority_role, scope=group
)
RoleAssignmentHandler().assign_role(user, group, role=editor_role, scope=database1)
RoleAssignmentHandler().assign_role(user, group, role=no_role_role, scope=table12)
RoleAssignmentHandler().assign_role(user, group, role=viewer_role, scope=table22)
RoleAssignmentHandler().assign_role(team1, group, role=builder_role, scope=group)
RoleAssignmentHandler().assign_role(team1, group, role=viewer_role, scope=database2)
RoleAssignmentHandler().assign_role(team2, group, role=editor_role, scope=group)
RoleAssignmentHandler().assign_role(team2, group, role=viewer_role, scope=database2)
RoleAssignmentHandler().assign_role(team3, group, role=builder_role, scope=group)
RoleAssignmentHandler().assign_role(team3, group, role=viewer_role, scope=database2)
role_assignment_handler = RoleAssignmentHandler()
with CaptureQueriesContext(connection) as captured:
role_assignment_handler.get_roles_per_scope(group, user)
for q in captured.captured_queries:
print(q)
print(len(captured.captured_queries))
with CaptureQueriesContext(connection) as captured:
role_assignment_handler.get_roles_per_scope(group, user)
print("----------- Second time ---------------")
for q in captured.captured_queries:
print(q)
print(len(captured.captured_queries))
with profiler(html_report_name="enterprise_get_roles_per_scope"):
for i in range(1000):
role_assignment_handler.get_roles_per_scope(group, user)

View file

@ -1,4 +1,6 @@
from django.db import connection
from django.test import override_settings
from django.test.utils import CaptureQueriesContext
import pytest
@ -683,7 +685,7 @@ def test_check_permissions_with_teams(
user,
UpdateApplicationOperationType.type,
group=group_1,
context=database_1,
context=database_1.application_ptr,
)
is True
)
@ -887,8 +889,8 @@ def test_get_permissions_object_with_teams(
perms = perm_manager.get_permissions_object(user, group=group_1)
assert all([not perm["default"] for perm in perms])
assert all([not perm["exceptions"] for perm in perms])
assert all([not perm["default"] for perm in perms.values()])
assert all([not perm["exceptions"] for perm in perms.values()])
# The user role should take the precedence
RoleAssignmentHandler().assign_role(user, group_1, role=role_builder)
@ -1111,3 +1113,141 @@ def test_all_operations_are_in_at_least_one_default_role(data_fixture):
assert missing_ops == [], "Non Assigned Ops:\n" + str(
"\n".join([o.__class__.__name__ + "," for o in missing_ops])
)
@pytest.mark.django_db
@pytest.mark.disabled_in_ci
# You must add --run-disabled-in-ci -s to pytest to run this test, you can do this in
# intellij by editing the run config for this test and adding --run-disabled-in-ci -s
# to additional args.
# pytest -k "test_check_permission_performance" -s --run-disabled-in-ci
def test_check_permission_performance(data_fixture, enterprise_data_fixture, profiler):
user = data_fixture.create_user()
user2 = data_fixture.create_user()
group = data_fixture.create_group(user=user, members=[user2])
database1 = data_fixture.create_database_application(user=user, group=group)
table11 = data_fixture.create_database_table(user=user, database=database1)
table12 = data_fixture.create_database_table(user=user, database=database1)
database2 = data_fixture.create_database_application(user=user, group=group)
table21 = data_fixture.create_database_table(user=user, database=database2)
table22 = data_fixture.create_database_table(user=user, database=database2)
team1 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
team2 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
team3 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
editor_role = Role.objects.get(uid="EDITOR")
builder_role = Role.objects.get(uid="BUILDER")
viewer_role = Role.objects.get(uid="VIEWER")
no_role_role = Role.objects.get(uid="NO_ACCESS")
low_priority_role = Role.objects.get(uid="NO_ROLE_LOW_PRIORITY")
RoleAssignmentHandler().assign_role(
user, group, role=low_priority_role, scope=group
)
RoleAssignmentHandler().assign_role(user, group, role=editor_role, scope=database1)
RoleAssignmentHandler().assign_role(user, group, role=no_role_role, scope=table12)
RoleAssignmentHandler().assign_role(user, group, role=viewer_role, scope=table22)
RoleAssignmentHandler().assign_role(team1, group, role=builder_role, scope=group)
RoleAssignmentHandler().assign_role(team1, group, role=viewer_role, scope=database2)
RoleAssignmentHandler().assign_role(team2, group, role=editor_role, scope=group)
RoleAssignmentHandler().assign_role(team2, group, role=viewer_role, scope=database2)
RoleAssignmentHandler().assign_role(team3, group, role=builder_role, scope=group)
RoleAssignmentHandler().assign_role(team3, group, role=viewer_role, scope=database2)
permission_manager = RolePermissionManagerType()
print("----------- first call queries ---------------")
with CaptureQueriesContext(connection) as captured:
permission_manager.check_permissions(
user, ReadDatabaseTableOperationType.type, group=group, context=table11
)
for q in captured.captured_queries:
print(q)
print(len(captured.captured_queries))
print("----------- second call queries ---------------")
with CaptureQueriesContext(connection) as captured:
permission_manager.check_permissions(
user, ReadDatabaseTableOperationType.type, group=group, context=table11
)
for q in captured.captured_queries:
print(q)
print(len(captured.captured_queries))
print("----------- check_permission perfs ---------------")
with profiler(html_report_name="enterprise_check_permissions"):
for i in range(1000):
permission_manager.check_permissions(
user, ReadDatabaseTableOperationType.type, group=group, context=table11
)
@pytest.mark.django_db
@pytest.mark.disabled_in_ci
# You must add --run-disabled-in-ci -s to pytest to run this test, you can do this in
# intellij by editing the run config for this test and adding --run-disabled-in-ci -s
# to additional args.
# pytest -k "test_get_permission_object_performance" -s --run-disabled-in-ci
def test_get_permission_object_performance(
data_fixture, enterprise_data_fixture, profiler
):
user = data_fixture.create_user()
user2 = data_fixture.create_user()
group = data_fixture.create_group(user=user, members=[user2])
database1 = data_fixture.create_database_application(user=user, group=group)
table11 = data_fixture.create_database_table(user=user, database=database1)
table12 = data_fixture.create_database_table(user=user, database=database1)
database2 = data_fixture.create_database_application(user=user, group=group)
table21 = data_fixture.create_database_table(user=user, database=database2)
table22 = data_fixture.create_database_table(user=user, database=database2)
team1 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
team2 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
team3 = enterprise_data_fixture.create_team(group=group, members=[user, user2])
editor_role = Role.objects.get(uid="EDITOR")
builder_role = Role.objects.get(uid="BUILDER")
viewer_role = Role.objects.get(uid="VIEWER")
no_role_role = Role.objects.get(uid="NO_ACCESS")
low_priority_role = Role.objects.get(uid="NO_ROLE_LOW_PRIORITY")
RoleAssignmentHandler().assign_role(
user, group, role=low_priority_role, scope=group
)
RoleAssignmentHandler().assign_role(user, group, role=editor_role, scope=database1)
RoleAssignmentHandler().assign_role(user, group, role=no_role_role, scope=table12)
RoleAssignmentHandler().assign_role(user, group, role=viewer_role, scope=table22)
RoleAssignmentHandler().assign_role(team1, group, role=builder_role, scope=group)
RoleAssignmentHandler().assign_role(team1, group, role=viewer_role, scope=database2)
RoleAssignmentHandler().assign_role(team2, group, role=editor_role, scope=group)
RoleAssignmentHandler().assign_role(team2, group, role=viewer_role, scope=database2)
RoleAssignmentHandler().assign_role(team3, group, role=builder_role, scope=group)
RoleAssignmentHandler().assign_role(team3, group, role=viewer_role, scope=database2)
permission_manager = RolePermissionManagerType()
print("----------- first call queries ---------------")
with CaptureQueriesContext(connection) as captured:
permission_manager.get_permissions_object(user, group=group)
for q in captured.captured_queries:
print(q)
print(len(captured.captured_queries))
print("----------- second call queries ---------------")
with CaptureQueriesContext(connection) as captured:
permission_manager.get_permissions_object(user, group=group)
for q in captured.captured_queries:
print(q)
print(len(captured.captured_queries))
print("----------- get_permission_object perfs ---------------")
with profiler(html_report_name="enterprise_get_permissions_object"):
for i in range(1000):
permission_manager.get_permissions_object(user, group=group)