1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-15 09:34:13 +00:00

Merge branch '3207-charge-for-user-source-users-count-the-number-of-external-users' into 'develop'

Resolve "Charge for user source users: count the number of external users."

Closes 

See merge request 
This commit is contained in:
Peter Evans 2024-12-10 15:02:28 +00:00
commit a1b7ee4790
16 changed files with 729 additions and 48 deletions
backend
src/baserow
tests/baserow
api/user_sources
contrib/builder/user_sources
core/user_sources
enterprise/backend
src/baserow_enterprise
audit_log
config/settings
integrations/local_baserow
tasks.py
tests/baserow_enterprise_tests/integrations/local_baserow

View file

@ -42,11 +42,27 @@ class UserSourceSerializer(serializers.ModelSerializer):
"""
type = serializers.SerializerMethodField(help_text="The type of the user_source.")
user_count = serializers.SerializerMethodField(
help_text="The total number of users in the user source."
)
user_count_updated_at = serializers.SerializerMethodField(
help_text="When the last user count took place."
)
@extend_schema_field(OpenApiTypes.STR)
def get_type(self, instance):
return user_source_type_registry.get_by_model(instance.specific_class).type
@extend_schema_field(OpenApiTypes.INT)
def get_user_count(self, instance):
user_count = instance.get_type().get_user_count(instance)
return user_count.count if user_count else None
@extend_schema_field(OpenApiTypes.DATETIME)
def get_user_count_updated_at(self, instance):
user_count = instance.get_type().get_user_count(instance)
return user_count.last_updated if user_count else None
auth_providers = ReadPolymorphicAppAuthProviderSerializer(
required=False,
many=True,
@ -64,6 +80,8 @@ class UserSourceSerializer(serializers.ModelSerializer):
"name",
"order",
"auth_providers",
"user_count",
"user_count_updated_at",
)
extra_kwargs = {
"id": {"read_only": True},
@ -74,6 +92,8 @@ class UserSourceSerializer(serializers.ModelSerializer):
"type": {"read_only": True},
"name": {"read_only": True},
"order": {"read_only": True, "help_text": "Lowest first."},
"user_count": {"read_only": True},
"user_count_updated_at": {"read_only": True},
}

View file

@ -295,7 +295,7 @@ class UserSourceView(APIView):
)
def patch(self, request, user_source_id: int):
"""
Update an user_source.
Update a user_source.
"""
user_source = UserSourceHandler().get_user_source_for_update(user_source_id)

View file

@ -0,0 +1,30 @@
from django.core.management.base import BaseCommand
from baserow.core.user_sources.handler import UserSourceHandler
class Command(BaseCommand):
help = (
"A management command which counts and caches all user source external "
"users. It's possible to reduce the scope by providing a user source type."
)
def add_arguments(self, parser):
parser.add_argument(
"--type",
type=str,
default=None,
help="Optionally choose a user source type to update, "
"instead of all of them.",
)
def handle(self, *args, **options):
user_source_type = options["type"]
UserSourceHandler().update_all_user_source_counts(user_source_type)
self.stdout.write(
self.style.SUCCESS(
"All configured user sources have been updated."
if not user_source_type
else f"All configured {user_source_type} user sources have been updated."
)
)

View file

@ -2,9 +2,12 @@ from ast import Dict
from typing import Iterable, List, Optional, Union
from zipfile import ZipFile
from django.conf import settings
from django.core.files.storage import Storage
from django.db.models import QuerySet
from loguru import logger
from baserow.core.db import specific_iterator
from baserow.core.exceptions import ApplicationOperationNotSupported
from baserow.core.models import Application
@ -29,7 +32,7 @@ class UserSourceHandler:
self, user_source_id: int, base_queryset: Optional[QuerySet] = None
) -> UserSource:
"""
Returns an user_source instance from the database.
Returns a user_source instance from the database.
:param user_source_id: The ID of the user_source.
:param base_queryset: The base queryset use to build the query if provided.
@ -54,11 +57,11 @@ class UserSourceHandler:
def get_user_source_by_uid(
self,
user_source_uid: int,
user_source_uid: str,
base_queryset: Optional[QuerySet] = None,
) -> UserSource:
"""
Returns an user_source instance from the database.
Returns a user_source instance from the database.
:param user_source_uid: The uid of the user_source.
:param base_queryset: The base queryset use to build the query if provided.
@ -85,7 +88,7 @@ class UserSourceHandler:
self, user_source_id: int, base_queryset: Optional[QuerySet] = None
) -> UserSourceForUpdate:
"""
Returns an user_source instance from the database that can be safely updated.
Returns a user_source instance from the database that can be safely updated.
:param user_source_id: The ID of the user_source.
:param base_queryset: The base queryset use to build the query if provided.
@ -212,8 +215,9 @@ class UserSourceHandler:
Updates and user_source with values. Will also check if the values are allowed
to be set on the user_source first.
:param user_source_type: The type of the user_source.
:param user_source: The user_source that should be updated.
:param values: The values that should be set on the user_source.
:param kwargs: The values that should be set on the user_source.
:return: The updated user_source.
"""
@ -232,7 +236,7 @@ class UserSourceHandler:
def delete_user_source(self, user_source: UserSource):
"""
Deletes an user_source.
Deletes a user_source.
:param user_source: The to-be-deleted user_source.
"""
@ -316,3 +320,33 @@ class UserSourceHandler:
id_mapping["user_sources"][serialized_user_source["id"]] = user_source.id
return user_source
def update_all_user_source_counts(
self, user_source_type: Optional[str] = None, raise_on_error: bool = False
):
"""
Responsible for iterating over all registered user source types, and asking the
implementation to count the number of external users it points to.
:param user_source_type: Optionally, a specific user source type to update.
:param raise_on_error: Whether to raise an exception when a user source
type raises an exception, or to continue with the remaining user sources.
:return: None
"""
user_source_types = (
[user_source_type_registry.get(user_source_type)]
if user_source_type
else user_source_type_registry.get_all()
)
for user_source_type in user_source_types:
try:
user_source_type.update_user_count()
except Exception as e:
if not settings.TESTS:
logger.exception(
f"Counting {user_source_type.type} external users failed: {e}"
)
if raise_on_error:
raise e
continue

View file

@ -37,7 +37,7 @@ class UserSource(
UserSources provide a way to configure user authentication source within an
application like the Application Builder.
An user_source can be associated with an application and it stores the data
A user_source can be associated with an application, and it stores the data
required to use the corresponding external service. This data may include an API
key for accessing an external database service, a user account for querying a
Baserow database, as well as the necessary URL, credentials, and headers for making

View file

@ -1,8 +1,10 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
from datetime import datetime
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, TypeVar
from django.contrib.auth.models import AbstractUser
from django.db.models import QuerySet
from baserow.core.app_auth_providers.handler import AppAuthProviderHandler
from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry
@ -24,6 +26,11 @@ from .models import UserSource
from .types import UserSourceDictSubClass, UserSourceSubClass
class UserSourceCount(NamedTuple):
count: Optional[int]
last_updated: Optional[datetime]
class UserSourceType(
ModelInstanceMixin[UserSource],
EasyImportExportMixin[UserSourceSubClass],
@ -35,6 +42,9 @@ class UserSourceType(
parent_property_name = "application"
id_mapping_name = "user_sources"
# When any of these properties are updated, the user count will be updated.
properties_requiring_user_recount = []
"""
An user_source type define a specific user_source with a given external service.
"""
@ -90,15 +100,37 @@ class UserSourceType(
user, ap_type, user_source, **ap
)
def after_update(self, user, user_source, values):
def after_update(
self,
user: AbstractUser,
user_source: UserSource,
values: Dict[str, Any],
trigger_user_count_update: bool = False,
):
"""
Recreate the auth providers.
Responsible for re-creating `auth_providers` if they are updated, and also
updating the user count if necessary.
:param user: The user on whose behalf the change is made.
:param user_source: The user source that has been updated.
:param values: The values that have been updated.
:param trigger_user_count_update: If True, the user count will be updated.
"""
if "auth_providers" in values:
user_source.auth_providers.all().delete()
self.after_create(user, user_source, values)
if trigger_user_count_update:
from baserow.core.user_sources.handler import UserSourceHandler
queryset = UserSourceHandler().get_user_sources(
user_source.application,
self.model_class.objects.filter(pk=user_source.pk),
specific=True,
)
self.update_user_count(queryset)
def serialize_property(
self,
instance: UserSource,
@ -278,6 +310,61 @@ class UserSourceType(
:param kwargs: The credential used to authenticate the user.
"""
def after_user_source_update_requires_user_recount(
self,
user_source: UserSource,
prepared_values: dict[str, Any],
) -> bool:
"""
Detects if any of the properties in the prepared_values require
a recount of the user source's user count.
:param user_source: the user source which is being updated.
:param prepared_values: the prepared values which will be
used to update the user source.
:return: whether a re-count is required.
"""
recount_required = False
for recount_property in self.properties_requiring_user_recount:
if recount_property in prepared_values:
current_value = getattr(user_source, recount_property)
updated_value = prepared_values[recount_property]
if current_value != updated_value:
recount_required = True
return recount_required
@abstractmethod
def update_user_count(
self,
user_sources: QuerySet[UserSource] = None,
) -> Optional[UserSourceCount]:
"""
Responsible for updating the cached number of users in this user source type.
If `user_sources` are provided, we will only update the user count for those
user sources. If no `user_sources` are provided, we will update the user count
for all user sources of this type.
:param user_sources: If a queryset of user sources is provided, we will only
update the user count for those user sources, otherwise we'll find all
user sources and update their user counts.
:return: if a `user_source` is provided, a `UserSourceCount is returned,
otherwise we will return `None`.
"""
@abstractmethod
def get_user_count(
self, user_source, force_recount: bool = False
) -> UserSourceCount:
"""
Responsible for retrieving a user source's count.
:param user_source: The user source we want a count from.
:param force_recount: If True, we will re-count the users and ignore any
existing cached count.
:return: A `UserSourceCount` instance or `None`.
"""
UserSourceTypeSubClass = TypeVar("UserSourceTypeSubClass", bound=UserSourceType)

View file

@ -65,11 +65,11 @@ class UserSourceService:
self, user: AbstractUser, user_source_uid: str, for_authentication: bool = False
) -> UserSource:
"""
Returns an user_source instance from the database. Also checks the user
Returns a user_source instance from the database. Also checks the user
permissions.
:param user: The user trying to get the user_source.
:param user_source_id: The ID of the user_source.
:param user_source_uid: The uid of the user_source.
:param for_authentication: If true we check a different permission.
:return: The user_source instance.
"""
@ -184,8 +184,7 @@ class UserSourceService:
:param user: The user trying to update the user_source.
:param user_source: The user_source that should be updated.
:param values: The values that should be set on the user_source.
:param kwargs: Additional attributes of the user_source.
:param kwargs: The values that should be set on the user_source.
:return: The updated user_source.
"""
@ -196,15 +195,27 @@ class UserSourceService:
context=user_source,
)
prepared_values = user_source.get_type().prepare_values(
kwargs, user, user_source
user_source_type: UserSourceType = user_source.get_type()
prepared_values = user_source_type.prepare_values(kwargs, user, user_source)
# Detect if a user re-count is required. Per user source type
# we track which properties changing triggers a recount.
trigger_user_count_update = (
user_source_type.after_user_source_update_requires_user_recount(
user_source, prepared_values
)
)
user_source = self.handler.update_user_source(
user_source.get_type(), user_source, **prepared_values
user_source_type, user_source, **prepared_values
)
user_source.get_type().after_update(user, user_source, prepared_values)
user_source_type.after_update(
user,
user_source,
prepared_values,
trigger_user_count_update=trigger_user_count_update,
)
user_source_updated.send(self, user_source=user_source, user=user)

View file

@ -4,6 +4,7 @@ import os
import sys
import threading
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional
@ -34,6 +35,7 @@ from baserow.core.permission_manager import CorePermissionManagerType
from baserow.core.services.dispatch_context import DispatchContext
from baserow.core.services.utils import ServiceAdhocRefinements
from baserow.core.trash.trash_types import WorkspaceTrashableItemType
from baserow.core.user_sources.registries import UserSourceCount
from baserow.core.utils import get_value_at_path
SKIP_FLAGS = ["disabled-in-ci", "once-per-day-in-ci"]
@ -229,6 +231,9 @@ def stub_user_source_registry(data_fixture, mutable_user_source_registry, fake):
get_user_return=None,
list_users_return=None,
gen_uid_return=None,
get_user_count_return=None,
update_user_count_return=None,
properties_requiring_user_recount_return=None,
):
"""
Replace first user_source type with the stub class
@ -241,6 +246,19 @@ def stub_user_source_registry(data_fixture, mutable_user_source_registry, fake):
class StubbedUserSourceType(UserSourceType):
type = user_source_type.type
model_class = user_source_type.model_class
properties_requiring_user_recount = properties_requiring_user_recount_return
def get_user_count(self, user_source, force_recount=False):
if get_user_count_return:
if callable(get_user_count_return):
return get_user_count_return(user_source, force_recount)
return UserSourceCount(count=5, last_updated=datetime.now())
def update_user_count(self, user_source=None):
if update_user_count_return:
if callable(update_user_count_return):
return update_user_count_return(user_source)
return None
def gen_uid(self, user_source):
if gen_uid_return:

View file

@ -122,6 +122,8 @@ def test_create_user_source(api_client, data_fixture):
response_json = response.json()
assert response.status_code == HTTP_200_OK
assert response_json["type"] == "local_baserow"
assert response_json["user_count"] is None
assert response_json["user_count_updated_at"] is None
@pytest.mark.django_db
@ -416,11 +418,14 @@ def test_create_user_source_bad_application_type(api_client, data_fixture):
@pytest.mark.django_db
def test_update_user_source(api_client, data_fixture):
user, token = data_fixture.create_user_and_token()
application = data_fixture.create_builder_application(user=user)
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
user_source1 = data_fixture.create_user_source_with_first_type(
application=application
)
integration = data_fixture.create_local_baserow_integration(user=user)
integration = data_fixture.create_local_baserow_integration(
application=application, authorized_user=user
)
url = reverse("api:user_sources:item", kwargs={"user_source_id": user_source1.id})
response = api_client.patch(
@ -431,8 +436,35 @@ def test_update_user_source(api_client, data_fixture):
)
assert response.status_code == HTTP_200_OK
assert response.json()["name"] == "newName"
assert response.json()["integration_id"] == integration.id
response_json = response.json()
assert response_json["name"] == "newName"
assert response_json["integration_id"] == integration.id
assert response_json["user_count"] is None
assert response_json["user_count_updated_at"] is None
database = data_fixture.create_database_application(workspace=workspace)
table = data_fixture.create_database_table(database=database)
name_field = data_fixture.create_text_field(table=table)
email_field = data_fixture.create_email_field(table=table)
model = table.get_model(field_ids=[])
model.objects.create()
model.objects.create()
model.objects.create()
response = api_client.patch(
url,
{
"table_id": table.id,
"name_field_id": name_field.id,
"email_field_id": email_field.id,
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
assert response.status_code == HTTP_200_OK
response_json = response.json()
assert response_json["user_count"] == 3
assert response_json["user_count_updated_at"] is not None
@pytest.mark.django_db

View file

@ -1,6 +1,7 @@
"""Test the UserSourceSerializer serializer."""
import pytest
from pytest_unordered import unordered
from baserow.api.user_sources.serializers import UserSourceSerializer
@ -23,6 +24,8 @@ def test_serializer_has_expected_fields(user_source):
"""Ensure the serializer returns the expected fields."""
expected_fields = [
"user_count",
"user_count_updated_at",
"application_id",
"auth_providers",
"id",
@ -34,4 +37,4 @@ def test_serializer_has_expected_fields(user_source):
]
serializer = UserSourceSerializer(instance=user_source)
assert sorted(serializer.data.keys()) == expected_fields
assert list(serializer.data.keys()) == unordered(expected_fields)

View file

@ -520,3 +520,26 @@ def test_get_all_roles_for_application_returns_user_roles(
user_roles = UserSourceHandler().get_all_roles_for_application(builder)
assert user_roles == expected_roles
@pytest.mark.django_db
def test_update_all_user_source_counts(stub_user_source_registry):
# Calling each `update_user_count`.
with stub_user_source_registry(update_user_count_return=lambda: 123):
UserSourceHandler().update_all_user_source_counts()
# When an exception raises, by default we won't propagate it.
def mock_raise_update_user_count(user_source):
raise Exception("An error has occurred.")
with stub_user_source_registry(
update_user_count_return=mock_raise_update_user_count
):
UserSourceHandler().update_all_user_source_counts()
# When an exception raises, we can make it propagate.
with stub_user_source_registry(
update_user_count_return=mock_raise_update_user_count
), pytest.raises(Exception) as exc:
UserSourceHandler().update_all_user_source_counts(raise_on_error=True)
assert str(exc.value) == "An error has occurred."

View file

@ -24,7 +24,7 @@ def clean_up_audit_log_entries(self):
@app.on_after_finalize.connect
def setup_periodic_tasks(sender, **kwargs):
def setup_periodic_audit_log_tasks(sender, **kwargs):
every = timedelta(
minutes=settings.BASEROW_ENTERPRISE_AUDIT_LOG_CLEANUP_INTERVAL_MINUTES
)

View file

@ -15,6 +15,17 @@ def setup(settings):
value['engine'] = 'some custom engine'
"""
settings.BASEROW_ENTERPRISE_USER_SOURCE_COUNTING_TASK_INTERVAL_MINUTES = int(
os.getenv("BASEROW_ENTERPRISE_USER_SOURCE_COUNTING_TASK_INTERVAL_MINUTES", "")
or 2 * 60
)
settings.BASEROW_ENTERPRISE_USER_SOURCE_COUNTING_CACHE_TTL_SECONDS = int(
# Default TTL is 120 minutes: 60 seconds * 120
os.getenv("BASEROW_ENTERPRISE_USER_SOURCE_COUNTING_CACHE_TTL_SECONDS")
or 7200
)
settings.BASEROW_ENTERPRISE_AUDIT_LOG_CLEANUP_INTERVAL_MINUTES = int(
os.getenv("BASEROW_ENTERPRISE_AUDIT_LOG_CLEANUP_INTERVAL_MINUTES", "")
or 24 * 60

View file

@ -1,6 +1,13 @@
import operator
from collections import defaultdict
from datetime import datetime, timezone
from functools import reduce
from typing import Any, Dict, List, Optional
from django.conf import settings
from django.contrib.auth.models import AbstractUser
from django.core.cache import cache
from django.db.models import Q, QuerySet
from loguru import logger
from rest_framework import serializers
@ -28,9 +35,9 @@ from baserow.core.formula.validator import ensure_string
from baserow.core.handler import CoreHandler
from baserow.core.user.exceptions import UserNotFound
from baserow.core.user_sources.exceptions import UserSourceImproperlyConfigured
from baserow.core.user_sources.models import UserSource
from baserow.core.user_sources.registries import UserSourceType
from baserow.core.user_sources.types import UserSourceDict, UserSourceSubClass
from baserow.core.user_sources.handler import UserSourceHandler
from baserow.core.user_sources.registries import UserSourceCount, UserSourceType
from baserow.core.user_sources.types import UserSourceDict
from baserow.core.user_sources.user_source_user import UserSourceUser
from baserow_enterprise.integrations.local_baserow.models import (
LocalBaserowPasswordAppAuthProvider,
@ -75,6 +82,15 @@ class LocalBaserowUserSourceType(UserSourceType):
]
allowed_fields = ["table", "email_field", "name_field", "role_field"]
# A list of fields which the page designer must configure so
# that the `LocalBaserowUserSource` is considered "configured".
fields_to_configure = [
"table_id",
"name_field_id",
"email_field_id",
"integration_id",
]
serializer_field_overrides = {
"table_id": serializers.IntegerField(
required=False,
@ -102,7 +118,7 @@ class LocalBaserowUserSourceType(UserSourceType):
self,
values: Dict[str, Any],
user: AbstractUser,
instance: Optional[UserSourceSubClass] = None,
instance: Optional[LocalBaserowUserSource] = None,
) -> Dict[str, Any]:
"""Load the table instance instead of the ID."""
@ -367,7 +383,13 @@ class LocalBaserowUserSourceType(UserSourceType):
return values
def after_update(self, user, user_source, values):
def after_update(
self,
user: AbstractUser,
user_source: LocalBaserowUserSource,
values: Dict[str, Any],
trigger_user_count_update: bool = False,
):
if "auth_provider" not in values and "table" in values:
# We clear all auth provider when the table changes
for ap in AppAuthProviderHandler.list_app_auth_providers_for_user_source(
@ -375,7 +397,9 @@ class LocalBaserowUserSourceType(UserSourceType):
):
ap.get_type().after_user_source_update(user, ap, user_source)
return super().after_update(user, user_source, values)
return super().after_update(
user, user_source, values, trigger_user_count_update
)
def deserialize_property(
self,
@ -411,7 +435,7 @@ class LocalBaserowUserSourceType(UserSourceType):
**kwargs,
)
def get_user_model(self, user_source):
def get_user_model(self, user_source: LocalBaserowUserSource):
try:
# Use table handler to exclude trashed table
table = TableHandler().get_table(user_source.table_id)
@ -433,19 +457,16 @@ class LocalBaserowUserSourceType(UserSourceType):
return model
def is_configured(self, user_source):
def is_configured(self, user_source: LocalBaserowUserSource) -> bool:
"""
Returns True if the user source is configured properly. False otherwise.
"""
return (
user_source.email_field_id is not None
and user_source.name_field_id is not None
and user_source.table_id is not None
and user_source.integration_id is not None
return not any(
[getattr(user_source, field) is None for field in self.fields_to_configure]
)
def gen_uid(self, user_source):
def gen_uid(self, user_source: LocalBaserowUserSource):
"""
We want to invalidate user tokens if the table or the email field change.
"""
@ -465,7 +486,7 @@ class LocalBaserowUserSourceType(UserSourceType):
return role_field.get_type().type in self.field_types_allowed_as_role
def get_user_role(self, user, user_source: UserSource) -> str:
def get_user_role(self, user, user_source: LocalBaserowUserSource) -> str:
"""
Return the User Role of the user if the role_field is defined.
@ -481,7 +502,9 @@ class LocalBaserowUserSourceType(UserSourceType):
return self.get_default_user_role(user_source)
def list_users(self, user_source: UserSource, count: int = 5, search: str = ""):
def list_users(
self, user_source: LocalBaserowUserSource, count: int = 5, search: str = ""
):
"""
Returns the users from the table selected with the user source.
"""
@ -516,7 +539,7 @@ class LocalBaserowUserSourceType(UserSourceType):
for user in queryset[:count]
]
def get_roles(self, user_source: UserSource) -> List[str]:
def get_roles(self, user_source: LocalBaserowUserSource) -> List[str]:
"""
Given a UserSource, return all valid roles for it.
@ -558,7 +581,7 @@ class LocalBaserowUserSourceType(UserSourceType):
return roles
def get_user(self, user_source: UserSource, **kwargs):
def get_user(self, user_source: LocalBaserowUserSource, **kwargs):
"""
Returns a user from the selected table.
"""
@ -598,7 +621,7 @@ class LocalBaserowUserSourceType(UserSourceType):
raise UserNotFound()
def create_user(self, user_source: UserSource, email, name, role=None):
def create_user(self, user_source: LocalBaserowUserSource, email, name, role=None):
"""
Creates the user in the configured table.
"""
@ -646,7 +669,7 @@ class LocalBaserowUserSourceType(UserSourceType):
self.get_user_role(user, user_source),
)
def authenticate(self, user_source: UserSource, **kwargs):
def authenticate(self, user_source: LocalBaserowUserSource, **kwargs):
"""
Authenticates using the given credentials. It uses the password auth provider.
"""
@ -670,3 +693,154 @@ class LocalBaserowUserSourceType(UserSourceType):
kwargs.get("email", ""),
kwargs.get("password", ""),
)
def _get_cached_user_count(
self, user_source: LocalBaserowUserSource
) -> Optional[UserSourceCount]:
"""
Given a `user_source`, return the cached user count if it exists.
:param user_source: The `LocalBaserowUserSource` instance.
:return: A `UserSourceCount` instance if the cached user count exists,
otherwise `None`.
"""
cached_value = cache.get(
self._generate_update_user_count_cache_key(user_source)
)
if cached_value is not None:
user_count, timestamp = cached_value.split("-")
return UserSourceCount(
count=int(user_count),
last_updated=datetime.fromtimestamp(float(timestamp)),
)
return None
def _generate_update_user_count_cache_key(
self, user_source: LocalBaserowUserSource
) -> str:
"""
Given a `user_source`, generate a cache key for the user count cache entry.
:param user_source: The `LocalBaserowUserSource` instance.
:return: A string representing the cache key.
"""
return f"local_baserow_user_source_{user_source.id}_user_count"
def _generate_update_user_count_cache_value(
self, user_count: int, now: datetime = None
) -> str:
"""
Given a `user_count`, generate a cache value for the user count cache entry.
:param user_count: The user count integer.
:param now: The datetime object representing the current time. If not provided,
we will use the current datetime.
:return: A string representing the cache value.
"""
now = now or datetime.now(tz=timezone.utc)
return f"{user_count}-{now.timestamp()}"
def after_user_source_update_requires_user_recount(
self,
user_source: LocalBaserowUserSource,
prepared_values: dict[str, Any],
) -> bool:
"""
By default, the Local Baserow user source type will re-count
its users following any change to the user source.
:param user_source: the user source which is being updated.
:param prepared_values: the prepared values which will be
used to update the user source.
:return: whether a re-count is required.
"""
return True
def update_user_count(
self,
user_sources: QuerySet[LocalBaserowUserSource] = None,
) -> Optional[UserSourceCount]:
"""
Responsible for updating the cached number of users in this user source type.
If `user_sources` are provided, we will only update the user count for those
user sources. If no `user_sources` are provided, we will update the user count
for all configured `LocalBaserowUserSource`.
:param user_sources: If a queryset of user sources is provided, we will only
update the user count for those user sources, otherwise we'll find all
configured user sources and update their user counts.
:return: if a `user_source` is provided, a `UserSourceCount is returned,
otherwise we will return `None`.
"""
# If no `user_sources` are provided, we will query for all "configured"
# user sources, i.e. those that have all the required fields set.
if not user_sources:
field_q = reduce(
operator.and_,
(~Q(**{field: None}) for field in self.fields_to_configure),
)
user_sources = self.model_class.objects.filter(field_q)
# Fetch all the table records in bulk.
user_source_table_map = defaultdict(list)
for us in user_sources:
user_source_table_map[us.table_id].append(us)
tables = TableHandler.get_tables().filter(id__in=user_source_table_map.keys())
user_source_count = None
for table in tables:
model = table.get_model(field_ids=[])
user_count = model.objects.count()
user_sources_using_table = user_source_table_map[table.id]
for user_source_using_table in user_sources_using_table:
now = datetime.now(tz=timezone.utc)
cache.set(
self._generate_update_user_count_cache_key(user_source_using_table),
self._generate_update_user_count_cache_value(user_count, now),
timeout=settings.BASEROW_ENTERPRISE_USER_SOURCE_COUNTING_CACHE_TTL_SECONDS,
)
if user_sources and user_source_using_table in user_sources:
user_source_count = UserSourceCount(
count=user_count,
last_updated=now,
)
return user_source_count
def get_user_count(
self, user_source: LocalBaserowUserSource, force_recount: bool = False
) -> Optional[UserSourceCount]:
"""
Responsible for retrieving a user source's count. If the user source isn't
configured, `None` will be returned. If it's configured, and cached, so long
as we're not `force_recount=True`, the cached user count will be returned.
If the count isn't cached, or `force_recount=True`, we will count the users,
cache the result, and return the count.
:param user_source: The user source we want a count from.
:param force_recount: If True, we will re-count the users and ignore any
existing cached count.
:return: A `UserSourceCount` instance if the user source is configured,
otherwise `None`.
"""
# If we're being asked for the user count of a
# misconfigured user source, we'll return None.
if not self.is_configured(user_source):
return None
cached_user_source_count = self._get_cached_user_count(user_source)
if cached_user_source_count and not force_recount:
return cached_user_source_count
queryset = UserSourceHandler().get_user_sources(
user_source.application,
self.model_class.objects.filter(pk=user_source.pk),
specific=True,
)
return self.update_user_count(queryset) # type: ignore

View file

@ -1,13 +1,37 @@
from datetime import timedelta
from django.conf import settings
from baserow.config.celery import app
from baserow.contrib.database.table.tasks import (
unsubscribe_subject_from_tables_currently_subscribed_to,
)
from baserow.core.user_sources.handler import UserSourceHandler
from baserow_enterprise.audit_log.tasks import (
clean_up_audit_log_entries,
setup_periodic_tasks,
setup_periodic_audit_log_tasks,
)
@app.task(bind=True, queue="export")
def count_all_user_source_users(self):
"""
Responsible for periodically looping through all user sources, counting the number
of external sources there are per user source type, and caching the results.
"""
UserSourceHandler().update_all_user_source_counts()
@app.on_after_finalize.connect
def setup_periodic_enterprise_tasks(sender, **kwargs):
every = timedelta(
minutes=settings.BASEROW_ENTERPRISE_USER_SOURCE_COUNTING_TASK_INTERVAL_MINUTES
)
sender.add_periodic_task(every, count_all_user_source_users.s())
@app.task(bind=True)
def unsubscribe_subject_from_tables_currently_subscribed_to_task(
self,
@ -40,4 +64,4 @@ def unsubscribe_subject_from_tables_currently_subscribed_to_task(
)
__all__ = ["clean_up_audit_log_entries", "setup_periodic_tasks"]
__all__ = ["clean_up_audit_log_entries", "setup_periodic_audit_log_tasks"]

View file

@ -1,9 +1,11 @@
from collections import defaultdict
from unittest.mock import MagicMock, patch
from datetime import datetime
from unittest.mock import MagicMock, Mock, patch
from django.urls import reverse
import pytest
from freezegun import freeze_time
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST
@ -39,6 +41,7 @@ from baserow.core.user_sources.registries import (
DEFAULT_USER_ROLE_PREFIX,
user_source_type_registry,
)
from baserow.core.user_sources.service import UserSourceService
from baserow.core.utils import MirrorDict, Progress
from baserow.test_utils.helpers import AnyStr
from baserow_enterprise.integrations.local_baserow.models import LocalBaserowUserSource
@ -2210,3 +2213,214 @@ def test_local_baserow_user_source_get_user_is_case_insensitive(
user = user_source_type.get_user(user_source, email=user_provided_email)
assert user.email == actual_email
def test__generate_update_user_count_cache_key():
user_source = Mock(id=123)
assert (
LocalBaserowUserSourceType()._generate_update_user_count_cache_key(user_source)
== "local_baserow_user_source_123_user_count"
)
def test__generate_update_user_count_cache_value():
with freeze_time("2024-11-29T12:00:00.00Z"):
assert (
LocalBaserowUserSourceType()._generate_update_user_count_cache_value(500)
== "500-1732881600.0"
)
@pytest.mark.django_db
def test_update_user_count_with_configured_user_sources(
data_fixture, django_assert_num_queries
):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
application1 = data_fixture.create_builder_application(workspace=workspace)
integration1 = data_fixture.create_local_baserow_integration(
application=application1
)
application2 = data_fixture.create_builder_application(workspace=workspace)
integration2 = data_fixture.create_local_baserow_integration(
application=application2
)
table, fields, rows = data_fixture.build_table(
user=user,
columns=[
("Email", "text"),
("Name", "text"),
("Role", "text"),
],
rows=[
["jrmi@baserow.io", "Jérémie", ""],
["peter@baserow.io", "Peter", ""],
["afonso@baserow.io", "Afonso", ""],
["tsering@baserow.io", "Tsering", ""],
["evren@baserow.io", "Evren", ""],
],
)
email_field, name_field, role_field = fields
local_baserow_user_source_type = user_source_type_registry.get("local_baserow")
user_source1 = data_fixture.create_user_source(
local_baserow_user_source_type.model_class,
application=application1,
integration=integration1,
table=table,
email_field=email_field,
name_field=name_field,
role_field=role_field,
)
user_source2 = data_fixture.create_user_source(
local_baserow_user_source_type.model_class,
application=application2,
integration=integration2,
table=table,
email_field=email_field,
name_field=name_field,
role_field=role_field,
)
# 1. Fetching all configured user sources.
# 2. Fetching all tables for the user sources.
# 3. One COUNT per table, in our case here, once.
with freeze_time("2030-11-29T12:00:00.00Z"), django_assert_num_queries(3):
local_baserow_user_source_type.update_user_count()
with django_assert_num_queries(0):
user_count = local_baserow_user_source_type.get_user_count(user_source1)
assert user_count.count == 5
assert user_count.last_updated == datetime(2030, 11, 29, 12, 0, 0)
with django_assert_num_queries(0):
user_count = local_baserow_user_source_type.get_user_count(user_source2)
assert user_count.count == 5
assert user_count.last_updated == datetime(2030, 11, 29, 12, 0, 0)
@pytest.mark.django_db
def test_update_user_count_with_misconfigured_user_sources(
data_fixture, django_assert_num_queries
):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
database = data_fixture.create_database_application(workspace=workspace)
table = data_fixture.create_database_table(database=database)
application1 = data_fixture.create_builder_application(workspace=workspace)
integration1 = data_fixture.create_local_baserow_integration(
application=application1
)
application2 = data_fixture.create_builder_application(workspace=workspace)
integration2 = data_fixture.create_local_baserow_integration(
application=application2
)
local_baserow_user_source_type = user_source_type_registry.get("local_baserow")
user_source1 = data_fixture.create_user_source(
local_baserow_user_source_type.model_class,
table=table,
application=application1,
integration=integration1,
)
user_source2 = data_fixture.create_user_source(
local_baserow_user_source_type.model_class,
table=table,
application=application2,
integration=integration2,
)
# 1. Fetching all configured user sources.
with django_assert_num_queries(1):
local_baserow_user_source_type.update_user_count()
# `get_user_count` will find that there's no cache entry, so they
# will each trigger 1 query. We only cache the count if the user
# source is configured.
with django_assert_num_queries(0):
assert local_baserow_user_source_type.get_user_count(user_source1) is None
with django_assert_num_queries(0):
assert local_baserow_user_source_type.get_user_count(user_source2) is None
@pytest.mark.django_db
def test_trigger_user_count_update_on_properties_requiring_user_recount_update(
data_fixture,
):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
database = data_fixture.create_database_application(workspace=workspace)
application = data_fixture.create_builder_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(application=application)
table_a, fields_a, rows_a = data_fixture.build_table(
database=database,
columns=[
("Email", "text"),
("Name", "text"),
("Role", "text"),
],
rows=[
["jrmi@baserow.io", "Jérémie", ""],
["peter@baserow.io", "Peter", ""],
],
)
email_field_a, name_field_a, role_field_a = fields_a
user_source_type = LocalBaserowUserSourceType()
user_source = data_fixture.create_user_source(
user_source_type.model_class,
table=table_a,
application=application,
integration=integration,
email_field=email_field_a,
name_field=name_field_a,
role_field=role_field_a,
)
table_a_count = user_source_type.get_user_count(user_source)
assert table_a_count.count == 2
table_b, fields_b, rows_b = data_fixture.build_table(
database=database,
columns=[
("Email", "text"),
("Name", "text"),
("Role", "text"),
],
rows=[
["jrmi@baserow.io", "Jérémie", ""],
["peter@baserow.io", "Peter", ""],
["afonso@baserow.io", "Afonso", ""],
["tsering@baserow.io", "Tsering", ""],
["evren@baserow.io", "Evren", ""],
],
)
email_field_b, name_field_b, role_field_b = fields_b
UserSourceService().update_user_source(
user,
user_source,
**{
"table": table_b,
"email_field": email_field_b,
"name_field": name_field_b,
"role_field": role_field_b,
},
)
table_b_count = user_source_type.get_user_count(user_source)
assert table_b_count.count == 5
assert (
table_b_count.last_updated != table_a_count.last_updated
) # confirm it's a fresh cache entry
def test_local_baserow_after_user_source_update_requires_user_recount():
assert (
LocalBaserowUserSourceType().after_user_source_update_requires_user_recount(
Mock(), {}
)
is True
)