diff --git a/backend/src/baserow/api/app_auth_providers/serializers.py b/backend/src/baserow/api/app_auth_providers/serializers.py index f937b9762..05e03c99f 100644 --- a/backend/src/baserow/api/app_auth_providers/serializers.py +++ b/backend/src/baserow/api/app_auth_providers/serializers.py @@ -7,6 +7,7 @@ from rest_framework import serializers from baserow.api.polymorphic import PolymorphicSerializer from baserow.core.app_auth_providers.models import AppAuthProvider from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry +from baserow.core.auth_provider.validators import validate_domain class AppAuthProviderSerializer(serializers.ModelSerializer): @@ -46,6 +47,13 @@ class BaseAppAuthProviderSerializer(serializers.ModelSerializer): help_text="The type of the app_auth_provider.", ) + domain = serializers.CharField( + validators=[validate_domain], + required=False, + allow_null=True, + help_text=AppAuthProvider._meta.get_field("domain").help_text, + ) + class Meta: model = AppAuthProvider fields = ("type", "user_source_id", "domain") diff --git a/backend/src/baserow/api/auth_provider/serializers.py b/backend/src/baserow/api/auth_provider/serializers.py index 329ad4601..e732a1f59 100644 --- a/backend/src/baserow/api/auth_provider/serializers.py +++ b/backend/src/baserow/api/auth_provider/serializers.py @@ -3,12 +3,19 @@ from drf_spectacular.utils import extend_schema_field from rest_framework import serializers from baserow.core.auth_provider.models import AuthProviderModel +from baserow.core.auth_provider.validators import validate_domain from baserow.core.registries import auth_provider_type_registry class AuthProviderSerializer(serializers.ModelSerializer): type = serializers.SerializerMethodField(help_text="The type of the related field.") + domain = serializers.CharField( + validators=[validate_domain], + required=True, + help_text=AuthProviderModel._meta.get_field("domain").help_text, + ) + class Meta: model = AuthProviderModel fields = ("id", "type", "domain", "enabled") diff --git a/backend/src/baserow/api/polymorphic.py b/backend/src/baserow/api/polymorphic.py index 5e6aaecba..6a0506492 100644 --- a/backend/src/baserow/api/polymorphic.py +++ b/backend/src/baserow/api/polymorphic.py @@ -104,6 +104,7 @@ class PolymorphicSerializer(serializers.Serializer): base_class=self.base_class, request=self.request, context=self.context, + extra_params=self.extra_params, ) ret = serializer.to_representation(instance) @@ -122,6 +123,7 @@ class PolymorphicSerializer(serializers.Serializer): base_class=self.base_class, request=self.request, context=self.context, + extra_params=self.extra_params, ) return serializer.to_internal_value(data) @@ -134,6 +136,7 @@ class PolymorphicSerializer(serializers.Serializer): base_class=self.base_class, request=self.request, context=self.context, + extra_params=self.extra_params, ) return serializer.create(validated_data) @@ -150,6 +153,7 @@ class PolymorphicSerializer(serializers.Serializer): base_class=self.base_class, request=self.request, context=self.context, + extra_params=self.extra_params, ) return serializer.update(instance, validated_data) @@ -170,6 +174,7 @@ class PolymorphicSerializer(serializers.Serializer): context=self.context, data=self.data, partial=self.partial, + extra_params=self.extra_params, ) except serializers.ValidationError: child_valid = False @@ -194,6 +199,7 @@ class PolymorphicSerializer(serializers.Serializer): request=self.request, context=self.context, partial=self.partial, + extra_params=self.extra_params, ) validated_data = serializer.run_validation(data) diff --git a/backend/src/baserow/api/urls.py b/backend/src/baserow/api/urls.py index 0eb6392a3..e8199359f 100755 --- a/backend/src/baserow/api/urls.py +++ b/backend/src/baserow/api/urls.py @@ -2,7 +2,11 @@ from django.urls import include, path from drf_spectacular.views import SpectacularRedocView -from baserow.core.registries import application_type_registry, plugin_registry +from baserow.core.registries import ( + application_type_registry, + auth_provider_type_registry, + plugin_registry, +) from .applications import urls as application_urls from .auth_provider import urls as auth_provider_urls @@ -53,5 +57,6 @@ urlpatterns = ( ), ] + application_type_registry.api_urls + + auth_provider_type_registry.api_urls + plugin_registry.api_urls ) diff --git a/backend/src/baserow/api/user_sources/authentication.py b/backend/src/baserow/api/user_sources/authentication.py index e70f6236d..15e121944 100644 --- a/backend/src/baserow/api/user_sources/authentication.py +++ b/backend/src/baserow/api/user_sources/authentication.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple, TypeVar from django.conf import settings from django.contrib.auth.models import AbstractBaseUser, AnonymousUser +from drf_spectacular.extensions import OpenApiAuthenticationExtension from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework.request import Request from rest_framework_simplejwt.authentication import JWTAuthentication @@ -157,3 +158,19 @@ class UserSourceJSONWebTokenAuthentication(JWTAuthentication): user, validated_token, ) + + +class UserSourceJSONWebTokenAuthenticationExtension(OpenApiAuthenticationExtension): + target_class = ( + "baserow.api.user_sources.authentication.UserSourceJSONWebTokenAuthentication" + ) + name = "UserSource JWT" + match_subclasses = True + priority = -1 + + def get_security_definition(self, auto_schema): + return { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT your_token", + } diff --git a/backend/src/baserow/contrib/builder/api/domains/serializers.py b/backend/src/baserow/contrib/builder/api/domains/serializers.py index 0c09df93a..abbffc6ba 100644 --- a/backend/src/baserow/contrib/builder/api/domains/serializers.py +++ b/backend/src/baserow/contrib/builder/api/domains/serializers.py @@ -6,9 +6,7 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import extend_schema_field from rest_framework import serializers -from baserow.api.app_auth_providers.serializers import ( - ReadPolymorphicAppAuthProviderSerializer, -) +from baserow.api.app_auth_providers.serializers import AppAuthProviderSerializer from baserow.api.polymorphic import PolymorphicSerializer from baserow.api.services.serializers import PublicServiceSerializer from baserow.api.user_files.serializers import UserFileField, UserFileSerializer @@ -26,6 +24,7 @@ from baserow.contrib.builder.elements.registries import element_type_registry from baserow.contrib.builder.models import Builder from baserow.contrib.builder.pages.handler import PageHandler from baserow.contrib.builder.pages.models import Page +from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry from baserow.core.services.registries import service_type_registry from baserow.core.user_sources.models import UserSource from baserow.core.user_sources.registries import user_source_type_registry @@ -175,6 +174,16 @@ class PublicPageSerializer(serializers.ModelSerializer): } +class PublicPolymorphicAppAuthProviderSerializer(PolymorphicSerializer): + """ + Polymorphic serializer for App Auth providers. + """ + + base_class = AppAuthProviderSerializer + registry = app_auth_provider_type_registry + extra_params = {"public": True} + + class BasePublicUserSourceSerializer(serializers.ModelSerializer): """ Basic user source serializer mostly for returned values. @@ -186,7 +195,7 @@ class BasePublicUserSourceSerializer(serializers.ModelSerializer): def get_type(self, instance): return user_source_type_registry.get_by_model(instance.specific_class).type - auth_providers = ReadPolymorphicAppAuthProviderSerializer( + auth_providers = PublicPolymorphicAppAuthProviderSerializer( required=False, many=True, help_text="Auth providers related to this user source.", diff --git a/backend/src/baserow/contrib/builder/application_types.py b/backend/src/baserow/contrib/builder/application_types.py index b3a7ab7da..017944609 100755 --- a/backend/src/baserow/contrib/builder/application_types.py +++ b/backend/src/baserow/contrib/builder/application_types.py @@ -1,6 +1,8 @@ from typing import Any, Dict, List, Optional +from urllib.parse import urljoin from zipfile import ZipFile +from django.conf import settings from django.contrib.auth.models import AbstractUser from django.core.files.storage import Storage from django.db import transaction @@ -419,6 +421,30 @@ class BuilderApplicationType(ApplicationType): return builder + def get_default_application_urls(self, application: Builder) -> list[str]: + """ + Returns the default frontend urls of a builder application. + """ + + from baserow.contrib.builder.domains.handler import DomainHandler + + domain = DomainHandler().get_domain_for_builder(application) + + if domain is not None: + # Let's also return the preview url so that it's easier to test + preview_url = urljoin( + settings.PUBLIC_WEB_FRONTEND_URL, + f"/builder/{domain.builder_id}/preview/", + ) + return [domain.get_public_url(), preview_url] + + preview_url = urljoin( + settings.PUBLIC_WEB_FRONTEND_URL, + f"/builder/{application.id}/preview/", + ) + # It's an unpublished version let's return to the home preview page + return [preview_url] + def enhance_queryset(self, queryset): queryset = queryset.prefetch_related("page_set") queryset = queryset.prefetch_related("user_sources") diff --git a/backend/src/baserow/contrib/builder/domains/handler.py b/backend/src/baserow/contrib/builder/domains/handler.py index 985db5b57..e821a8885 100644 --- a/backend/src/baserow/contrib/builder/domains/handler.py +++ b/backend/src/baserow/contrib/builder/domains/handler.py @@ -85,6 +85,17 @@ class DomainHandler: return domain.published_to + def get_domain_for_builder(self, builder: Builder) -> Domain | None: + """ + Returns the domain the builder is published for or None if it's not a published + builder. + """ + + try: + return Domain.objects.get(published_to=builder) + except Domain.DoesNotExist: + return None + def create_domain( self, domain_type: DomainType, builder: Builder, **kwargs ) -> Domain: diff --git a/backend/src/baserow/contrib/builder/domains/models.py b/backend/src/baserow/contrib/builder/domains/models.py index f549d0ac2..6fef00918 100644 --- a/backend/src/baserow/contrib/builder/domains/models.py +++ b/backend/src/baserow/contrib/builder/domains/models.py @@ -1,3 +1,6 @@ +from urllib.parse import urlparse + +from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models import CASCADE, SET_NULL @@ -80,6 +83,16 @@ class Domain( class Meta: ordering = ("order",) + def get_public_url(self): + """ + Returns the URL for this domain. + """ + + # Parse the PUBLIC_WEB_FRONTEND_URL to extract the scheme and port + parsed_url = urlparse(settings.PUBLIC_WEB_FRONTEND_URL) + port_string = f":{parsed_url.port}" if parsed_url.port else "" + return f"{parsed_url.scheme}://{self.domain_name}{port_string}" + @classmethod def get_last_order(cls, builder): queryset = Domain.objects.filter(builder=builder) diff --git a/backend/src/baserow/contrib/builder/domains/permission_manager.py b/backend/src/baserow/contrib/builder/domains/permission_manager.py index 8a2139771..2d50f2f9a 100755 --- a/backend/src/baserow/contrib/builder/domains/permission_manager.py +++ b/backend/src/baserow/contrib/builder/domains/permission_manager.py @@ -4,6 +4,7 @@ from baserow.contrib.builder.data_sources.operations import ( DispatchDataSourceOperationType, ListDataSourcesPageOperationType, ) +from baserow.contrib.builder.domains.handler import DomainHandler from baserow.contrib.builder.elements.operations import ListElementsPageOperationType from baserow.contrib.builder.models import Builder from baserow.contrib.builder.workflow_actions.operations import ( @@ -20,8 +21,6 @@ from baserow.core.user_sources.operations import ( ) from baserow.core.user_sources.subjects import UserSourceUserSubjectType -from .models import Domain - User = get_user_model() @@ -101,10 +100,7 @@ class AllowPublicBuilderManagerType(PermissionManagerType): # give access to specific data. continue - if ( - builder.workspace is None - and Domain.objects.filter(published_to=builder).exists() - ): + if DomainHandler().get_domain_for_builder(builder) is not None: # it's a public builder, we allow it. result[check] = True diff --git a/backend/src/baserow/contrib/builder/registries.py b/backend/src/baserow/contrib/builder/registries.py deleted file mode 100644 index 737524c7a..000000000 --- a/backend/src/baserow/contrib/builder/registries.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import Dict, List, Optional - -from baserow.core.registry import CustomFieldsInstanceMixin - - -class PublicCustomFieldsInstanceMixin(CustomFieldsInstanceMixin): - public_serializer_field_names = [] - """The field names that must be added to the serializer if it's public.""" - - public_request_serializer_field_names = [] - """ - The field names that must be added to the public request serializer if different - from the `public_serializer_field_names`. - """ - - request_serializer_field_overrides = None - """ - The fields that must be added to the request serializer if different from the - `serializer_field_overrides` property. - """ - - public_serializer_field_overrides = None - """The fields that must be added to the public serializer.""" - - public_request_serializer_field_overrides = None - """ - The fields that must be added to the public request serializer if different from the - `public_serializer_field_overrides` property. - """ - - def get_field_overrides( - self, request_serializer: bool, extra_params=None, **kwargs - ) -> Dict: - public = extra_params.get("public", False) - - if public: - if request_serializer and self.public_request_serializer_field_overrides: - return self.public_request_serializer_field_overrides - if self.public_serializer_field_overrides: - return self.public_serializer_field_overrides - - return super().get_field_overrides(request_serializer, extra_params, **kwargs) - - def get_field_names( - self, request_serializer: bool, extra_params=None, **kwargs - ) -> List[str]: - public = extra_params.get("public", False) - - if public: - if request_serializer and self.public_request_serializer_field_names: - return self.public_request_serializer_field_names - if self.public_serializer_field_names: - return self.public_serializer_field_names - - return super().get_field_names(request_serializer, extra_params, **kwargs) - - def get_meta_ref_name( - self, - request_serializer: bool, - extra_params=None, - **kwargs, - ) -> Optional[str]: - meta_ref_name = super().get_meta_ref_name( - request_serializer, extra_params, **kwargs - ) - - public = extra_params.get("public", False) - - if public: - meta_ref_name = f"Public{meta_ref_name}" - - return meta_ref_name diff --git a/backend/src/baserow/contrib/builder/workflow_actions/registries.py b/backend/src/baserow/contrib/builder/workflow_actions/registries.py index 227b2b775..ec39db948 100644 --- a/backend/src/baserow/contrib/builder/workflow_actions/registries.py +++ b/backend/src/baserow/contrib/builder/workflow_actions/registries.py @@ -4,11 +4,11 @@ from django.contrib.auth.models import AbstractUser from baserow.contrib.builder.formula_importer import import_formula from baserow.contrib.builder.mixins import BuilderInstanceWithFormulaMixin -from baserow.contrib.builder.registries import PublicCustomFieldsInstanceMixin from baserow.contrib.builder.workflow_actions.models import BuilderWorkflowAction from baserow.core.registry import ( CustomFieldsRegistryMixin, ModelRegistryMixin, + PublicCustomFieldsInstanceMixin, Registry, ) from baserow.core.workflow_actions.registries import WorkflowActionType @@ -88,6 +88,7 @@ class BuilderWorkflowActionType( cache = {} element_id = serialized_values["element_id"] + import_context = {} if element_id: imported_element_id = id_mapping["builder_page_elements"][element_id] import_context = ElementHandler().get_import_context_addition( diff --git a/backend/src/baserow/core/app_auth_providers/auth_provider_types.py b/backend/src/baserow/core/app_auth_providers/auth_provider_types.py index 3de2d05f3..1fe191a1a 100644 --- a/backend/src/baserow/core/app_auth_providers/auth_provider_types.py +++ b/backend/src/baserow/core/app_auth_providers/auth_provider_types.py @@ -1,18 +1,20 @@ -from typing import TYPE_CHECKING, Callable, List, Type, Union +from typing import TYPE_CHECKING, Callable, List, Tuple, Type, Union from django.contrib.auth.models import AbstractUser from baserow.core.app_auth_providers.exceptions import IncompatibleUserSourceType from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict from baserow.core.auth_provider.registries import BaseAuthProviderType -from baserow.core.auth_provider.types import AuthProviderModelSubClass -from baserow.core.registry import EasyImportExportMixin +from baserow.core.auth_provider.types import AuthProviderModelSubClass, UserInfo +from baserow.core.registry import EasyImportExportMixin, PublicCustomFieldsInstanceMixin if TYPE_CHECKING: from baserow.core.user_sources.types import UserSourceSubClass -class AppAuthProviderType(EasyImportExportMixin, BaseAuthProviderType): +class AppAuthProviderType( + EasyImportExportMixin, PublicCustomFieldsInstanceMixin, BaseAuthProviderType +): """ Authentication provider for application user sources. """ @@ -70,3 +72,16 @@ class AppAuthProviderType(EasyImportExportMixin, BaseAuthProviderType): :param instance: The auth provider instance related to the user source. :param user_source: The user source being updated. """ + + def get_or_create_user_and_sign_in( + self, auth_provider: AuthProviderModelSubClass, user_info: UserInfo + ) -> Tuple[AbstractUser, bool]: + """ + Get or create a user for the given UserInfo. Calls the related userSource + get_or_create_user. + """ + + user_source = auth_provider.user_source.specific + return user_source.get_type().get_or_create_user( + user_source, email=user_info.email, name=user_info.name + ) diff --git a/backend/src/baserow/core/app_auth_providers/models.py b/backend/src/baserow/core/app_auth_providers/models.py index f195e00c7..371ffff13 100644 --- a/backend/src/baserow/core/app_auth_providers/models.py +++ b/backend/src/baserow/core/app_auth_providers/models.py @@ -37,4 +37,4 @@ class AppAuthProvider(BaseAuthProviderModel, HierarchicalModelMixin): return app_auth_provider_type_registry class Meta: - ordering = ["domain", "id"] + ordering = ["id"] diff --git a/backend/src/baserow/core/auth_provider/models.py b/backend/src/baserow/core/auth_provider/models.py index b4833400d..953a3b92d 100644 --- a/backend/src/baserow/core/auth_provider/models.py +++ b/backend/src/baserow/core/auth_provider/models.py @@ -17,8 +17,14 @@ class BaseAuthProviderModel( Base abstract model for app_providers. """ - domain = models.CharField(max_length=255, null=True) - enabled = models.BooleanField(default=True) + domain = models.CharField( + max_length=255, + null=True, + help_text="The email domain registered with this provider.", + ) + enabled = models.BooleanField( + help_text="Whether the provider is enabled or not.", default=True + ) class Meta: abstract = True diff --git a/backend/src/baserow/core/auth_provider/registries.py b/backend/src/baserow/core/auth_provider/registries.py index 66dec45c3..6ff0802c2 100644 --- a/backend/src/baserow/core/auth_provider/registries.py +++ b/backend/src/baserow/core/auth_provider/registries.py @@ -47,13 +47,6 @@ class BaseAuthProviderType( default_create_allowed_fields = ["domain", "enabled"] default_update_allowed_fields = ["domain", "enabled"] - @abstractmethod - def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]: - """ - Returns a dictionary containing the login options - to populate the login component accordingly. - """ - def can_create_new_providers(self, **kwargs) -> bool: """ Returns True if it's possible to create an authentication provider of this type. @@ -249,6 +242,13 @@ class AuthenticationProviderTypeRegistry( super().__init__(*args, **kwargs) self._default = None + @abstractmethod + def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]: + """ + Returns a dictionary containing the login options + to populate the login component accordingly. + """ + def get_all_available_login_options(self): login_options = {} for provider_type in self.get_all(): diff --git a/backend/src/baserow/core/management/commands/list_urls.py b/backend/src/baserow/core/management/commands/list_urls.py new file mode 100644 index 000000000..dc76b2fd3 --- /dev/null +++ b/backend/src/baserow/core/management/commands/list_urls.py @@ -0,0 +1,33 @@ +from django.core.management.base import BaseCommand +from django.urls import URLPattern, URLResolver, get_resolver + + +class Command(BaseCommand): + help = "List all registered full URLs with their namespaces" + + def handle(self, *args, **kwargs): + resolver = get_resolver() + self.list_urls(resolver.url_patterns) + + def list_urls(self, urlpatterns, prefix="", namespace=None): + for pattern in urlpatterns: + if isinstance(pattern, URLPattern): + # Construct the full URL path + full_url = f"{prefix}{pattern.pattern}" + full_namespace = ( + f"{namespace}:{pattern.name}" + if namespace and pattern.name + else pattern.name or "None" + ) + self.stdout.write(f"URL: {full_url}, Namespace: {full_namespace}") + elif isinstance(pattern, URLResolver): + # Construct the full namespace and recurse + new_prefix = f"{prefix}{pattern.pattern}" + new_namespace = ( + f"{namespace}:{pattern.namespace}" + if namespace and pattern.namespace + else pattern.namespace or namespace + ) + self.list_urls( + pattern.url_patterns, prefix=new_prefix, namespace=new_namespace + ) diff --git a/backend/src/baserow/core/migrations/0093_alter_appauthprovider_options_and_more.py b/backend/src/baserow/core/migrations/0093_alter_appauthprovider_options_and_more.py new file mode 100644 index 000000000..95b717880 --- /dev/null +++ b/backend/src/baserow/core/migrations/0093_alter_appauthprovider_options_and_more.py @@ -0,0 +1,48 @@ +# Generated by Django 5.0.9 on 2024-11-19 09:43 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("core", "0092_alter_userprofile_language"), + ] + + operations = [ + migrations.AlterModelOptions( + name="appauthprovider", + options={"ordering": ["id"]}, + ), + migrations.AlterField( + model_name="appauthprovider", + name="domain", + field=models.CharField( + help_text="The email domain registered with this provider.", + max_length=255, + null=True, + ), + ), + migrations.AlterField( + model_name="appauthprovider", + name="enabled", + field=models.BooleanField( + default=True, help_text="Whether the provider is enabled or not." + ), + ), + migrations.AlterField( + model_name="authprovidermodel", + name="domain", + field=models.CharField( + help_text="The email domain registered with this provider.", + max_length=255, + null=True, + ), + ), + migrations.AlterField( + model_name="authprovidermodel", + name="enabled", + field=models.BooleanField( + default=True, help_text="Whether the provider is enabled or not." + ), + ), + ] diff --git a/backend/src/baserow/core/registries.py b/backend/src/baserow/core/registries.py index facca86a5..831186292 100755 --- a/backend/src/baserow/core/registries.py +++ b/backend/src/baserow/core/registries.py @@ -507,6 +507,13 @@ class ApplicationType( def enhance_queryset(self, queryset): return queryset + def get_default_application_urls(self, application: "Application") -> list[str]: + """ + Returns the default frontend urls of the application if any. + """ + + return [] + ApplicationSubClassInstance = TypeVar( "ApplicationSubClassInstance", bound="Application" diff --git a/backend/src/baserow/core/registry.py b/backend/src/baserow/core/registry.py index cd1e507f9..a82f644d4 100644 --- a/backend/src/baserow/core/registry.py +++ b/backend/src/baserow/core/registry.py @@ -276,13 +276,95 @@ class CustomFieldsInstanceMixin: return None +class PublicCustomFieldsInstanceMixin(CustomFieldsInstanceMixin): + """ + A mixin for instance with custom fields but some field should remains private + when used in some APIs. + """ + + public_serializer_field_names = None + """The field names that must be added to the serializer if it's public.""" + + public_request_serializer_field_names = None + """ + The field names that must be added to the public request serializer if different + from the `public_serializer_field_names`. + """ + + request_serializer_field_overrides = None + """ + The fields that must be added to the request serializer if different from the + `serializer_field_overrides` property. + """ + + public_serializer_field_overrides = None + """The fields that must be added to the public serializer.""" + + public_request_serializer_field_overrides = None + """ + The fields that must be added to the public request serializer if different from the + `public_serializer_field_overrides` property. + """ + + def get_field_overrides( + self, request_serializer: bool, extra_params=None, **kwargs + ) -> Dict: + public = extra_params.get("public", False) + + if public: + if ( + request_serializer is not None + and self.public_request_serializer_field_overrides is not None + ): + return self.public_request_serializer_field_overrides + if self.public_serializer_field_overrides is not None: + return self.public_serializer_field_overrides + + return super().get_field_overrides(request_serializer, extra_params, **kwargs) + + def get_field_names( + self, request_serializer: bool, extra_params=None, **kwargs + ) -> List[str]: + public = extra_params.get("public", False) + + if public: + if ( + request_serializer is not None + and self.public_request_serializer_field_names is not None + ): + return self.public_request_serializer_field_names + if self.public_serializer_field_names is not None: + return self.public_serializer_field_names + + return super().get_field_names(request_serializer, extra_params, **kwargs) + + def get_meta_ref_name( + self, + request_serializer: bool, + extra_params=None, + **kwargs, + ) -> Optional[str]: + meta_ref_name = super().get_meta_ref_name( + request_serializer, extra_params, **kwargs + ) + + public = extra_params.get("public", False) + + if public: + meta_ref_name = f"Public{meta_ref_name}" + + return meta_ref_name + + class APIUrlsInstanceMixin: - def get_api_urls(self): + def get_api_urls(self) -> List: """ If needed custom api related urls to the instance can be added here. Example: + from django.urls import include, path + def get_api_urls(self): from . import api_urls @@ -298,7 +380,6 @@ class APIUrlsInstanceMixin: ] :return: A list containing the urls. - :rtype: list """ return [] diff --git a/backend/src/baserow/core/user_sources/registries.py b/backend/src/baserow/core/user_sources/registries.py index 7c4ebe6bd..d8673cf11 100644 --- a/backend/src/baserow/core/user_sources/registries.py +++ b/backend/src/baserow/core/user_sources/registries.py @@ -1,6 +1,6 @@ import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar from django.contrib.auth.models import AbstractUser @@ -16,6 +16,7 @@ from baserow.core.registry import ( ModelRegistryMixin, Registry, ) +from baserow.core.user.exceptions import UserNotFound from baserow.core.user_sources.constants import DEFAULT_USER_ROLE_PREFIX from baserow.core.user_sources.user_source_user import UserSourceUser @@ -233,15 +234,41 @@ class UserSourceType( """ @abstractmethod - def get_user(self, user_source: UserSource, **kwargs) -> Optional[UserSourceUser]: + def create_user( + self, user_source: UserSource, email: str, name: str + ) -> UserSourceUser: + """ + Create a user for the given user source instance from it's email and name. + + :param user_source: The user source we want to create the user for. + :param email: Email of the user to create. + :param name: Name of the user to create. + :return: A user instance. + """ + + @abstractmethod + def get_user(self, user_source: UserSource, **kwargs) -> UserSourceUser: """ Returns a user given some args. :param user_source: The user source used to get the user. :param kwargs: Keyword arguments to get the user. + :raises UserNotFound: When the user can't be found. :return: A user instance if any found with the given parameters. """ + def get_or_create_user( + self, user_source: UserSource, email: str, name: str + ) -> Tuple[UserSourceUser, bool]: + """ + Shorthand to create a user if he doesn't exist. + """ + + try: + return self.get_user(user_source, email=email), False + except UserNotFound: + return self.create_user(user_source, email, name), True + @abstractmethod def authenticate(self, user_source: UserSource, **kwargs) -> UserSourceUser: """ diff --git a/backend/src/baserow/test_utils/pytest_conftest.py b/backend/src/baserow/test_utils/pytest_conftest.py index e92a0fa89..88f086de8 100755 --- a/backend/src/baserow/test_utils/pytest_conftest.py +++ b/backend/src/baserow/test_utils/pytest_conftest.py @@ -265,6 +265,9 @@ def stub_user_source_registry(data_fixture, mutable_user_source_registry, fake): return get_user_return return data_fixture.create_user_source_user(user_source=user_source) + def create_user(self, user_source, email, name): + return data_fixture.create_user_source_user(user_source=user_source) + def authenticate(self, user_source, **kwargs): if authenticate_return: if callable(authenticate_return): diff --git a/backend/tests/baserow/api/user_sources/test_user_source_views.py b/backend/tests/baserow/api/user_sources/test_user_source_views.py index 6d6571e74..a8ea1638d 100644 --- a/backend/tests/baserow/api/user_sources/test_user_source_views.py +++ b/backend/tests/baserow/api/user_sources/test_user_source_views.py @@ -193,7 +193,10 @@ def test_create_user_source_w_auth_providers(api_client, data_fixture): "name": "test", "integration_id": integration.id, "auth_providers": [ - {"type": "local_baserow_password", "enabled": False, "domain": "test1"}, + { + "type": "local_baserow_password", + "enabled": False, + }, ], }, format="json", @@ -208,14 +211,87 @@ def test_create_user_source_w_auth_providers(api_client, data_fixture): assert response_json["auth_providers"] == [ { - "domain": "test1", "id": first.id, "password_field_id": None, "type": "local_baserow_password", + "domain": None, }, ] +@pytest.mark.django_db +def test_create_user_source_w_auth_providers_w_domain(api_client, data_fixture): + user, token = data_fixture.create_user_and_token() + workspace = data_fixture.create_workspace(user=user) + application = data_fixture.create_builder_application(workspace=workspace) + integration = data_fixture.create_local_baserow_integration(application=application) + + url = reverse("api:user_sources:list", kwargs={"application_id": application.id}) + response = api_client.post( + url, + { + "type": "local_baserow", + "name": "test", + "integration_id": integration.id, + "auth_providers": [ + { + "domain": "domain.com", + "type": "local_baserow_password", + "enabled": False, + }, + ], + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + response_json = response.json() + assert response.status_code == HTTP_200_OK + + assert AppAuthProvider.objects.count() == 1 + first = AppAuthProvider.objects.first() + + assert response_json["auth_providers"] == [ + { + "id": first.id, + "password_field_id": None, + "type": "local_baserow_password", + "domain": "domain.com", + }, + ] + + +@pytest.mark.django_db +def test_create_user_source_w_auth_providers_w_wrong_domain(api_client, data_fixture): + user, token = data_fixture.create_user_and_token() + workspace = data_fixture.create_workspace(user=user) + application = data_fixture.create_builder_application(workspace=workspace) + integration = data_fixture.create_local_baserow_integration(application=application) + + url = reverse("api:user_sources:list", kwargs={"application_id": application.id}) + response = api_client.post( + url, + { + "type": "local_baserow", + "name": "test", + "integration_id": integration.id, + "auth_providers": [ + { + "domain": "baddomain", + "type": "local_baserow_password", + "enabled": False, + }, + ], + }, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + response_json = response.json() + assert response.status_code == HTTP_400_BAD_REQUEST + assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION" + + @pytest.mark.django_db def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture): user, token = data_fixture.create_user_and_token() @@ -237,7 +313,6 @@ def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture) "integration_id": integration.id, "auth_providers": [ { - "domain": "test_domain", "enabled": True, "type": app_auth_provider_type.type, }, @@ -270,7 +345,6 @@ def test_create_user_source_w_auth_provider_missing_type(api_client, data_fixtur "integration_id": integration.id, "auth_providers": [ { - "domain": "test_domain", "enabled": True, "type": "bad_type", }, @@ -374,7 +448,11 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture): url, { "auth_providers": [ - {"type": "local_baserow_password", "enabled": False, "domain": "test1"}, + { + "type": "local_baserow_password", + "enabled": False, + "domain": "test2.com", + }, ], }, format="json", @@ -388,7 +466,7 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture): assert response.json()["auth_providers"] == [ { - "domain": "test1", + "domain": "test2.com", "id": first.id, "password_field_id": None, "type": "local_baserow_password", @@ -399,7 +477,11 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture): url, { "auth_providers": [ - {"type": "local_baserow_password", "enabled": False, "domain": "test3"}, + { + "type": "local_baserow_password", + "enabled": False, + "domain": "test3.com", + }, ], }, format="json", @@ -411,7 +493,7 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture): assert response.json()["auth_providers"] == [ { - "domain": "test3", + "domain": "test3.com", "id": first.id, "password_field_id": None, "type": "local_baserow_password", diff --git a/backend/tests/baserow/contrib/builder/domains/test_domain_handler.py b/backend/tests/baserow/contrib/builder/domains/test_domain_handler.py index 347135208..7617009f6 100644 --- a/backend/tests/baserow/contrib/builder/domains/test_domain_handler.py +++ b/backend/tests/baserow/contrib/builder/domains/test_domain_handler.py @@ -196,3 +196,34 @@ def test_domain_publishing(data_fixture): DomainHandler().publish(domain1, progress) assert Builder.objects.count() == 2 + + +@pytest.mark.django_db +def test_get_domain_for_builder(data_fixture): + user = data_fixture.create_user() + builder = data_fixture.create_builder_application(user=user) + builder_to = data_fixture.create_builder_application(workspace=None) + domain1 = data_fixture.create_builder_custom_domain( + builder=builder, published_to=builder_to, domain_name="mytest.com" + ) + domain2 = data_fixture.create_builder_custom_domain( + builder=builder, domain_name="mytest2.com" + ) + + assert ( + DomainHandler().get_domain_for_builder(builder_to).domain_name == "mytest.com" + ) + + assert DomainHandler().get_domain_for_builder(builder) is None + + +@pytest.mark.django_db +def test_get_domain_public_url(data_fixture): + user = data_fixture.create_user() + builder = data_fixture.create_builder_application(user=user) + builder_to = data_fixture.create_builder_application(workspace=None) + domain1 = data_fixture.create_builder_custom_domain( + builder=builder, published_to=builder_to, domain_name="mytest.com" + ) + + assert domain1.get_public_url() == "http://mytest.com:3000" diff --git a/backend/tests/baserow/contrib/builder/test_builder_application_type.py b/backend/tests/baserow/contrib/builder/test_builder_application_type.py index d52c29490..cb6645304 100644 --- a/backend/tests/baserow/contrib/builder/test_builder_application_type.py +++ b/backend/tests/baserow/contrib/builder/test_builder_application_type.py @@ -1636,3 +1636,21 @@ def test_builder_application_exports_file_with_zip_file( serialized_image_element = visible_pages[0]["elements"][0] assert serialized_image_element["image_source_type"] == "upload" assert serialized_image_element["image_file_id"] == serialized_file + + +@pytest.mark.django_db +def test_get_default_application_urls(data_fixture): + user = data_fixture.create_user() + builder = data_fixture.create_builder_application(user=user) + builder_to = data_fixture.create_builder_application(workspace=None) + domain1 = data_fixture.create_builder_custom_domain( + builder=builder, published_to=builder_to, domain_name="mytest.com" + ) + + assert builder.get_type().get_default_application_urls(builder) == [ + f"http://localhost:3000/builder/{builder.id}/preview/" + ] + assert builder_to.get_type().get_default_application_urls(builder_to) == [ + "http://mytest.com:3000", + f"http://localhost:3000/builder/{builder.id}/preview/", + ] diff --git a/enterprise/backend/src/baserow_enterprise/api/admin/auth_provider/serializers.py b/enterprise/backend/src/baserow_enterprise/api/admin/auth_provider/serializers.py index 15ce0e34f..6bca33d06 100644 --- a/enterprise/backend/src/baserow_enterprise/api/admin/auth_provider/serializers.py +++ b/enterprise/backend/src/baserow_enterprise/api/admin/auth_provider/serializers.py @@ -14,7 +14,7 @@ class CreateAuthProviderSerializer(serializers.ModelSerializer): class Meta: model = AuthProviderModel - fields = ("domain", "type") + fields = ("domain", "type", "enabled") class UpdateAuthProviderSerializer(serializers.ModelSerializer): @@ -29,7 +29,7 @@ class UpdateAuthProviderSerializer(serializers.ModelSerializer): class Meta: model = AuthProviderModel - fields = ("domain", "type") + fields = ("domain", "type", "enabled") extra_kwargs = { "domain": {"required": False}, } diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/__init__.py b/enterprise/backend/src/baserow_enterprise/api/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/common/__init__.py b/enterprise/backend/src/baserow_enterprise/api/integrations/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/__init__.py b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/__init__.py b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/serializers.py b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/serializers.py new file mode 100644 index 000000000..c29a24065 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/serializers.py @@ -0,0 +1,15 @@ +from rest_framework import serializers + +from baserow_enterprise.api.sso.saml.serializers import SAMLResponseSerializer +from baserow_enterprise.api.sso.serializers import BaseSsoLoginRequestSerializer + + +class CommonSsoLoginRequestSerializer(BaseSsoLoginRequestSerializer): + next = serializers.CharField( + required=False, + help_text="If provided, the user will be redirected to that path after login", + ) + + +class CommonSAMLResponseSerializer(SAMLResponseSerializer): + query_param_serializer = CommonSsoLoginRequestSerializer diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/urls.py b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/urls.py new file mode 100644 index 000000000..87b9c6b12 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/urls.py @@ -0,0 +1,19 @@ +from django.urls import re_path + +from .views import ( + SamlAppAuthProviderAssertionConsumerServiceView, + SamlAppAuthProviderBaserowInitiatedSingleSignOn, +) + +app_name = "baserow_enterprise.api.integrations.common.sso.saml" + +urlpatterns = [ + re_path( + r"acs/$", SamlAppAuthProviderAssertionConsumerServiceView.as_view(), name="acs" + ), + re_path( + r"login/$", + SamlAppAuthProviderBaserowInitiatedSingleSignOn.as_view(), + name="login", + ), +] diff --git a/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/views.py b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/views.py new file mode 100644 index 000000000..dc86c3588 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/api/integrations/common/sso/saml/views.py @@ -0,0 +1,245 @@ +from django.db import transaction +from django.http import HttpResponseRedirect +from django.shortcuts import redirect + +from drf_spectacular.openapi import OpenApiParameter, OpenApiTypes +from drf_spectacular.utils import extend_schema +from rest_framework.permissions import AllowAny +from rest_framework.request import Request +from rest_framework.views import APIView + +from baserow.api.decorators import map_exceptions +from baserow.api.exceptions import ( + QueryParameterValidationException, + RequestBodyValidationException, +) +from baserow.api.user_sources.errors import ERROR_USER_SOURCE_DOES_NOT_EXIST +from baserow.api.utils import validate_data +from baserow.core.user.exceptions import DeactivatedUserException +from baserow.core.user_sources.exceptions import ( + UserSourceDoesNotExist, + UserSourceImproperlyConfigured, +) +from baserow.core.user_sources.handler import UserSourceHandler +from baserow_enterprise.api.integrations.common.sso.saml.serializers import ( + CommonSAMLResponseSerializer, + CommonSsoLoginRequestSerializer, +) +from baserow_enterprise.api.sso.serializers import BaseSsoLoginRequestSerializer +from baserow_enterprise.api.sso.utils import ( + SsoErrorCode, + get_valid_frontend_url, + map_sso_exceptions, + urlencode_query_params, +) +from baserow_enterprise.integrations.common.sso.saml.handler import ( + SamlAppAuthProviderHandler, +) +from baserow_enterprise.integrations.common.sso.saml.models import ( + SamlAppAuthProviderModel, +) +from baserow_enterprise.sso.saml.exceptions import ( + InvalidSamlConfiguration, + InvalidSamlRequest, + InvalidSamlResponse, +) + + +class SamlAppAuthProviderAssertionConsumerServiceView(APIView): + permission_classes = (AllowAny,) + + @extend_schema( + tags=["Auth"], + request=CommonSAMLResponseSerializer, + operation_id="auth_provider_saml_acs_url", + description=( + "Complete the SAML authentication flow by validating the SAML response. " + "Sign in the user if already exists in user_source or create a new one " + "otherwise." + "Once authenticated, the user will be redirected to the original " + "URL they were trying to access. If the response is invalid, the user " + "will be redirected to an error page with a specific error message." + "It accepts the language code and the workspace invitation token as query " + "parameters if provided." + ), + responses={302: None}, + auth=[], + ) + @transaction.atomic + @map_exceptions( + { + UserSourceDoesNotExist: ERROR_USER_SOURCE_DOES_NOT_EXIST, + } + ) + def post( + self, + request: Request, + user_source_uid, + ) -> HttpResponseRedirect: + user_source = UserSourceHandler().get_user_source_by_uid(user_source_uid) + + default_frontend_urls = ( + user_source.application.get_type().get_default_application_urls( + user_source.application.specific + ) + ) + + error_raised = {"code": None} + + def on_error(error_code): + error_raised["code"] = error_code + + with map_sso_exceptions( + { + InvalidSamlConfiguration: SsoErrorCode.INVALID_SAML_RESPONSE, + InvalidSamlResponse: SsoErrorCode.INVALID_SAML_RESPONSE, + DeactivatedUserException: SsoErrorCode.USER_DEACTIVATED, + RequestBodyValidationException: SsoErrorCode.INVALID_SAML_RESPONSE, + UserSourceDoesNotExist: SsoErrorCode.INVALID_SAML_REQUEST, + UserSourceImproperlyConfigured: SsoErrorCode.INVALID_SAML_REQUEST, + }, + on_error=on_error, + ): + # We can't use the decorator here because the redirect_url is related + # to the user source and we don't have it before. + data = validate_data( + CommonSAMLResponseSerializer, + request.data, + return_validated=True, + ) + + next_path = data["saml_request_data"].pop("next", None) + + user = SamlAppAuthProviderHandler.sign_in_user_from_saml_response( + data["SAMLResponse"], + data["saml_request_data"], + base_queryset=SamlAppAuthProviderModel.objects.filter( + user_source=user_source + ), + ) + + if error_raised["code"]: + # We redirect to the default frontend url with an error code + error_url = urlencode_query_params( + default_frontend_urls[0], + {f"saml_error__{user_source.id}": error_raised["code"].value}, + ) + return redirect(error_url) + + query_params = { + f"user_source_saml_token__{user_source.id}": user.get_refresh_token() + } + + if next_path: + query_params["next"] = next_path + + # Otherwise it a success, we redirect to the login page + redirect_url = get_valid_frontend_url( + data["RelayState"], + default_frontend_urls=default_frontend_urls, + # Add the refresh token as query parameter + query_params=query_params, + allow_any_path=False, + ) + + return redirect(redirect_url) + + +class SamlAppAuthProviderBaserowInitiatedSingleSignOn(APIView): + permission_classes = (AllowAny,) + + @extend_schema( + parameters=[ + OpenApiParameter( + name="email", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description="The email address of the user that want to sign in using SAML.", + ), + OpenApiParameter( + name="original", + location=OpenApiParameter.QUERY, + type=OpenApiTypes.STR, + description=( + "The url to which the user should be redirected after a successful " + "login or sign up." + ), + ), + ], + tags=["User sources"], + request=BaseSsoLoginRequestSerializer, + operation_id="app_auth_provider_saml_sp_login", + description=( + "This is the endpoint that is called when the user wants to initiate a " + "SSO SAML login from Baserow (the service provider). The user will be " + "redirected to the SAML identity provider (IdP) where the user " + "can authenticate. " + "Once logged in in the IdP, the user will be redirected back " + "to the assertion consumer service endpoint (ACS) where the SAML response " + "will be validated and a new JWT session token will be provided to work " + "with Baserow APIs." + ), + responses={302: None}, + auth=[], + ) + @map_exceptions( + { + UserSourceDoesNotExist: ERROR_USER_SOURCE_DOES_NOT_EXIST, + } + ) + def get(self, request: Request, user_source_uid: str) -> HttpResponseRedirect: + user_source = UserSourceHandler().get_user_source_by_uid(user_source_uid) + + default_frontend_urls = ( + user_source.application.get_type().get_default_application_urls( + user_source.application.specific + ) + ) + + error_raised = {"code": None} + + def on_error(error_code): + error_raised["code"] = error_code + + with map_sso_exceptions( + { + InvalidSamlConfiguration: SsoErrorCode.INVALID_SAML_REQUEST, + InvalidSamlRequest: SsoErrorCode.INVALID_SAML_REQUEST, + RequestBodyValidationException: SsoErrorCode.INVALID_SAML_REQUEST, + }, + on_error=on_error, + ): + # Validate query parameters + query_params = validate_data( + CommonSsoLoginRequestSerializer, + request.GET.dict(), + partial=False, + exception_to_raise=QueryParameterValidationException, + return_validated=True, + ) + + original_url = query_params.pop("original", "") + valid_relay_state_url = get_valid_frontend_url( + original_url, + query_params, + default_frontend_urls=default_frontend_urls, + allow_any_path=False, + ) + + idp_sign_in_url = SamlAppAuthProviderHandler.get_sign_in_url( + query_params, + SamlAppAuthProviderHandler.model_class.objects.filter( + user_source__uid=user_source.uid + ), + redirect_to=valid_relay_state_url, + ) + + if error_raised["code"]: + # We redirect to the default frontend url with an error code + error_url = urlencode_query_params( + default_frontend_urls[0], + {f"saml_error__{user_source.id}": error_raised["code"].value}, + ) + return redirect(error_url) + + return redirect(idp_sign_in_url) diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/saml/serializers.py b/enterprise/backend/src/baserow_enterprise/api/sso/saml/serializers.py index 4916a4155..2091c14f7 100644 --- a/enterprise/backend/src/baserow_enterprise/api/sso/saml/serializers.py +++ b/enterprise/backend/src/baserow_enterprise/api/sso/saml/serializers.py @@ -7,6 +7,8 @@ from baserow_enterprise.sso.saml.exceptions import InvalidSamlResponse class SAMLResponseSerializer(serializers.Serializer): + query_param_serializer = SsoLoginRequestSerializer + SAMLResponse = serializers.CharField( required=True, help_text="The encoded SAML response from the IdP." ) @@ -25,11 +27,12 @@ class SAMLResponseSerializer(serializers.Serializer): parsed_relay_state = urlparse(relay_state) query_params = dict(parse_qsl(parsed_relay_state.query)) if query_params: - request_data_serializer = SsoLoginRequestSerializer(data=query_params) + request_data_serializer = self.query_param_serializer(data=query_params) if request_data_serializer.is_valid(): data["saml_request_data"] = request_data_serializer.validated_data else: raise InvalidSamlResponse("Invalid RelayState query parameters.") + data["RelayState"] = parsed_relay_state._replace(query="").geturl() return data diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py b/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py index ef2f39d2d..66f8b1522 100644 --- a/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py +++ b/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py @@ -1,5 +1,7 @@ import io +from django.db.models import QuerySet + from rest_framework import serializers from saml2.xml.schema import XMLSchemaError from saml2.xml.schema import validate as validate_saml_metadata_schema @@ -9,14 +11,17 @@ from baserow_enterprise.sso.saml.models import SamlAuthProviderModel def validate_unique_saml_domain( - domain, instance=None, model_class=SamlAuthProviderModel + domain, instance=None, base_queryset: QuerySet | None = None ): - queryset = model_class.objects.filter(domain=domain) + if base_queryset is None: + base_queryset = SamlAuthProviderModel.objects + + queryset = base_queryset.filter(domain=domain) if instance: queryset = queryset.exclude(id=instance.id) if queryset.exists(): raise SamlProviderForDomainAlreadyExists( - f"There is already a {model_class.__name__} for this domain." + "There is already a provider for this domain." ) return domain diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/saml/views.py b/enterprise/backend/src/baserow_enterprise/api/sso/saml/views.py index 409f0a75a..07589516d 100644 --- a/enterprise/backend/src/baserow_enterprise/api/sso/saml/views.py +++ b/enterprise/backend/src/baserow_enterprise/api/sso/saml/views.py @@ -230,7 +230,7 @@ class AdminAuthProvidersLoginUrlView(APIView): ) saml_login_url = urljoin( - settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:login") + settings.PUBLIC_BACKEND_URL, reverse("api:enterprise_sso_saml:login") ) saml_login_url = urlencode_query_params(saml_login_url, query_params) return Response({"redirect_url": saml_login_url}) diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/serializers.py b/enterprise/backend/src/baserow_enterprise/api/sso/serializers.py index 8667b6a7c..baee5de61 100644 --- a/enterprise/backend/src/baserow_enterprise/api/sso/serializers.py +++ b/enterprise/backend/src/baserow_enterprise/api/sso/serializers.py @@ -7,13 +7,13 @@ from baserow.api.user.serializers import NormalizedEmailField from baserow.api.user.validators import language_validation -class SsoLoginRequestSerializer(serializers.Serializer): +class BaseSsoLoginRequestSerializer(serializers.Serializer): email = NormalizedEmailField( required=False, help_text="The email address of the user." ) original = serializers.CharField( required=False, - help_text="The relative part of URL that the user wanted to access.", + help_text="The original URL that the user wanted to access.", ) language = serializers.CharField( required=False, @@ -23,6 +23,13 @@ class SsoLoginRequestSerializer(serializers.Serializer): help_text="An ISO 639 language code (with optional variant) " "selected by the user. Ex: en-GB.", ) + + def to_internal_value(self, instance) -> Dict[str, str]: + data = super().to_internal_value(instance) + return {k: v for k, v in data.items() if v is not None} + + +class SsoLoginRequestSerializer(BaseSsoLoginRequestSerializer): workspace_invitation_token = serializers.CharField( required=False, help_text="If provided and valid, the user accepts the workspace invitation and" @@ -34,8 +41,5 @@ class SsoLoginRequestSerializer(serializers.Serializer): if urlparse(value).hostname: return None - return value - def to_internal_value(self, instance) -> Dict[str, str]: - data = super().to_internal_value(instance) - return {k: v for k, v in data.items() if v is not None} + return value diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/urls.py b/enterprise/backend/src/baserow_enterprise/api/sso/urls.py deleted file mode 100644 index 26e23a3d6..000000000 --- a/enterprise/backend/src/baserow_enterprise/api/sso/urls.py +++ /dev/null @@ -1,11 +0,0 @@ -from django.urls import include, path - -from .oauth2 import urls as oauth2_urls -from .saml import urls as saml_urls - -app_name = "baserow_enterprise.api.sso" - -urlpatterns = [ - path("saml/", include(saml_urls, namespace="saml")), - path("oauth2/", include(oauth2_urls, namespace="oauth2")), -] diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/utils.py b/enterprise/backend/src/baserow_enterprise/api/sso/utils.py index 0e31c41c8..13620285c 100644 --- a/enterprise/backend/src/baserow_enterprise/api/sso/utils.py +++ b/enterprise/backend/src/baserow_enterprise/api/sso/utils.py @@ -1,5 +1,7 @@ +from contextlib import ContextDecorator from enum import Enum -from typing import Dict, Optional +from functools import wraps +from typing import Callable, Dict, Optional, Type from urllib.parse import urljoin, urlparse from django.conf import settings @@ -25,31 +27,58 @@ class SsoErrorCode(Enum): SIGNUP_DISABLED = "errorSignupDisabled" -def map_sso_exceptions(mapping: Dict[Exception, SsoErrorCode]): +class map_sso_exceptions(ContextDecorator): """ - This decorator can be used to map exceptions to SSO error codes. If the - decorated function raises an exception that is in the mapping, the - redirect_to_sign_in_error_page() function will be called with the mapped - error code. If the exception is not in the mapping, it will be raised - normally. + A context manager and decorator to map exceptions to SSO error codes. If the + enclosed code block or decorated function raises an exception that is in the + mapping, the provided redirect function will be called with the mapped error code. + If the exception is not in the mapping, it will be raised normally. :param mapping: A dictionary that maps exceptions to SSO error codes. - :return: The decorator. + :param on_error: A callable that takes an error code and handles the action. """ - def decorator(func): - def wrapper(*args, **kwargs): + def __init__( + self, + mapping: Dict[Type[Exception], str], + on_error: Callable[[str], None] | None = None, + ): + self.mapping = mapping + if on_error is None: + self.on_error = redirect_to_sign_in_error_page + else: + self.on_error = on_error + + def __enter__(self): + pass + + # Context manager + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + # No exception occurred + return False + + for exception, error_code in self.mapping.items(): + if isinstance(exc_value, exception): + self.on_error(error_code) + return True # Swallow the exception after handling it + + # If exception not handled, propagate it + return False + + # Decorator version + def __call__(self, func): + @wraps(func) + def wrapped_function(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - for exception, error_code in mapping.items(): + for exception, error_code in self.mapping.items(): if isinstance(e, exception): - return redirect_to_sign_in_error_page(error_code) + return self.on_error(error_code) raise e - return wrapper - - return decorator + return wrapped_function def urlencode_query_params(url: str, query_params: Dict[str, str]) -> str: @@ -79,6 +108,7 @@ def redirect_to_sign_in_error_page( """ frontend_error_page_url = get_frontend_login_error_url() + if error_code: frontend_error_page_url = urlencode_query_params( frontend_error_page_url, {"error": error_code.value} @@ -89,40 +119,64 @@ def redirect_to_sign_in_error_page( def get_valid_frontend_url( requested_original_url: Optional[str] = None, query_params: Optional[Dict[str, str]] = None, -) -> str: + default_frontend_urls: list[str] | None = None, + allow_any_path: bool = True, +): """ Returns a valid absolute frontend url based on the original url requested before the redirection to the login (can be relative or absolute). If the - original url is relative, it will be prefixed with the frontend hostname to - make the IdP redirection work. If the original url is external to Baserow, - the default frontend dashboard url will be returned instead. + original url is relative, it will be prefixed with the default hostname to + make the IdP redirection work. If the original url doesn't match any of the given + default_front_urls, the first default frontend url will be used instead. :param requested_original_url: The url to which the user should be redirected after a successful login. + :param query_params: The query parameters to add to the URL. + :param default_frontend_urls: The first one is the default fallback frontend URL. + Baserow one is used if None. Others are also allowed as valid URLs. :return: The url with the token as a query parameter. """ requested_url_parsed = urlparse(requested_original_url or "") - default_frontend_url_parsed = urlparse(get_frontend_default_redirect_url()) - if requested_url_parsed.path in ["", "/"]: - # use the default frontend path if the requested one is empty - requested_url_parsed = requested_url_parsed._replace( - path=default_frontend_url_parsed.path - ) + if default_frontend_urls is None: + default_frontend_urls = [get_frontend_default_redirect_url()] + + default_frontend_urls_parsed = [urlparse(u) for u in default_frontend_urls] + default_frontend_url_parsed = default_frontend_urls_parsed[0] + + matching_url = default_frontend_url_parsed + if requested_url_parsed.hostname is None: # provide a correct absolute url if the requested one is relative requested_url_parsed = default_frontend_url_parsed._replace( path=requested_url_parsed.path ) - elif requested_url_parsed.hostname != default_frontend_url_parsed.hostname: - # return the default url if the requested url is external to Baserow - requested_url_parsed = default_frontend_url_parsed + + else: + found = False + for allowed_url in default_frontend_urls_parsed: + if requested_url_parsed.hostname == allowed_url.hostname: + matching_url = allowed_url + found = True + + if not found: + # None are matching -> redirecting to main homepage + requested_url_parsed = default_frontend_url_parsed + matching_url = default_frontend_url_parsed + + if allow_any_path: + if requested_url_parsed.path in ["", "/"]: + # use the default frontend path if the requested one is empty + requested_url_parsed = requested_url_parsed._replace(path=matching_url.path) + elif not requested_url_parsed.geturl().startswith(matching_url.geturl()): + # if using a path that doesn't match the allowed urls, we reset to default url + requested_url_parsed = matching_url if query_params: return urlencode_query_params(requested_url_parsed.geturl(), query_params) - return str(requested_url_parsed.geturl()) + return requested_url_parsed.geturl() def urlencode_user_token(frontend_url: str, user: AbstractUser) -> str: diff --git a/enterprise/backend/src/baserow_enterprise/api/urls.py b/enterprise/backend/src/baserow_enterprise/api/urls.py index ee78f1291..d4c23b465 100644 --- a/enterprise/backend/src/baserow_enterprise/api/urls.py +++ b/enterprise/backend/src/baserow_enterprise/api/urls.py @@ -4,7 +4,6 @@ from .admin import urls as admin_urls from .audit_log import urls as audit_log_urls from .role import urls as role_urls from .secure_file_serve import urls as secure_file_serve_urls -from .sso import urls as sso_urls from .teams import urls as teams_urls app_name = "baserow_enterprise.api" @@ -13,7 +12,6 @@ urlpatterns = [ path("teams/", include(teams_urls, namespace="teams")), path("role/", include(role_urls, namespace="role")), path("admin/", include(admin_urls, namespace="admin")), - path("sso/", include(sso_urls, namespace="sso")), path("audit-log/", include(audit_log_urls, namespace="audit_log")), path("files/", include(secure_file_serve_urls, namespace="files")), ] diff --git a/enterprise/backend/src/baserow_enterprise/apps.py b/enterprise/backend/src/baserow_enterprise/apps.py index 97b10d1ad..75287770d 100755 --- a/enterprise/backend/src/baserow_enterprise/apps.py +++ b/enterprise/backend/src/baserow_enterprise/apps.py @@ -174,6 +174,12 @@ class BaserowEnterpriseConfig(AppConfig): LocalBaserowPasswordAppAuthProviderType() ) + from baserow_enterprise.integrations.common.sso.saml.app_auth_provider_types import ( + SamlAppAuthProviderType, + ) + + app_auth_provider_type_registry.register(SamlAppAuthProviderType()) + from baserow.contrib.builder.elements.registries import element_type_registry from baserow_enterprise.builder.elements.element_types import ( AuthFormElementType, diff --git a/enterprise/backend/src/baserow_enterprise/integrations/common/__init__.py b/enterprise/backend/src/baserow_enterprise/integrations/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/integrations/common/sso/__init__.py b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/__init__.py b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/app_auth_provider_types.py b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/app_auth_provider_types.py new file mode 100644 index 000000000..c02902df2 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/app_auth_provider_types.py @@ -0,0 +1,113 @@ +from typing import List +from urllib.parse import urljoin + +from django.conf import settings +from django.urls import include, path, reverse + +from baserow.core.app_auth_providers.auth_provider_types import AppAuthProviderType +from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict +from baserow_enterprise.api.sso.saml.validators import validate_unique_saml_domain +from baserow_enterprise.integrations.local_baserow.user_source_types import ( + LocalBaserowUserSourceType, +) +from baserow_enterprise.sso.saml.auth_provider_types import SamlAuthProviderTypeMixin + +from .models import SamlAppAuthProviderModel + + +class SamlAppAuthProviderType(SamlAuthProviderTypeMixin, AppAuthProviderType): + """ + The SAML authentication provider type allows users to login using SAML. + """ + + model_class = SamlAppAuthProviderModel + + compatible_user_source_types = [LocalBaserowUserSourceType.type] + + class SerializedDict( + AppAuthProviderTypeDict, SamlAuthProviderTypeMixin.SamlSerializedDict + ): + ... + + @property + def allowed_fields(self) -> List[str]: + return SamlAuthProviderTypeMixin.saml_allowed_fields + + @property + def serializer_field_names(self): + return SamlAuthProviderTypeMixin.saml_serializer_field_names + + public_serializer_field_names = [] + + @property + def serializer_field_overrides(self): + return SamlAuthProviderTypeMixin.saml_serializer_field_overrides + + public_serializer_field_overrides = {} + + def get_api_urls(self): + from baserow_enterprise.api.integrations.common.sso.saml import urls + + return [ + path( + "user-source/<str:user_source_uid>/sso/saml/", + include(urls, namespace="sso_saml"), + ) + ] + + def before_create(self, user, **values): + user_source = values["user_source"] + if "domain" in values: + validate_unique_saml_domain( + values["domain"], + base_queryset=SamlAppAuthProviderModel.objects.filter( + user_source=user_source + ), + ) + return super().before_create(user, **values) + + def before_update(self, user, provider, **values): + if "domain" in values: + user_source = values.get("user_source", provider.user_source) + validate_unique_saml_domain( + values["domain"], + provider, + base_queryset=SamlAppAuthProviderModel.objects.filter( + user_source=user_source + ), + ) + return super().before_update(user, provider, **values) + + @classmethod + def get_acs_absolute_url( + cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None + ): + """ + Returns the ACS url for SAML authentication purpose. The user is redirected + to this URL after a successful login. + """ + + return urljoin( + settings.PUBLIC_BACKEND_URL, + reverse( + "api:user_sources:sso_saml:acs", + kwargs={"user_source_uid": auth_provider.user_source.uid}, + ), + ) + + @classmethod + def get_login_absolute_url( + cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None + ): + """ + Returns the login URL for this auth_provider. The login URL is used to initiate + the Saml login process. + """ + + return urljoin( + settings.PUBLIC_BACKEND_URL, + reverse( + "api:user_sources:sso_saml:login", + kwargs={"user_source_uid": auth_provider.user_source.uid}, + ), + ) diff --git a/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/handler.py b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/handler.py new file mode 100644 index 000000000..209744480 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/handler.py @@ -0,0 +1,8 @@ +from baserow_enterprise.integrations.common.sso.saml.models import ( + SamlAppAuthProviderModel, +) +from baserow_enterprise.sso.saml.handler import SamlAuthProviderHandler + + +class SamlAppAuthProviderHandler(SamlAuthProviderHandler): + model_class = SamlAppAuthProviderModel diff --git a/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/models.py b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/models.py new file mode 100644 index 000000000..f0f2095bc --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/integrations/common/sso/saml/models.py @@ -0,0 +1,8 @@ +from baserow.core.app_auth_providers.models import AppAuthProvider +from baserow_enterprise.sso.saml.models import SamlAuthProviderModelMixin + + +class SamlAppAuthProviderModel(SamlAuthProviderModelMixin, AppAuthProvider): + # Restore ordering + class Meta(AppAuthProvider.Meta): + ordering = ["id"] diff --git a/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/auth_provider_types.py b/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/auth_provider_types.py index 67b798ae1..f475496c5 100644 --- a/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/auth_provider_types.py +++ b/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/auth_provider_types.py @@ -31,6 +31,8 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType): ] serializer_field_names = ["password_field_id"] + public_serializer_field_names = [] + allowed_fields = ["password_field"] serializer_field_overrides = { @@ -40,6 +42,7 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType): help_text="The id of the field to use as password for the user account.", ), } + public_serializer_field_overrides = {} class SerializedDict(AppAuthProviderTypeDict): password_field_id: int @@ -151,13 +154,6 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType): instance.password_field = None instance.save() - def get_login_options(self, **kwargs) -> Dict[str, Any]: - """ - Not implemented yet. - """ - - return {} - def get_or_create_user_and_sign_in( self, auth_provider: AuthProviderModelSubClass, user_info: Dict[str, Any] ) -> Tuple[AbstractUser, bool]: diff --git a/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/user_source_types.py b/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/user_source_types.py index 65d9e538b..cecf93ad4 100644 --- a/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/user_source_types.py +++ b/enterprise/backend/src/baserow_enterprise/integrations/local_baserow/user_source_types.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional from django.contrib.auth.models import AbstractUser +from loguru import logger from rest_framework import serializers from baserow.api.exceptions import RequestBodyValidationException @@ -17,6 +18,7 @@ from baserow.contrib.database.fields.field_types import ( ) from baserow.contrib.database.fields.handler import FieldHandler from baserow.contrib.database.fields.registries import FieldType +from baserow.contrib.database.rows.actions import CreateRowsActionType from baserow.contrib.database.rows.operations import ReadDatabaseRowOperationType from baserow.contrib.database.search.handler import SearchHandler from baserow.contrib.database.table.exceptions import TableDoesNotExist @@ -450,9 +452,9 @@ class LocalBaserowUserSourceType(UserSourceType): return ( f"{user_source.id}" - f"_{user_source.table_id if user_source.table_id else '?'}" - f"_{user_source.email_field_id if user_source.email_field_id else '?'}" - f"_{user_source.role_field_id if user_source.role_field_id else '?'}" + f"_{user_source.table_id if user_source.table_id else '0'}" + f"_{user_source.email_field_id if user_source.email_field_id else '0'}" + f"_{user_source.role_field_id if user_source.role_field_id else '0'}" ) def role_type_is_allowed(self, role_field: Optional[FieldType]) -> bool: @@ -596,6 +598,54 @@ class LocalBaserowUserSourceType(UserSourceType): raise UserNotFound() + def create_user(self, user_source: UserSource, email, name, role=None): + """ + Creates the user in the configured table. + """ + + if not self.is_configured(user_source): + raise UserSourceImproperlyConfigured() + + try: + # Use table handler to exclude trashed table + table = TableHandler().get_table(user_source.table_id) + except TableDoesNotExist as exc: + # As we CASCADE when a table is deleted, the table shouldn't + # exist only if it's trashed and not yet deleted. + raise UserSourceImproperlyConfigured("The table doesn't exist.") from exc + + integration = user_source.integration.specific + + model = table.get_model() + + values = { + user_source.name_field.db_column: name, + user_source.email_field.db_column: email, + } + if role and user_source.role_field_id: + values[user_source.role_field.db_column] = role + + try: + # Use the action to keep track on what's going on + (user,) = CreateRowsActionType.do( + user=integration.authorized_user, + table=table, + rows_values=[values], + model=model, + ) + except Exception as e: + logger.exception(e) + raise ("Error while creating the user") from e + + return UserSourceUser( + user_source, + user, + user.id, + getattr(user, user_source.name_field.db_column), + getattr(user, user_source.email_field.db_column), + self.get_user_role(user, user_source), + ) + def authenticate(self, user_source: UserSource, **kwargs): """ Authenticates using the given credentials. It uses the password auth provider. diff --git a/enterprise/backend/src/baserow_enterprise/migrations/0034_samlappauthprovidermodel_and_more.py b/enterprise/backend/src/baserow_enterprise/migrations/0034_samlappauthprovidermodel_and_more.py new file mode 100644 index 000000000..a875d8d78 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/migrations/0034_samlappauthprovidermodel_and_more.py @@ -0,0 +1,81 @@ +# Generated by Django 5.0.9 on 2024-12-05 11:23 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("baserow_enterprise", "0033_samlauthprovidermodel_email_attr_key_and_more"), + ("core", "0093_alter_appauthprovider_options_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="SamlAppAuthProviderModel", + fields=[ + ( + "appauthprovider_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="core.appauthprovider", + ), + ), + ( + "metadata", + models.TextField( + blank=True, + help_text="The XML metadata downloaded from the metadata_url.", + ), + ), + ( + "is_verified", + models.BooleanField( + default=False, + help_text="This will be set to True only after a user successfully login with this IdP. This must be True to disable normal username/password login and make SAML the only authentication provider. ", + ), + ), + ( + "email_attr_key", + models.CharField( + db_default="user.email", + default="user.email", + help_text="The key in the SAML response that contains the email address of the user. If this is not set, the email will be taken from the user's profile.", + max_length=32, + ), + ), + ( + "first_name_attr_key", + models.CharField( + db_default="user.first_name", + default="user.first_name", + help_text="The key in the SAML response that contains the first name of the user. If this is not set, the first name will be taken from the user's profile.", + max_length=32, + ), + ), + ( + "last_name_attr_key", + models.CharField( + blank=True, + db_default="user.last_name", + default="user.last_name", + help_text="The key in the SAML response that contains the last name of the user. If this is not set, the last name will be taken from the user's profile.", + max_length=32, + ), + ), + ], + options={ + "ordering": ["id"], + "abstract": False, + }, + bases=("core.appauthprovider", models.Model), + ), + migrations.AlterModelOptions( + name="samlauthprovidermodel", + options={"ordering": ["domain"]}, + ), + ] diff --git a/enterprise/backend/src/baserow_enterprise/models.py b/enterprise/backend/src/baserow_enterprise/models.py index c18036be0..dae42118f 100644 --- a/enterprise/backend/src/baserow_enterprise/models.py +++ b/enterprise/backend/src/baserow_enterprise/models.py @@ -1,6 +1,12 @@ from baserow_enterprise.builder.elements.models import AuthFormElement from baserow_enterprise.data_sync.models import LocalBaserowTableDataSync -from baserow_enterprise.integrations.models import LocalBaserowUserSource +from baserow_enterprise.integrations.common.sso.saml.models import ( + SamlAppAuthProviderModel, +) +from baserow_enterprise.integrations.models import ( + LocalBaserowPasswordAppAuthProvider, + LocalBaserowUserSource, +) from baserow_enterprise.role.models import Role, RoleAssignment from baserow_enterprise.teams.models import Team, TeamSubject @@ -12,4 +18,6 @@ __all__ = [ "LocalBaserowUserSource", "AuthFormElement", "LocalBaserowTableDataSync", + "LocalBaserowPasswordAppAuthProvider", + "SamlAppAuthProviderModel", ] diff --git a/enterprise/backend/src/baserow_enterprise/sso/oauth2/auth_provider_types.py b/enterprise/backend/src/baserow_enterprise/sso/oauth2/auth_provider_types.py index f8ff3bdc4..44c6892c4 100644 --- a/enterprise/backend/src/baserow_enterprise/sso/oauth2/auth_provider_types.py +++ b/enterprise/backend/src/baserow_enterprise/sso/oauth2/auth_provider_types.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple from django.conf import settings from django.contrib.sessions.backends.base import SessionBase -from django.urls import reverse +from django.urls import include, path, reverse import requests from loguru import logger @@ -30,6 +30,8 @@ from .models import ( OAUTH_BACKEND_URL = settings.PUBLIC_BACKEND_URL +_is_url_already_loaded = False + @dataclass class WellKnownUrls: @@ -50,6 +52,20 @@ class OAuth2AuthProviderMixin: - self.SCOPE """ + def get_api_urls(self): + global _is_url_already_loaded + + from baserow_enterprise.api.sso.oauth2 import urls + + if not _is_url_already_loaded: + _is_url_already_loaded = True + # We need to register this only once + return [ + path("sso/oauth2/", include(urls, namespace="enterprise_sso_oauth2")) + ] + else: + return [] + def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]: if not is_sso_feature_active(): return None @@ -64,7 +80,7 @@ class OAuth2AuthProviderMixin: { "redirect_url": urllib.parse.urljoin( OAUTH_BACKEND_URL, - reverse("api:enterprise:sso:oauth2:login", args=(instance.id,)), + reverse("api:enterprise_sso_oauth2:login", args=(instance.id,)), ), "name": instance.name, "type": self.type, @@ -144,7 +160,7 @@ class OAuth2AuthProviderMixin: redirect_uri = urllib.parse.urljoin( OAUTH_BACKEND_URL, - reverse("api:enterprise:sso:oauth2:callback", args=(instance.id,)), + reverse("api:enterprise_sso_oauth2:callback", args=(instance.id,)), ) if "oauth_state" in session: return OAuth2Session( diff --git a/enterprise/backend/src/baserow_enterprise/sso/saml/auth_provider_types.py b/enterprise/backend/src/baserow_enterprise/sso/saml/auth_provider_types.py index ddbfc5903..54552c65f 100644 --- a/enterprise/backend/src/baserow_enterprise/sso/saml/auth_provider_types.py +++ b/enterprise/backend/src/baserow_enterprise/sso/saml/auth_provider_types.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, List, Optional +from abc import abstractmethod +from typing import Any, Dict, List, Optional, TypedDict from urllib.parse import urljoin from django.conf import settings -from django.urls import reverse +from django.urls import include, path, reverse from rest_framework import serializers @@ -26,43 +27,33 @@ from baserow_enterprise.sso.utils import is_sso_feature_active from .models import SamlAuthProviderModel -class SamlAuthProviderType(AuthProviderType): +class SamlAuthProviderTypeMixin: """ The SAML authentication provider type allows users to login using SAML. """ type = "saml" - model_class = SamlAuthProviderModel - allowed_fields: List[str] = [ - "id", - "domain", - "type", - "enabled", + + class SamlSerializedDict(TypedDict): + metadata: Dict + is_verified: bool + + saml_allowed_fields: List[str] = [ "metadata", "is_verified", "email_attr_key", "first_name_attr_key", "last_name_attr_key", ] - serializer_field_names = [ - "domain", + saml_serializer_field_names = [ "metadata", - "enabled", "is_verified", "email_attr_key", "first_name_attr_key", "last_name_attr_key", ] - serializer_field_overrides = { - "domain": serializers.CharField( - validators=[validate_domain], - required=True, - help_text="The email domain registered with this provider.", - ), - "enabled": serializers.BooleanField( - help_text="Whether the provider is enabled or not.", - required=False, - ), + + saml_serializer_field_overrides = { "metadata": serializers.CharField( validators=[validate_saml_metadata], required=True, @@ -98,21 +89,61 @@ class SamlAuthProviderType(AuthProviderType): SamlProviderForDomainAlreadyExists: ERROR_SAML_PROVIDER_FOR_DOMAIN_ALREADY_EXISTS } - def before_create(self, user, **values): - validate_unique_saml_domain(values["domain"]) - return super().before_create(user, **values) + @classmethod + @abstractmethod + def get_acs_absolute_url( + cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None + ): + """ + Returns the ACS url for SAML authentication purpose. The user is redirected + to this URL after a successful login. + """ - def before_update(self, user, provider, **values): - if "domain" in values: - validate_unique_saml_domain(values["domain"], provider) - return super().before_update(user, provider, **values) + @classmethod + @abstractmethod + def get_login_absolute_url(cls): + """ + Returns the login URL for this auth_provider. The login URL is used to initiate + the Saml login process. + """ + + +class SamlAuthProviderType(SamlAuthProviderTypeMixin, AuthProviderType): + """ + The SAML authentication provider type allows users to login using SAML. + """ + + model_class = SamlAuthProviderModel + + @property + def allowed_fields(self) -> List[str]: + return SamlAuthProviderTypeMixin.saml_allowed_fields + + @property + def serializer_field_names(self): + return SamlAuthProviderTypeMixin.saml_serializer_field_names + + @property + def serializer_field_overrides(self): + return SamlAuthProviderTypeMixin.saml_serializer_field_overrides | { + "domain": serializers.CharField( + validators=[validate_domain], + required=True, + help_text="The email domain registered with this provider.", + ), + } + + def get_api_urls(self): + from baserow_enterprise.api.sso.saml import urls + + return [path("sso/saml/", include(urls, namespace="enterprise_sso_saml"))] def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]: single_sign_on_feature_active = is_sso_feature_active() if not single_sign_on_feature_active: return None - configured_domains = SamlAuthProviderModel.objects.filter(enabled=True).count() + configured_domains = self.model_class.objects.filter(enabled=True).count() if not configured_domains: return None @@ -130,16 +161,28 @@ class SamlAuthProviderType(AuthProviderType): "default_redirect_url": default_redirect_url, } + def before_create(self, user, **values): + validate_unique_saml_domain(values["domain"]) + return super().before_create(user, **values) + + def before_update(self, user, provider, **values): + if "domain" in values: + validate_unique_saml_domain(values["domain"], provider) + return super().before_update(user, provider, **values) + @classmethod - def get_acs_absolute_url(cls): + def get_acs_absolute_url( + cls, auth_provider: SamlAuthProviderTypeMixin | None = None + ): return urljoin( - settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:acs") + settings.PUBLIC_BACKEND_URL, reverse("api:enterprise_sso_saml:acs") ) @classmethod def get_login_absolute_url(cls): return urljoin( - settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:login") + settings.PUBLIC_BACKEND_URL, + reverse("api:enterprise_sso_saml:login"), ) def export_serialized(self) -> Dict[str, Any]: diff --git a/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py b/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py index 1a03f1158..b96d33137 100644 --- a/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py +++ b/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional from django.conf import settings from django.contrib.auth.models import AbstractUser +from django.db.models import Model, QuerySet from defusedxml import ElementTree from loguru import logger @@ -13,9 +14,11 @@ from saml2.config import Config as Saml2Config from saml2.response import AuthnResponse from baserow.core.auth_provider.types import UserInfo -from baserow.core.registries import auth_provider_type_registry from baserow_enterprise.api.sso.utils import get_valid_frontend_url -from baserow_enterprise.sso.saml.models import SamlAuthProviderModel +from baserow_enterprise.sso.saml.models import ( + SamlAuthProviderModel, + SamlAuthProviderModelMixin, +) from .exceptions import ( InvalidSamlConfiguration, @@ -25,10 +28,12 @@ from .exceptions import ( class SamlAuthProviderHandler: + model_class: Model = SamlAuthProviderModel + @classmethod def prepare_saml_client( cls, - saml_auth_provider: SamlAuthProviderModel, + saml_auth_provider: SamlAuthProviderModelMixin, ) -> Saml2Client: """ Returns a SAML client with the correct configuration for the given @@ -39,10 +44,9 @@ class SamlAuthProviderHandler: :return: The SAML client that can be used to authenticate the user. """ - saml_provider_type = auth_provider_type_registry.get_by_model( - saml_auth_provider + acs_url = saml_auth_provider.get_type().get_acs_absolute_url( + saml_auth_provider.specific ) - acs_url = saml_provider_type.get_acs_absolute_url() saml_settings: Dict[str, Any] = { "entityid": acs_url, @@ -95,9 +99,8 @@ class SamlAuthProviderHandler: @classmethod def get_saml_auth_provider_from_saml_response( - cls, - raw_saml_response: str, - ) -> SamlAuthProviderModel: + cls, raw_saml_response: str, base_queryset: QuerySet | None = None + ) -> SamlAuthProviderModelMixin: """ Parses the saml response and returns the authentication provider that needs to be used to authenticate the user. @@ -120,7 +123,10 @@ class SamlAuthProviderHandler: except (ElementTree.ParseError, AttributeError): raise InvalidSamlResponse("Impossible decode SAML response.") - saml_auth_provider = SamlAuthProviderModel.objects.filter( + if base_queryset is None: + base_queryset = cls.model_class.objects + + saml_auth_provider = base_queryset.filter( enabled=True, metadata__contains=issuer ).first() if not saml_auth_provider: @@ -130,7 +136,7 @@ class SamlAuthProviderHandler: @classmethod def get_user_info_from_authn_user_identity( cls, - saml_auth_provider: SamlAuthProviderModel, + saml_auth_provider: SamlAuthProviderModelMixin, authn_identity: Dict[str, str], saml_request_data: Optional[Dict[str, str]] = None, ) -> UserInfo: @@ -172,8 +178,7 @@ class SamlAuthProviderHandler: @classmethod def get_saml_auth_provider_from_email( - cls, - email: Optional[str] = None, + cls, email: Optional[str] = None, base_queryset: QuerySet | None = None ) -> SamlAuthProviderModel: """ It returns the Saml Identity Provider for the the given email address. @@ -187,20 +192,24 @@ class SamlAuthProviderHandler: address provided. """ - base_queryset = SamlAuthProviderModel.objects.filter(enabled=True) + if base_queryset is None: + base_queryset = cls.model_class.objects + + queryset = base_queryset.filter(enabled=True) if email is not None: try: domain = email.rsplit("@", 1)[1] except IndexError: raise InvalidSamlRequest("Invalid mail address provided.") - base_queryset = base_queryset.filter(domain=domain) + + queryset = queryset.filter(domain=domain) try: - return base_queryset.get() + return queryset.get() except ( - SamlAuthProviderModel.DoesNotExist, - SamlAuthProviderModel.MultipleObjectsReturned, + cls.model_class.DoesNotExist, + cls.model_class.MultipleObjectsReturned, ): raise InvalidSamlRequest("No valid SAML identity provider found.") @@ -213,7 +222,10 @@ class SamlAuthProviderHandler: @classmethod def sign_in_user_from_saml_response( - cls, saml_response: str, saml_request_data: Optional[Dict[str, str]] = None + cls, + saml_response: str, + saml_request_data: Optional[Dict[str, str]] = None, + base_queryset: QuerySet | None = None, ) -> AbstractUser: """ Signs in the user using the SAML response received from the identity @@ -230,7 +242,7 @@ class SamlAuthProviderHandler: try: saml_auth_provider = cls.get_saml_auth_provider_from_saml_response( - saml_response + saml_response, base_queryset=base_queryset ) saml_client = cls.prepare_saml_client(saml_auth_provider) @@ -249,11 +261,10 @@ class SamlAuthProviderHandler: logger.exception(exc) raise InvalidSamlResponse(str(exc)) - saml_provider_type = saml_auth_provider.get_type() ( user, _, - ) = saml_provider_type.get_or_create_user_and_sign_in( + ) = saml_auth_provider.get_type().get_or_create_user_and_sign_in( saml_auth_provider, idp_provided_user_info ) @@ -268,7 +279,7 @@ class SamlAuthProviderHandler: @classmethod def get_sign_in_url_for_auth_provider( cls, - saml_auth_provider: SamlAuthProviderModel, + saml_auth_provider: SamlAuthProviderModelMixin, original_url: str = "", ) -> str: """ @@ -284,6 +295,7 @@ class SamlAuthProviderHandler: """ saml_client = cls.prepare_saml_client(saml_auth_provider) + _, info = saml_client.prepare_for_authenticate(relay_state=original_url) for key, value in info["headers"]: @@ -294,24 +306,36 @@ class SamlAuthProviderHandler: raise InvalidSamlConfiguration("No Location header found in SAML response.") @classmethod - def get_sign_in_url(cls, query_params: Dict[str, str]) -> str: + def get_sign_in_url( + cls, + query_params: Dict[str, str], + base_queryset: QuerySet | None = None, + redirect_to: str | None = None, + ) -> str: """ Returns the sign in url for the correct identity provider. This url is used to initiate the SAML authentication flow from the service provider. :param query_params: A dict containing the query parameters from the sign in request. + :param redirect_to: if set, used as relay state url. :raises InvalidSamlRequest: If the email address is invalid. :raises InvalidSamlConfiguration: If the SAML configuration is invalid. :return: The redirect url to the identity provider. """ user_email = query_params.pop("email", None) - original_url = query_params.pop("original", "") - valid_relay_state_url = get_valid_frontend_url(original_url, query_params) + + if redirect_to: + valid_relay_state_url = redirect_to + else: + original_url = query_params.pop("original", "") + valid_relay_state_url = get_valid_frontend_url(original_url, query_params) try: - saml_auth_provider = cls.get_saml_auth_provider_from_email(user_email) + saml_auth_provider = cls.get_saml_auth_provider_from_email( + user_email, base_queryset=base_queryset + ) return cls.get_sign_in_url_for_auth_provider( saml_auth_provider, valid_relay_state_url ) diff --git a/enterprise/backend/src/baserow_enterprise/sso/saml/models.py b/enterprise/backend/src/baserow_enterprise/sso/saml/models.py index a99405a11..120a4d197 100644 --- a/enterprise/backend/src/baserow_enterprise/sso/saml/models.py +++ b/enterprise/backend/src/baserow_enterprise/sso/saml/models.py @@ -5,7 +5,7 @@ from django.dispatch import receiver from baserow.core.auth_provider.models import AuthProviderModel -class SamlAuthProviderModel(AuthProviderModel): +class SamlAuthProviderModelMixin(models.Model): metadata = models.TextField( blank=True, help_text="The XML metadata downloaded from the metadata_url." ) @@ -46,6 +46,15 @@ class SamlAuthProviderModel(AuthProviderModel): ), ) + class Meta: + abstract = True + + +class SamlAuthProviderModel(SamlAuthProviderModelMixin, AuthProviderModel): + # Restore ordering + class Meta(AuthProviderModel.Meta): + ordering = ["domain"] + @receiver(pre_save, sender=SamlAuthProviderModel) def reset_is_verified_if_metadata_changed(sender, instance, **kwargs): diff --git a/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_oauth_views.py b/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_oauth_views.py index 0feabbdec..40053d7e2 100755 --- a/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_oauth_views.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_oauth_views.py @@ -52,7 +52,7 @@ def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture): ) response = api_client.get( - reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": provider.id}), + reverse("api:enterprise_sso_oauth2:login", kwargs={"provider_id": provider.id}), format="json", ) @@ -68,7 +68,7 @@ def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture): def test_oauth2_login_provider_doesnt_exist(api_client, enterprise_data_fixture): enterprise_data_fixture.enable_enterprise() response = api_client.get( - reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": 50}), + reverse("api:enterprise_sso_oauth2:login", kwargs={"provider_id": 50}), format="json", ) @@ -88,7 +88,7 @@ def test_oauth2_login_with_url_param(api_client, enterprise_data_fixture): ) response = api_client.get( reverse( - "api:enterprise:sso:oauth2:login", + "api:enterprise_sso_oauth2:login", kwargs={"provider_id": provider.id}, ) + "?original=templates&workspace_invitation_token=t&language=en", @@ -118,7 +118,7 @@ def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture) response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", kwargs={"provider_id": provider.id} + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id} ), format="json", ) @@ -135,7 +135,7 @@ def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture) def test_oauth2_callback_provider_doesnt_exist(api_client, enterprise_data_fixture): enterprise_data_fixture.enable_enterprise() response = api_client.get( - reverse("api:enterprise:sso:oauth2:callback", kwargs={"provider_id": 50}), + reverse("api:enterprise_sso_oauth2:callback", kwargs={"provider_id": 50}), format="json", ) @@ -167,7 +167,7 @@ def test_oauth2_callback_signup_success(api_client, enterprise_data_fixture): response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -211,7 +211,7 @@ def test_oauth2_callback_signup_set_language(api_client, enterprise_data_fixture response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -263,7 +263,7 @@ def test_oauth2_callback_signup_workspace_invitation( response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -318,7 +318,7 @@ def test_oauth2_callback_signup_workspace_invitation_email_mismatch( response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -360,7 +360,7 @@ def test_oauth2_callback_signup_disabled(api_client, enterprise_data_fixture): response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -408,7 +408,7 @@ def test_oauth2_callback_login_success( response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -460,7 +460,7 @@ def test_oauth2_callback_login_deactivated_user( response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -505,7 +505,7 @@ def test_oauth2_callback_login_different_provider( response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", @@ -549,7 +549,7 @@ def test_oauth2_callback_login_auth_flow_error( response = api_client.get( reverse( - "api:enterprise:sso:oauth2:callback", + "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}, ) + "?code=validcode", diff --git a/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_saml_views.py b/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_saml_views.py index e8b2f2f08..5d59522f4 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_saml_views.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/api/sso/test_saml_views.py @@ -33,8 +33,8 @@ def test_saml_provider_get_login_url(api_client, data_fixture, enterprise_data_f auth_provider_1 = enterprise_data_fixture.create_saml_auth_provider( domain="test1.com" ) - auth_provider_login = reverse("api:enterprise:sso:saml:login") - auth_provider_login_url = reverse("api:enterprise:sso:saml:login_url") + auth_provider_login = reverse("api:enterprise_sso_saml:login") + auth_provider_login_url = reverse("api:enterprise_sso_saml:login_url") _, unauthorized_token = data_fixture.create_user_and_token() @@ -147,7 +147,7 @@ def test_user_cannot_initiate_saml_sso_without_enterprise_license( api_client, enterprise_data_fixture ): enterprise_data_fixture.create_saml_auth_provider(domain="test1.com") - response = api_client.get(reverse("api:enterprise:sso:saml:login")) + response = api_client.get(reverse("api:enterprise_sso_saml:login")) assert response.status_code == HTTP_302_FOUND assert ( response.headers["Location"] @@ -173,7 +173,7 @@ def test_user_can_initiate_saml_sso_with_enterprise_license( enterprise_data_fixture.create_saml_auth_provider(domain="test1.com") enterprise_data_fixture.enable_enterprise() - sp_sso_saml_login_url = reverse("api:enterprise:sso:saml:login") + sp_sso_saml_login_url = reverse("api:enterprise_sso_saml:login") original_relative_url = "database/1/table/1/" request_query_string = urlencode({"original": original_relative_url}) @@ -204,7 +204,7 @@ def test_user_can_initiate_saml_sso_with_enterprise_license( @override_settings(DEBUG=True) def test_saml_assertion_consumer_service(api_client, enterprise_data_fixture): user, _ = enterprise_data_fixture.create_enterprise_admin_user_and_token() - sp_sso_saml_acs_url = reverse("api:enterprise:sso:saml:acs") + sp_sso_saml_acs_url = reverse("api:enterprise_sso_saml:acs") ( metadata, diff --git a/enterprise/backend/tests/baserow_enterprise_tests/integrations/local_baserow/test_user_source_types.py b/enterprise/backend/tests/baserow_enterprise_tests/integrations/local_baserow/test_user_source_types.py index e87b70786..bfff7a294 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/integrations/local_baserow/test_user_source_types.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/integrations/local_baserow/test_user_source_types.py @@ -654,6 +654,95 @@ def test_create_user_source_field_from_other_table(api_client, data_fixture): assert response.json()["detail"]["name_field_id"][0]["code"] == "missing_table" +@pytest.mark.django_db +def test_user_source_create_user(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + application = data_fixture.create_builder_application(workspace=workspace) + database = data_fixture.create_database_application(workspace=workspace) + + integration = data_fixture.create_local_baserow_integration( + application=application, user=user + ) + + table_from_same_workspace1, fields, rows = data_fixture.build_table( + user=user, + database=database, + columns=[ + ("Email", "text"), + ("Name", "text"), + ], + rows=[ + ["test@baserow.io", "Test"], + ], + ) + + email_field, name_field = fields + + user_source = data_fixture.create_user_source_with_first_type( + integration=integration, + name="Test name", + table=table_from_same_workspace1, + email_field=email_field, + name_field=name_field, + uid="uid", + ) + + created_user = user_source.get_type().create_user( + user_source, "test2@baserow.io", "Test2" + ) + + model = table_from_same_workspace1.get_model() + assert created_user.email == "test2@baserow.io" + assert model.objects.count() == 2 + assert getattr(model.objects.last(), email_field.db_column) == "test2@baserow.io" + + +@pytest.mark.django_db +def test_user_source_create_user_w_role(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + application = data_fixture.create_builder_application(workspace=workspace) + database = data_fixture.create_database_application(workspace=workspace) + + integration = data_fixture.create_local_baserow_integration( + application=application, user=user + ) + + table_from_same_workspace1, fields, rows = data_fixture.build_table( + user=user, + database=database, + columns=[ + ("Email", "text"), + ("Name", "text"), + ("Role", "text"), + ], + rows=[ + ["test@baserow.io", "Test", "role1"], + ], + ) + + email_field, name_field, role_field = fields + + user_source = data_fixture.create_user_source_with_first_type( + integration=integration, + name="Test name", + table=table_from_same_workspace1, + email_field=email_field, + name_field=name_field, + role_field=role_field, + uid="uid", + ) + + created_user = user_source.get_type().create_user( + user_source, "test2@baserow.io", "Test2", "role2" + ) + + model = table_from_same_workspace1.get_model() + assert created_user.role == "role2" + assert getattr(model.objects.last(), role_field.db_column) == "role2" + + @pytest.mark.django_db def test_export_user_source(data_fixture): user = data_fixture.create_user() @@ -803,7 +892,7 @@ def test_create_local_baserow_user_source_w_auth_providers(api_client, data_fixt { "type": "local_baserow_password", "enabled": True, - "domain": "test1", + "domain": "test1.com", "password_field_id": password_field.id, } ], @@ -820,7 +909,7 @@ def test_create_local_baserow_user_source_w_auth_providers(api_client, data_fixt assert response_json["auth_providers"] == [ { - "domain": "test1", + "domain": "test1.com", "id": first.id, "password_field_id": password_field.id, "type": "local_baserow_password", diff --git a/enterprise/backend/tests/baserow_enterprise_tests/sso/test_sso_utils.py b/enterprise/backend/tests/baserow_enterprise_tests/sso/test_sso_utils.py new file mode 100644 index 000000000..a9e4a0316 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/sso/test_sso_utils.py @@ -0,0 +1,100 @@ +from baserow_enterprise.api.sso.utils import get_valid_frontend_url + + +def test_get_valid_front_url(): + assert get_valid_frontend_url() == "http://localhost:3000/dashboard" + assert ( + get_valid_frontend_url("http://localhost:3000/dashboard") + == "http://localhost:3000/dashboard" + ) + assert ( + get_valid_frontend_url("http://localhost:3000/dashboard/after") + == "http://localhost:3000/dashboard/after" + ) + assert ( + get_valid_frontend_url("http://localhost:3000/other") + == "http://localhost:3000/other" + ) + assert ( + get_valid_frontend_url("http://localhost:3000/other", allow_any_path=False) + == "http://localhost:3000/dashboard" + ) + assert ( + get_valid_frontend_url("http://localhost:3000/") + == "http://localhost:3000/dashboard" + ) + + assert ( + get_valid_frontend_url("http://something.com/") + == "http://localhost:3000/dashboard" + ) + assert ( + get_valid_frontend_url("http://something.com/dashboard/test") + == "http://localhost:3000/dashboard" + ) + + +def test_get_valid_front_url_with_defaults(): + defaults = ["https://test.com/toto", "http://random.net/"] + assert ( + get_valid_frontend_url(default_frontend_urls=defaults) + == "https://test.com/toto" + ) + assert ( + get_valid_frontend_url("https://test.com/toto", default_frontend_urls=defaults) + == "https://test.com/toto" + ) + assert ( + get_valid_frontend_url("http://random.net/", default_frontend_urls=defaults) + == "http://random.net/" + ) + assert ( + get_valid_frontend_url( + "https://test.com/toto/subpath/", default_frontend_urls=defaults + ) + == "https://test.com/toto/subpath/" + ) + assert ( + get_valid_frontend_url("https://test.com/titi/", default_frontend_urls=defaults) + == "https://test.com/titi/" + ) + assert ( + get_valid_frontend_url( + "https://test.com/titi/", + default_frontend_urls=defaults, + ) + == "https://test.com/titi/" + ) + assert ( + get_valid_frontend_url( + "https://test.com/titi/", + default_frontend_urls=defaults, + allow_any_path=False, + ) + == "https://test.com/toto" + ) + assert ( + get_valid_frontend_url("http://random.net/", default_frontend_urls=defaults) + == "http://random.net/" + ) + assert ( + get_valid_frontend_url("http://random.net/path", default_frontend_urls=defaults) + == "http://random.net/path" + ) + assert ( + get_valid_frontend_url("http://other.net/path", default_frontend_urls=defaults) + == "https://test.com/toto" + ) + + +def test_get_valid_front_url_w_params(): + assert ( + get_valid_frontend_url(query_params={"test": "value"}) + == "http://localhost:3000/dashboard?test=value" + ) + assert ( + get_valid_frontend_url( + "http://localhost:3000/dashboard", query_params={"test": "value"} + ) + == "http://localhost:3000/dashboard?test=value" + ) diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/all.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/all.scss index 413a0ea0d..6aba9c2c8 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/all.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/all.scss @@ -14,3 +14,6 @@ @import 'long_text_field'; @import 'highest_role_field'; @import 'auth_form_element'; +@import 'common_saml_setting_form'; +@import 'common_saml_setting_modal'; +@import 'saml_auth_link'; diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/auth_form_element.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/auth_form_element.scss index f8fbdd52f..173167ec5 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/auth_form_element.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/auth_form_element.scss @@ -15,3 +15,11 @@ font-size: var(--label-font-size, 13px); padding-top: 0.5em; } + +.auth-form-element__provider { + padding-top: 16px; + + &:first-child { + padding-top: 0; + } +} diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/common_saml_setting_form.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/common_saml_setting_form.scss new file mode 100644 index 000000000..8b43bfdee --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/common_saml_setting_form.scss @@ -0,0 +1,13 @@ +.common-saml-setting-form { + @include elevation($elevation-low); + @include rounded($rounded); + + padding: 20px; + border: 1px solid $palette-neutral-200; + display: flex; + cursor: pointer; + + &--error { + border-color: $palette-red-600; + } +} diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/common_saml_setting_modal.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/common_saml_setting_modal.scss new file mode 100644 index 000000000..075ba2c67 --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/common_saml_setting_modal.scss @@ -0,0 +1,24 @@ +.common-saml-setting-modal__url-block { + @include rounded($rounded); + + position: relative; + font-family: monospace; + padding: 12px 16px; + background-color: $palette-neutral-50; + border: 1px solid $palette-neutral-400; +} + +.common-saml-setting-modal__url { + padding-bottom: 8px; + cursor: pointer; + display: flex; + gap: 5px; +} + +.common-saml-setting-modal__url-domain { + color: $palette-neutral-800; +} + +.common-saml-setting-modal__url-dest { + @extend %ellipsis; +} diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/saml_auth_link.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/saml_auth_link.scss new file mode 100644 index 000000000..d246ef8d4 --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/saml_auth_link.scss @@ -0,0 +1,6 @@ +.saml-auth-link, +.saml-auth-link__modal-footer { + display: flex; + flex-direction: column; + font-size: var(--label-font-size, 13px); +} diff --git a/enterprise/web-frontend/modules/baserow_enterprise/authProviderTypes.js b/enterprise/web-frontend/modules/baserow_enterprise/authProviderTypes.js index adbefad2d..5160e2d31 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/authProviderTypes.js +++ b/enterprise/web-frontend/modules/baserow_enterprise/authProviderTypes.js @@ -42,68 +42,150 @@ export class PasswordAuthProviderType extends AuthProviderType { return null } + /** + * We can create only one password provider. + */ + canCreateNew(authProviders) { + return ( + !authProviders[this.getType()] || + authProviders[this.getType()].length === 0 + ) + } + getOrder() { return 1 } } -export class SamlAuthProviderType extends AuthProviderType { - static getType() { - return 'saml' - } - - getIcon() { - return SAMLIcon - } - - getVerifiedIcon() { - return VerifiedProviderIcon - } - - getName() { - return 'SSO SAML provider' - } - - getProviderName(provider) { - return `SSO SAML: ${provider.domain}` - } - - getLoginActionComponent() { - return SamlLoginAction - } - - getAdminListComponent() { - return AuthProviderItem - } - - getAdminSettingsFormComponent() { - return SamlSettingsForm +export const SamlAuthProviderTypeMixin = (Base) => + class extends Base { + static getType() { + return 'saml' + } + + getIcon() { + return SAMLIcon + } + + getVerifiedIcon() { + return VerifiedProviderIcon + } + + getName() { + return this.app.i18n.t('authProviderTypes.saml') + } + + getProviderName(provider) { + if (provider.domain) { + return this.app.i18n.t('authProviderTypes.ssoSamlProviderName', { + domain: provider.domain, + }) + } else { + return this.app.i18n.t( + 'authProviderTypes.ssoSamlProviderNameUnconfigured' + ) + } + } + + getLoginActionComponent() { + return SamlLoginAction + } + + getAdminListComponent() { + return AuthProviderItem + } + + getAdminSettingsFormComponent() { + return SamlSettingsForm + } + + getRelayStateUrl() { + return this.app.store.getters['authProviderAdmin/getType'](this.getType()) + .relayStateUrl + } + + getAcsUrl() { + return this.app.store.getters['authProviderAdmin/getType'](this.getType()) + .acsUrl + } + + populateLoginOptions(authProviderOption) { + const loginOptions = super.populateLoginOptions(authProviderOption) + return { + redirectUrl: authProviderOption.redirect_url, + domainRequired: authProviderOption.domain_required, + ...loginOptions, + } + } + + handleServerError(vueComponentInstance, error) { + if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false + + for (const [key, value] of Object.entries(error.handler.detail || {})) { + vueComponentInstance.serverErrors[key] = value + } + return true + } + + populate(authProviderType) { + const populated = super.populate(authProviderType) + return { + acsUrl: authProviderType.acs_url, + relayStateUrl: authProviderType.relay_state_url, + ...populated, + } + } } +export class SamlAuthProviderType extends SamlAuthProviderTypeMixin( + AuthProviderType +) { getOrder() { return 50 } - - populateLoginOptions(authProviderOption) { - const loginOptions = super.populateLoginOptions(authProviderOption) - return { - redirectUrl: authProviderOption.redirect_url, - domainRequired: authProviderOption.domain_required, - ...loginOptions, - } - } - - populate(authProviderType) { - const populated = super.populate(authProviderType) - return { - acsUrl: authProviderType.acs_url, - relayStateUrl: authProviderType.relay_state_url, - ...populated, - } - } } -export class GoogleAuthProviderType extends AuthProviderType { +export const OAuth2AuthProviderTypeMixin = (Base) => + class extends Base { + getLoginButtonComponent() { + return LoginButton + } + + getAdminListComponent() { + return AuthProviderItem + } + + getAdminSettingsFormComponent() { + return OAuth2SettingsForm + } + + getCallbackUrl(authProvider) { + if (!authProvider.id) { + const nextProviderId = + this.app.store.getters['authProviderAdmin/getNextProviderId'] + return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/` + } + return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/` + } + + populateLoginOptions(authProviderOption) { + const loginOptions = super.populateLoginOptions(authProviderOption) + return { + ...loginOptions, + } + } + + populate(authProviderType) { + const populated = super.populate(authProviderType) + return { + ...populated, + } + } + } + +export class GoogleAuthProviderType extends OAuth2AuthProviderTypeMixin( + AuthProviderType +) { static getType() { return 'google' } @@ -120,38 +202,14 @@ export class GoogleAuthProviderType extends AuthProviderType { return provider.name ? provider.name : `Google` } - getLoginButtonComponent() { - return LoginButton - } - - getAdminListComponent() { - return AuthProviderItem - } - - getAdminSettingsFormComponent() { - return OAuth2SettingsForm - } - getOrder() { return 50 } - - populateLoginOptions(authProviderOption) { - const loginOptions = super.populateLoginOptions(authProviderOption) - return { - ...loginOptions, - } - } - - populate(authProviderType) { - const populated = super.populate(authProviderType) - return { - ...populated, - } - } } -export class FacebookAuthProviderType extends AuthProviderType { +export class FacebookAuthProviderType extends OAuth2AuthProviderTypeMixin( + AuthProviderType +) { static getType() { return 'facebook' } @@ -168,38 +226,14 @@ export class FacebookAuthProviderType extends AuthProviderType { return provider.name ? provider.name : this.getName() } - getLoginButtonComponent() { - return LoginButton - } - - getAdminListComponent() { - return AuthProviderItem - } - - getAdminSettingsFormComponent() { - return OAuth2SettingsForm - } - getOrder() { return 50 } - - populateLoginOptions(authProviderOption) { - const loginOptions = super.populateLoginOptions(authProviderOption) - return { - ...loginOptions, - } - } - - populate(authProviderType) { - const populated = super.populate(authProviderType) - return { - ...populated, - } - } } -export class GitHubAuthProviderType extends AuthProviderType { +export class GitHubAuthProviderType extends OAuth2AuthProviderTypeMixin( + AuthProviderType +) { static getType() { return 'github' } @@ -216,35 +250,9 @@ export class GitHubAuthProviderType extends AuthProviderType { return provider.name ? provider.name : this.getName() } - getLoginButtonComponent() { - return LoginButton - } - - getAdminListComponent() { - return AuthProviderItem - } - - getAdminSettingsFormComponent() { - return OAuth2SettingsForm - } - getOrder() { return 50 } - - populateLoginOptions(authProviderOption) { - const loginOptions = super.populateLoginOptions(authProviderOption) - return { - ...loginOptions, - } - } - - populate(authProviderType) { - const populated = super.populate(authProviderType) - return { - ...populated, - } - } } export class GitLabAuthProviderType extends AuthProviderType { @@ -276,26 +284,23 @@ export class GitLabAuthProviderType extends AuthProviderType { return GitLabSettingsForm } + getCallbackUrl(authProvider) { + if (!authProvider.id) { + const nextProviderId = + this.app.store.getters['authProviderAdmin/getNextProviderId'] + return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/` + } + return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/` + } + getOrder() { return 50 } - - populateLoginOptions(authProviderOption) { - const loginOptions = super.populateLoginOptions(authProviderOption) - return { - ...loginOptions, - } - } - - populate(authProviderType) { - const populated = super.populate(authProviderType) - return { - ...populated, - } - } } -export class OpenIdConnectAuthProviderType extends AuthProviderType { +export class OpenIdConnectAuthProviderType extends OAuth2AuthProviderTypeMixin( + AuthProviderType +) { static getType() { return 'openid_connect' } @@ -312,33 +317,38 @@ export class OpenIdConnectAuthProviderType extends AuthProviderType { return provider.name ? provider.name : this.getName() } - getLoginButtonComponent() { - return LoginButton - } - - getAdminListComponent() { - return AuthProviderItem - } - getAdminSettingsFormComponent() { return OpenIdConnectSettingsForm } + getCallbackUrl(authProvider) { + if (!authProvider.id) { + const nextProviderId = + this.app.store.getters['authProviderAdmin/getNextProviderId'] + return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/` + } + return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/` + } + + handleServerError(vueComponentInstance, error) { + if (error.handler.code === 'ERROR_INVALID_PROVIDER_URL') { + vueComponentInstance.serverErrors = { + ...vueComponentInstance.serverErrors, + baseUrl: error.handler.detail, + } + return true + } + + if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false + + vueComponentInstance.serverErrors = structuredClone( + error.handler.detail || {} + ) + + return true + } + getOrder() { return 50 } - - populateLoginOptions(authProviderOption) { - const loginOptions = super.populateLoginOptions(authProviderOption) - return { - ...loginOptions, - } - } - - populate(authProviderType) { - const populated = super.populate(authProviderType) - return { - ...populated, - } - } } diff --git a/enterprise/web-frontend/modules/baserow_enterprise/builder/components/elements/AuthFormElement.vue b/enterprise/web-frontend/modules/baserow_enterprise/builder/components/elements/AuthFormElement.vue index 54c6a1ec1..0ce20e1a4 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/builder/components/elements/AuthFormElement.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/builder/components/elements/AuthFormElement.vue @@ -1,64 +1,30 @@ <template> - <form - v-if="hasAtLeastOneLoginOption" - class="auth-form-element" - :style="getStyleOverride('input')" - @submit.prevent="onLogin" - > - <Error :error="error"></Error> - <ABFormGroup - :label="$t('authFormElement.email')" - :error-message=" - $v.values.email.$dirty - ? !$v.values.email.required - ? $t('error.requiredField') - : !$v.values.email.email - ? $t('error.invalidEmail') - : '' - : '' - " - :autocomplete="isEditMode ? 'off' : ''" - required - > - <ABInput - v-model="values.email" - :placeholder="$t('authFormElement.emailPlaceholder')" - @blur="$v.values.email.$touch()" - /> - </ABFormGroup> - <ABFormGroup - :label="$t('authFormElement.password')" - :error-message=" - $v.values.password.$dirty - ? !$v.values.password.required - ? $t('error.requiredField') - : '' - : '' - " - required - > - <ABInput - ref="passwordRef" - v-model="values.password" - type="password" - :placeholder="$t('authFormElement.passwordPlaceholder')" - @blur="$v.values.password.$touch()" - /> - </ABFormGroup> - <div :style="getStyleOverride('login_button')" class="auth-form__footer"> - <ABButton :disabled="$v.$error" :loading="loading" size="large"> - {{ resolvedLoginButtonLabel }} - </ABButton> - </div> - </form> - <p v-else>{{ $t('authFormElement.selectOrConfigureUserSourceFirst') }}</p> + <div v-if="hasAtLeastOneLoginOption" :style="fullStyle"> + <template v-for="appAuthType in appAuthProviderTypes"> + <div + v-if="hasAtLeastOneProvider(appAuthType)" + :key="appAuthType.type" + class="auth-form-element__provider" + > + <component + :is="appAuthType.component" + :user-source="selectedUserSource" + :auth-providers="appAuthProviderPerTypes[appAuthType.type]" + :login-button-label="resolvedLoginButtonLabel" + @after-login="afterLogin" + /> + </div> + </template> + </div> + <p v-else> + {{ $t('authFormElement.selectOrConfigureUserSourceFirst') }} + </p> </template> <script> import form from '@baserow/modules/core/mixins/form' import error from '@baserow/modules/core/mixins/error' import element from '@baserow/modules/builder/mixins/element' -import { required, email } from 'vuelidate/lib/validators' import { ensureString } from '@baserow/modules/core/utils/validator' import { mapActions } from 'vuex' @@ -79,12 +45,15 @@ export default { }, }, data() { - return { - loading: false, - values: { email: '', password: '' }, - } + return {} }, computed: { + fullStyle() { + return { + ...this.getStyleOverride('input'), + ...this.getStyleOverride('login_button'), + } + }, selectedUserSource() { return this.$store.getters['userSource/getUserSourceById']( this.builder, @@ -97,8 +66,21 @@ export default { } return this.$registry.get('userSource', this.selectedUserSource.type) }, - isAuthenticated() { - return this.$store.getters['userSourceUser/isAuthenticated'](this.builder) + authProviders() { + return this.selectedUserSource?.auth_providers || [] + }, + appAuthProviderTypes() { + return this.$registry.getOrderedList('appAuthProvider') + }, + appAuthProviderPerTypes() { + return Object.fromEntries( + this.appAuthProviderTypes.map((authType) => { + return [ + authType.type, + this.authProviders.filter(({ type }) => type === authType.type), + ] + }) + ) }, loginOptions() { if (!this.selectedUserSourceType) { @@ -141,68 +123,15 @@ export default { ...mapActions({ actionForceUpdateElement: 'element/forceUpdate', }), - async onLogin(event) { - if (this.isAuthenticated) { - await this.$store.dispatch('userSourceUser/logoff', { - application: this.builder, - }) - } - - this.$v.$touch() - if (this.$v.$invalid) { - this.focusOnFirstError() - return - } - this.loading = true - this.hideError() - try { - await this.$store.dispatch('userSourceUser/authenticate', { - application: this.builder, - userSource: this.selectedUserSource, - credentials: { - email: this.values.email, - password: this.values.password, - }, - setCookie: this.mode === 'public', - }) - this.values.password = '' - this.values.email = '' - this.$v.$reset() - this.fireEvent( - this.elementType.getEventByName(this.element, 'after_login') - ) - } catch (error) { - if (error.handler) { - const response = error.handler.response - if (response && response.status === 401) { - this.values.password = '' - this.$v.$reset() - this.$v.$touch() - this.$refs.passwordRef.focus() - - if (response.data?.error === 'ERROR_INVALID_CREDENTIALS') { - this.showError( - this.$t('error.incorrectCredentialTitle'), - this.$t('error.incorrectCredentialMessage') - ) - } - } else { - const message = error.handler.getMessage('login') - this.showError(message) - } - - error.handler.handled() - } else { - throw error - } - } - this.loading = false + hasAtLeastOneProvider(authProviderType) { + return ( + this.appAuthProviderPerTypes[authProviderType.getType()]?.length > 0 + ) }, - }, - validations: { - values: { - email: { required, email }, - password: { required }, + afterLogin() { + this.fireEvent( + this.elementType.getEventByName(this.element, 'after_login') + ) }, }, } diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/contexts/CreateAuthProviderContext.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/contexts/CreateAuthProviderContext.vue index fe47f3dbe..b07a81ce8 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/contexts/CreateAuthProviderContext.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/contexts/CreateAuthProviderContext.vue @@ -10,8 +10,8 @@ class="context__menu-item-link" @click="$emit('create', authProviderType)" > - <AuthProviderIcon :icon="getIcon(authProviderType)" /> - {{ getName(authProviderType) }} + <AuthProviderIcon :icon="authProviderType.getIcon()" /> + {{ authProviderType.getName() }} </a> </li> </ul> @@ -32,13 +32,5 @@ export default { required: true, }, }, - methods: { - getIcon(providerType) { - return this.$registry.get('authProvider', providerType.type).getIcon() - }, - getName(providerType) { - return this.$registry.get('authProvider', providerType.type).getName() - }, - }, } </script> diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/GitLabSettingsForm.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/GitLabSettingsForm.vue index c718bd715..d14deb2b3 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/GitLabSettingsForm.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/GitLabSettingsForm.vue @@ -107,59 +107,25 @@ <script> import { required, url } from 'vuelidate/lib/validators' -import form from '@baserow/modules/core/mixins/form' +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' export default { name: 'GitLabSettingsForm', - mixins: [form], - props: { - authProvider: { - type: Object, - required: false, - default: () => ({}), - }, - }, + mixins: [authProviderForm], data() { return { allowedValues: ['name', 'base_url', 'client_id', 'secret'], values: { name: '', - base_url: '', + base_url: 'https://gitlab.com', client_id: '', secret: '', }, } }, computed: { - providerName() { - return this.$registry - .get('authProvider', 'gitlab') - .getProviderName(this.authProvider) - }, callbackUrl() { - if (!this.authProvider.id) { - const nextProviderId = - this.$store.getters['authProviderAdmin/getNextProviderId'] - return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/` - } - return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/` - }, - }, - methods: { - getDefaultValues() { - return { - name: this.providerName, - base_url: this.authProvider.base_url || 'https://gitlab.com', - client_id: this.authProvider.client_id || '', - secret: this.authProvider.secret || '', - } - }, - submit() { - this.$v.$touch() - if (this.$v.$invalid) { - return - } - this.$emit('submit', this.values) + return this.authProviderType.getCallbackUrl(this.authProvider) }, }, validations() { diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OAuth2SettingsForm.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OAuth2SettingsForm.vue index c52dbc5a4..81f73664e 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OAuth2SettingsForm.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OAuth2SettingsForm.vue @@ -81,23 +81,11 @@ <script> import { required } from 'vuelidate/lib/validators' -import form from '@baserow/modules/core/mixins/form' +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' export default { name: 'OAuth2SettingsForm', - mixins: [form], - props: { - authProvider: { - type: Object, - required: false, - default: () => ({}), - }, - authProviderType: { - type: String, - required: false, - default: null, - }, - }, + mixins: [authProviderForm], data() { return { allowedValues: ['name', 'client_id', 'secret'], @@ -109,37 +97,8 @@ export default { } }, computed: { - providerName() { - const type = this.authProviderType - ? this.authProviderType - : this.authProvider.type - return this.$registry - .get('authProvider', type) - .getProviderName(this.authProvider) - }, callbackUrl() { - if (!this.authProvider.id) { - const nextProviderId = - this.$store.getters['authProviderAdmin/getNextProviderId'] - return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/` - } - return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/` - }, - }, - methods: { - getDefaultValues() { - return { - name: this.providerName, - client_id: this.authProvider.client_id || '', - secret: this.authProvider.secret || '', - } - }, - submit() { - this.$v.$touch() - if (this.$v.$invalid) { - return - } - this.$emit('submit', this.values) + return this.authProviderType.getCallbackUrl(this.authProvider) }, }, validations() { diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OpenIdConnectSettingsForm.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OpenIdConnectSettingsForm.vue index 043b0de4d..39e4f7b22 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OpenIdConnectSettingsForm.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/OpenIdConnectSettingsForm.vue @@ -109,23 +109,11 @@ <script> import { required, url } from 'vuelidate/lib/validators' -import form from '@baserow/modules/core/mixins/form' +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' export default { name: 'OpenIdConnectSettingsForm', - mixins: [form], - props: { - authProvider: { - type: Object, - required: false, - default: () => ({}), - }, - serverErrors: { - type: Object, - required: false, - default: () => ({}), - }, - }, + mixins: [authProviderForm], data() { return { allowedValues: ['name', 'base_url', 'client_id', 'secret'], @@ -139,42 +127,7 @@ export default { }, computed: { callbackUrl() { - if (!this.authProvider.id) { - const nextProviderId = - this.$store.getters['authProviderAdmin/getNextProviderId'] - return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/` - } - return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/` - }, - }, - methods: { - getDefaultValues() { - return { - name: this.authProvider.name || '', - base_url: this.authProvider.base_url || '', - client_id: this.authProvider.client_id || '', - secret: this.authProvider.secret || '', - } - }, - submit() { - this.$v.$touch() - if (this.$v.$invalid) { - return - } - this.$emit('submit', this.values) - }, - handleServerError(error) { - if (error.handler.code === 'ERROR_INVALID_PROVIDER_URL') { - this.serverErrors.baseUrl = error.handler.detail - return true - } - - if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false - - for (const [key, value] of Object.entries(error.handler.detail || {})) { - this.serverErrors[key] = value - } - return true + return this.authProviderType.getCallbackUrl(this.authProvider) }, }, validations() { diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/SamlSettingsForm.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/SamlSettingsForm.vue index 7defebb6d..014f7b565 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/SamlSettingsForm.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/forms/SamlSettingsForm.vue @@ -21,9 +21,9 @@ ref="domain" v-model="values.domain" size="large" - :error="fieldHasErrors('domain') || serverErrors.domain" + :error="fieldHasErrors('domain') || !!serverErrors.domain" :placeholder="$t('samlSettingsForm.domainPlaceholder')" - @input="serverErrors.domain = null" + @input="onDomainInput()" @blur="$v.values.domain.$touch()" ></FormInput> <template #error> @@ -50,16 +50,16 @@ small-label required :label="$t('samlSettingsForm.metadata')" - :error="fieldHasErrors('metadata')" + :error="fieldHasErrors('metadata') || !!serverErrors.metadata" class="margin-bottom-2" > <FormTextarea ref="metadata" v-model="values.metadata" - :rows="12" - :error="fieldHasErrors('metadata') || serverErrors.metadata" + :rows="8" + :error="fieldHasErrors('metadata') || !!serverErrors.metadata" :placeholder="$t('samlSettingsForm.metadataPlaceholder')" - @input="serverErrors.metadata = null" + @input="onMetadataInput()" @blur="$v.values.metadata.$touch()" ></FormTextarea> @@ -73,23 +73,25 @@ </template> </FormGroup> - <FormGroup - small-label - required - :label="$t('samlSettingsForm.relayStateUrl')" - class="margin-bottom-2" - > - <code>{{ getRelayStateUrl() }}</code> - </FormGroup> + <slot name="config"> + <FormGroup + small-label + required + :label="$t('samlSettingsForm.relayStateUrl')" + class="margin-bottom-2" + > + <code>{{ getRelayStateUrl() }}</code> + </FormGroup> - <FormGroup - small-label - required - :label="$t('samlSettingsForm.acsUrl')" - class="margin-bottom-2" - > - <code>{{ getAcsUrl() }}</code> - </FormGroup> + <FormGroup + small-label + required + :label="$t('samlSettingsForm.acsUrl')" + class="margin-bottom-2" + > + <code>{{ getAcsUrl() }}</code> + </FormGroup> + </slot> <Expandable card class="margin-bottom-2"> <template #header="{ toggle, expanded }"> @@ -110,7 +112,7 @@ </div> <div> {{ - usingDefaultAttrs() + usingDefaultAttrs ? $t('samlSettingsForm.defaultAttrs') : $t('samlSettingsForm.customAttrs') }} @@ -190,7 +192,7 @@ <script> import { maxLength, required, helpers } from 'vuelidate/lib/validators' -import form from '@baserow/modules/core/mixins/form' +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' const alphanumericDotDashUnderscore = helpers.regex( 'alphanumericDotDashUnderscore', @@ -199,41 +201,32 @@ const alphanumericDotDashUnderscore = helpers.regex( export default { name: 'SamlSettingsForm', - mixins: [form], - props: { - authProvider: { - type: Object, - required: false, - default: () => ({}), - }, - authProviderType: { - type: String, - required: false, - default: null, - }, - }, + mixins: [authProviderForm], data() { return { - allowedValues: ['domain', 'metadata'], - serverErrors: {}, + allowedValues: [ + 'domain', + 'metadata', + 'email_attr_key', + 'first_name_attr_key', + 'last_name_attr_key', + ], values: { domain: '', metadata: '', - email_attr_key: '', - first_name_attr_key: '', - last_name_attr_key: '', + email_attr_key: 'user.email', + first_name_attr_key: 'user.first_name', + last_name_attr_key: 'user.last_name', }, } }, computed: { + allSamlProviders() { + return this.authProviders.saml || [] + }, samlDomains() { - const samlAuthProviders = - this.$store.getters['authProviderAdmin/getAll'].saml?.authProviders || - [] - return samlAuthProviders - .filter( - (authProvider) => authProvider.domain !== this.authProvider.domain - ) + return this.allSamlProviders + .filter((authProvider) => authProvider.id !== this.authProvider.id) .map((authProvider) => authProvider.domain) }, defaultAttrs() { @@ -244,10 +237,8 @@ export default { } }, type() { - return this.authProviderType || this.authProvider.type + return this.authProviderType.getType() }, - }, - methods: { usingDefaultAttrs() { return ( this.values.email_attr_key === this.defaultAttrs.email_attr_key && @@ -256,20 +247,13 @@ export default { this.values.last_name_attr_key === this.defaultAttrs.last_name_attr_key ) }, - getDefaultValues() { - const authProviderAttrs = { - email_attr_key: this.authProvider.email_attr_key, - first_name_attr_key: this.authProvider.first_name_attr_key, - last_name_attr_key: this.authProvider.last_name_attr_key, - } - const samlAttrs = this.authProvider.id - ? authProviderAttrs - : this.defaultAttrs - return { - domain: this.authProvider.domain || '', - metadata: this.authProvider.metadata || '', - ...samlAttrs, - } + }, + methods: { + onDomainInput() { + this.serverErrors.domain = null + }, + onMetadataInput() { + this.serverErrors.metadata = null }, getFieldErrorMsg(fieldName) { if (!this.$v.values[fieldName].$dirty) { @@ -285,33 +269,17 @@ export default { } }, getRelayStateUrl() { - return this.$store.getters['authProviderAdmin/getType'](this.type) - .relayStateUrl + return this.authProviderType.getRelayStateUrl() }, getAcsUrl() { - return this.$store.getters['authProviderAdmin/getType'](this.type).acsUrl + return this.authProviderType.getAcsUrl() }, getVerifiedIcon() { - return this.$registry.get('authProvider', this.type).getVerifiedIcon() - }, - submit() { - this.$v.$touch() - if (this.$v.$invalid) { - return - } - this.$emit('submit', this.values) + return this.authProviderType.getVerifiedIcon() }, mustHaveUniqueDomain(domain) { return !this.samlDomains.includes(domain.trim()) }, - handleServerError(error) { - if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false - - for (const [key, value] of Object.entries(error.handler.detail || {})) { - this.serverErrors[key] = value - } - return true - }, }, validations() { return { diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/CreateAuthProviderModal.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/CreateAuthProviderModal.vue index 13014c540..e1db5d398 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/CreateAuthProviderModal.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/CreateAuthProviderModal.vue @@ -3,14 +3,15 @@ <h2 class="box__title"> {{ $t('createSettingsAuthProviderModal.title', { - type: getProviderTypeName(), + type: authProviderType.getName(), }) }} </h2> - <div v-if="authProviderType"> + <div> <component :is="getProviderAdminSettingsFormComponent()" ref="providerSettingsForm" + :auth-providers="appAuthProviderPerTypes" :auth-provider-type="authProviderType" @submit="create($event)" > @@ -21,8 +22,8 @@ </li> </ul> <Button type="primary" :disabled="loading" :loading="loading"> - {{ $t('action.create') }}</Button - > + {{ $t('action.create') }} + </Button> </div> </component> </div> @@ -38,9 +39,8 @@ export default { mixins: [modal], props: { authProviderType: { - type: String, - required: false, - default: null, + type: Object, + required: true, }, }, data() { @@ -48,23 +48,30 @@ export default { loading: false, } }, + computed: { + authProviders() { + return this.$store.getters['authProviderAdmin/getAll'] + }, + appAuthProviderPerTypes() { + return Object.fromEntries( + this.$registry + .getOrderedList('authProvider') + .map((authProviderType) => [ + authProviderType.getType(), + this.authProviders[authProviderType.getType()].authProviders, + ]) + ) + }, + }, methods: { getProviderAdminSettingsFormComponent() { - return this.$registry - .get('authProvider', this.authProviderType) - .getAdminSettingsFormComponent() - }, - getProviderTypeName() { - if (!this.authProviderType) return '' - - return this.$registry.get('authProvider', this.authProviderType).getName() + return this.authProviderType.getAdminSettingsFormComponent() }, async create(values) { this.loading = true - this.serverErrors = {} try { await this.$store.dispatch('authProviderAdmin/create', { - type: this.authProviderType, + type: this.authProviderType.getType(), values, }) this.$emit('created') diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/UpdateSettingsAuthProviderModal.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/UpdateSettingsAuthProviderModal.vue index 838c49eec..45eb7e425 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/UpdateSettingsAuthProviderModal.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/admin/modals/UpdateSettingsAuthProviderModal.vue @@ -11,7 +11,10 @@ <component :is="getProviderAdminSettingsFormComponent()" ref="providerSettingsForm" + :auth-providers="appAuthProviderPerTypes" :auth-provider="authProvider" + :default-values="authProvider" + :auth-provider-type="authProviderType" @submit="onSettingsUpdated" > <div class="actions"> @@ -22,8 +25,8 @@ </ul> <Button type="primary" :disabled="loading" :loading="loading"> - {{ $t('action.save') }}</Button - > + {{ $t('action.save') }} + </Button> </div> </component> </div> @@ -48,16 +51,30 @@ export default { loading: false, } }, + computed: { + authProviderType() { + return this.$registry.get('authProvider', this.authProvider.type) + }, + authProviders() { + return this.$store.getters['authProviderAdmin/getAll'] + }, + appAuthProviderPerTypes() { + return Object.fromEntries( + this.$registry + .getOrderedList('authProvider') + .map((authProviderType) => [ + authProviderType.getType(), + this.authProviders[authProviderType.getType()].authProviders, + ]) + ) + }, + }, methods: { getProviderAdminSettingsFormComponent() { - return this.$registry - .get('authProvider', this.authProvider.type) - .getAdminSettingsFormComponent() + return this.authProviderType.getAdminSettingsFormComponent() }, getProviderName() { - return this.$registry - .get('authProvider', this.authProvider.type) - .getProviderName(this.authProvider) + return this.authProviderType.getProviderName(this.authProvider) }, async onSettingsUpdated(values) { this.loading = true diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/appAuthProviderTypes.js b/enterprise/web-frontend/modules/baserow_enterprise/integrations/appAuthProviderTypes.js index 19eabf8e3..0639edc66 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/integrations/appAuthProviderTypes.js +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/appAuthProviderTypes.js @@ -1,5 +1,10 @@ import { AppAuthProviderType } from '@baserow/modules/core/appAuthProviderTypes' +import { SamlAuthProviderTypeMixin } from '@baserow_enterprise/authProviderTypes' + import LocalBaserowUserSourceForm from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm' +import LocalBaserowAuthPassword from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowAuthPassword' +import CommonSamlSettingForm from '@baserow_enterprise/integrations/common/components/CommonSamlSettingForm' +import SamlAuthLink from '@baserow_enterprise/integrations/common/components/SamlAuthLink' import { PasswordFieldType } from '@baserow/modules/database/fieldTypes' export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType { @@ -11,6 +16,10 @@ export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType return this.app.i18n.t('appAuthProviderType.localBaserowPassword') } + get component() { + return LocalBaserowAuthPassword + } + get formComponent() { return LocalBaserowUserSourceForm } @@ -23,7 +32,91 @@ export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType return [PasswordFieldType.getType()] } + getLoginOptions(authProvider) { + if (authProvider.password_field_id) { + return {} + } + return null + } + + /** + * We can create only one password provider. + */ + canCreateNew(appAuthProviders) { + return ( + !appAuthProviders[this.getType()] || + appAuthProviders[this.getType()].length === 0 + ) + } + getOrder() { return 10 } } + +export class SamlAppAuthProviderType extends SamlAuthProviderTypeMixin( + AppAuthProviderType +) { + get name() { + return this.app.i18n.t('appAuthProviderType.commonSaml') + } + + get component() { + return SamlAuthLink + } + + get formComponent() { + return CommonSamlSettingForm + } + + handleServerError(vueComponentInstance, error) { + if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false + + if (error.handler.detail?.auth_providers?.length > 0) { + const flatProviders = Object.entries(vueComponentInstance.authProviders) + .map(([, providers]) => providers) + .flat() + // Sort per ID to make sure we have the same order + // as the backend + .sort((a, b) => a.id - b.id) + + for (const [ + index, + authError, + ] of error.handler.detail.auth_providers.entries()) { + if ( + Object.keys(authError).length > 0 && + flatProviders[index].id === vueComponentInstance.authProvider.id + ) { + vueComponentInstance.serverErrors = { + ...vueComponentInstance.serverErrors, + ...authError, + } + return true + } + } + } + + return false + } + + getAuthToken(userSource, authProvider, route) { + // token can be in the query string (SSO) or in the cookies (previous session) + // We use the user source id in order to prevent conflicts when using multiple + // auth forms on the same page. + const queryParamName = `user_source_saml_token__${userSource.id}` + return route.query[queryParamName] + } + + handleError(userSource, authProvider, route) { + const queryParamName = `saml_error__${userSource.id}` + const errorCode = route.query[queryParamName] + if (errorCode) { + return { message: this.app.i18n.t(`loginError.${errorCode}`), code: 500 } + } + } + + getOrder() { + return 20 + } +} diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/CommonSamlSettingForm.vue b/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/CommonSamlSettingForm.vue new file mode 100644 index 000000000..50e35ea9b --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/CommonSamlSettingForm.vue @@ -0,0 +1,83 @@ +<template> + <div> + <div + class="common-saml-setting-form" + :class="{ 'common-saml-setting-form--error': inError }" + > + <Presentation + class="flex-grow-1" + :title="authProviderType.getProviderName(authProvider)" + size="medium" + avatar-color="neutral" + :image="authProviderType.getIcon()" + @click="onEdit" + /> + <div class="common-saml-setting-form__actions"> + <ButtonIcon + type="secondary" + icon="iconoir-edit" + @click.prevent="onEdit" + /> + <ButtonIcon + type="secondary" + icon="iconoir-bin" + @click.prevent="onDelete" + /> + </div> + </div> + <div v-if="inError" class="error"> + {{ $t('commonSamlSettingForm.authProviderInError') }} + </div> + <CommonSamlSettingModal + ref="samlModal" + v-bind="$props" + :integration="integration" + :user-source="userSource" + @form-valid="onFormValid($event)" + v-on="$listeners" + ></CommonSamlSettingModal> + </div> +</template> + +<script> +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' +import CommonSamlSettingModal from '@baserow_enterprise/integrations/common/components/CommonSamlSettingModal' + +export default { + name: 'CommonSamlSettingForm', + components: { CommonSamlSettingModal }, + mixins: [authProviderForm], + inject: ['builder'], + props: { + integration: { + type: Object, + required: true, + }, + userSource: { + type: Object, + required: true, + }, + }, + data() { + return { inError: false } + }, + methods: { + onFormValid(value) { + this.inError = !value + }, + onEdit() { + this.$refs.samlModal.show() + }, + onDelete() { + this.$emit('delete') + }, + handleServerError(error) { + if (this.$refs.samlModal.handleServerError(error)) { + this.inError = true + return true + } + return false + }, + }, +} +</script> diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/CommonSamlSettingModal.vue b/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/CommonSamlSettingModal.vue new file mode 100644 index 000000000..6d56d95d4 --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/CommonSamlSettingModal.vue @@ -0,0 +1,172 @@ +<template> + <Modal ref="modal" keep-content @hidden="onHide"> + <h2 class="box__title">{{ $t('commonSamlSettingModal.title') }}</h2> + <div> + <SamlSettingsForm + v-bind="$props" + ref="samlForm" + @values-changed="checkValidity" + v-on="$listeners" + > + <template #config> + <FormGroup + small-label + required + :label="$t('commonSamlSettingModal.relayStateTitle')" + class="margin-bottom-2" + > + <div class="common-saml-setting-modal__url-block"> + <div + v-for="conf in config" + :key="conf.name" + class="common-saml-setting-modal__url" + @click.prevent=" + ;[copyToClipboard(conf.relay), $refs.copiedRelay.show()] + " + > + <span class="common-saml-setting-modal__url-domain"> + {{ conf.name }} + </span> + <span + class="common-saml-setting-modal__url-dest" + :title="conf.relay" + > + {{ conf.relay }} + </span> + </div> + <Copied ref="copiedRelay"></Copied> + </div> + </FormGroup> + + <FormGroup + small-label + required + :label="$t('commonSamlSettingModal.acsTitle')" + class="margin-bottom-2" + > + <div class="common-saml-setting-modal__url-block"> + <div + v-for="conf in config" + :key="conf.name" + class="common-saml-setting-modal__url" + @click.prevent=" + ;[copyToClipboard(conf.acs), $refs.copiedACS.show()] + " + > + <span class="common-saml-setting-modal__url-domain"> + {{ conf.name }} + </span> + <span + class="common-saml-setting-modal__url-dest" + :title="conf.acs" + > + {{ conf.acs }} + </span> + </div> + <Copied ref="copiedACS"></Copied> + </div> + </FormGroup> + </template> + </SamlSettingsForm> + <div class="actions actions--right"> + <Button size="large" @click.prevent="$refs.modal.hide()"> + {{ $t('action.close') }} + </Button> + </div> + </div> + </Modal> +</template> + +<script> +import SamlSettingsForm from '@baserow_enterprise/components/admin/forms/SamlSettingsForm' +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' +import error from '@baserow/modules/core/mixins/error' +import { copyToClipboard } from '@baserow/modules/database/utils/clipboard' +import { mapActions, mapGetters } from 'vuex' +import modal from '@baserow/modules/core/mixins/modal' + +export default { + name: 'CommonSamlSettingsModal', + components: { SamlSettingsForm }, + mixins: [error, authProviderForm, modal], + inject: ['builder'], + props: { + integration: { + type: Object, + required: true, + }, + userSource: { + type: Object, + required: true, + }, + }, + async fetch() { + try { + await this.actionFetchDomains({ builderId: this.builder.id }) + } catch (error) { + this.handleError(error) + } + }, + watch: { + '$v.$anyDirty'() { + // Force validity refresh on child touch + this.checkValidity() + }, + }, + computed: { + ...mapGetters({ domains: 'domain/getDomains' }), + config() { + const previewRelay = `${this.$config.PUBLIC_WEB_FRONTEND_URL}/builder/${this.builder.id}/preview/` + const previewACS = `${this.$config.PUBLIC_BACKEND_URL}/api/user-source/${this.userSource.uid}/sso/saml/acs/` + const preview = [ + { + name: this.$t('commonSamlSettingModal.preview'), + acs: previewACS, + relay: previewRelay, + }, + ] + + const others = this.domains.map((domain) => ({ + name: domain.domain_name, + acs: `${this.$config.PUBLIC_BACKEND_URL}/api/user-source/domain_${domain.id}__${this.userSource.uid}/sso/saml/acs/`, + relay: this.getDomainUrl(domain), + })) + return [...preview, ...others] + }, + }, + methods: { + ...mapActions({ + actionFetchDomains: 'domain/fetch', + }), + copyToClipboard(value) { + copyToClipboard(value) + }, + onHide() { + this.checkValidity() + }, + checkValidity() { + if ( + !this.$refs.samlForm.isFormValid() && + this.$refs.samlForm.$v.$anyDirty + ) { + this.$emit('form-valid', false) + } else { + this.$emit('form-valid', true) + } + }, + getDomainUrl(domain) { + const url = new URL(this.$config.PUBLIC_WEB_FRONTEND_URL) + return `${url.protocol}//${domain.domain_name}${ + url.port ? `:${url.port}` : '' + }` + }, + handleServerError(error) { + return this.$refs.samlForm.handleServerError(error) + }, + }, + validations() { + // Keep this to get the `$v` property + return {} + }, +} +</script> diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/SamlAuthLink.vue b/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/SamlAuthLink.vue new file mode 100644 index 000000000..482f6afb8 --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/common/components/SamlAuthLink.vue @@ -0,0 +1,135 @@ +<template> + <div class="saml-auth-link"> + <ABButton @click.prevent="onClick()"> + {{ buttonLabel }} + </ABButton> + <Modal ref="modal"> + <ThemeProvider> + <form @submit.prevent=""> + <ABFormGroup + v-if="hasMultipleSamlProvider" + :label="$t('samlAuthLink.provideEmail')" + :error-message=" + $v.values.email.$dirty + ? !$v.values.email.required + ? $t('error.requiredField') + : !$v.values.email.email + ? $t('error.invalidEmail') + : '' + : '' + " + :autocomplete="isEditMode ? 'off' : ''" + required + > + <ABInput + v-model="values.email" + :placeholder="$t('samlAuthLink.emailPlaceholder')" + @blur="$v.values.email.$touch()" + /> + </ABFormGroup> + <div class="saml-auth-link__modal-footer"> + <ABButton class="margin-top-2" @click.prevent.stop="login()"> + {{ buttonLabel }} + </ABButton> + </div> + </form> + </ThemeProvider> + </Modal> + </div> +</template> + +<script> +import form from '@baserow/modules/core/mixins/form' +import error from '@baserow/modules/core/mixins/error' +import { required, email } from 'vuelidate/lib/validators' +import ThemeProvider from '@baserow/modules/builder/components/theme/ThemeProvider' + +export default { + components: { ThemeProvider }, + mixins: [form, error], + inject: ['builder', 'mode'], + props: { + userSource: { type: Object, required: true }, + authProviders: { + type: Array, + required: true, + }, + loginButtonLabel: { + type: String, + required: true, + }, + }, + data() { + return { + loading: false, + values: { email: '' }, + } + }, + computed: { + isAuthenticated() { + return this.$store.getters['userSourceUser/isAuthenticated'](this.builder) + }, + hasMultipleSamlProvider() { + return this.authProviders.length > 1 + }, + isEditMode() { + return this.mode === 'editing' + }, + buttonLabel() { + return this.$t('samlAuthLink.placeholderWithSaml', { + login: this.loginButtonLabel, + }) + }, + }, + methods: { + async onClick() { + if (this.hasMultipleSamlProvider) { + this.$refs.modal.show() + } else { + await this.login() + } + }, + async login() { + if (this.isAuthenticated) { + await this.$store.dispatch('userSourceUser/logoff', { + application: this.builder, + }) + } + + if (this.hasMultipleSamlProvider) { + this.$v.$touch() + + if (this.$v.$invalid) { + this.focusOnFirstError() + return + } + } + + this.loading = true + this.hideError() + + const dest = `${ + this.$config.PUBLIC_BACKEND_URL + }/api/user-source/${encodeURIComponent( + this.userSource.uid + )}/sso/saml/login/` + + const urlWithParams = new URL(dest) + + if (this.hasMultipleSamlProvider) { + urlWithParams.searchParams.append('email', this.values.email) + } + + // Add the current url as get parameter to be redirected here after the login. + urlWithParams.searchParams.append('original', window.location) + + window.location = urlWithParams.toString() + }, + }, + validations: { + values: { + email: { required, email }, + }, + }, +} +</script> diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowAuthPassword.vue b/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowAuthPassword.vue new file mode 100644 index 000000000..bbc8c51ca --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowAuthPassword.vue @@ -0,0 +1,148 @@ +<template> + <form class="auth-form-element" @submit.prevent="onLogin"> + <Error :error="error"></Error> + <ABFormGroup + :label="$t('authFormElement.email')" + :error-message=" + $v.values.email.$dirty + ? !$v.values.email.required + ? $t('error.requiredField') + : !$v.values.email.email + ? $t('error.invalidEmail') + : '' + : '' + " + :autocomplete="isEditMode ? 'off' : ''" + required + > + <ABInput + v-model="values.email" + :placeholder="$t('authFormElement.emailPlaceholder')" + @blur="$v.values.email.$touch()" + /> + </ABFormGroup> + <ABFormGroup + :label="$t('authFormElement.password')" + :error-message=" + $v.values.password.$dirty + ? !$v.values.password.required + ? $t('error.requiredField') + : '' + : '' + " + required + > + <ABInput + ref="passwordRef" + v-model="values.password" + type="password" + :placeholder="$t('authFormElement.passwordPlaceholder')" + @blur="$v.values.password.$touch()" + /> + </ABFormGroup> + <div class="auth-form__footer"> + <ABButton :disabled="$v.$error" :loading="loading" size="large"> + {{ loginButtonLabel }} + </ABButton> + </div> + </form> +</template> + +<script> +import form from '@baserow/modules/core/mixins/form' +import error from '@baserow/modules/core/mixins/error' +import { required, email } from 'vuelidate/lib/validators' + +export default { + mixins: [form, error], + inject: ['builder', 'mode'], + props: { + userSource: { type: Object, required: true }, + authProviders: { + type: Array, + required: true, + }, + loginButtonLabel: { + type: String, + required: true, + }, + }, + data() { + return { + loading: false, + values: { email: '', password: '' }, + } + }, + computed: { + isAuthenticated() { + return this.$store.getters['userSourceUser/isAuthenticated'](this.builder) + }, + isEditMode() { + return this.mode === 'editing' + }, + }, + methods: { + async onLogin(event) { + if (this.isAuthenticated) { + await this.$store.dispatch('userSourceUser/logoff', { + application: this.builder, + }) + } + + this.$v.$touch() + if (this.$v.$invalid) { + this.focusOnFirstError() + return + } + this.loading = true + this.hideError() + try { + await this.$store.dispatch('userSourceUser/authenticate', { + application: this.builder, + userSource: this.userSource, + credentials: { + email: this.values.email, + password: this.values.password, + }, + setCookie: this.mode === 'public', + }) + this.values.password = '' + this.values.email = '' + this.$v.$reset() + this.$emit('after-login') + } catch (error) { + if (error.handler) { + const response = error.handler.response + if (response && response.status === 401) { + this.values.password = '' + this.$v.$reset() + this.$v.$touch() + this.$refs.passwordRef.focus() + + if (response.data?.error === 'ERROR_INVALID_CREDENTIALS') { + this.showError( + this.$t('error.incorrectCredentialTitle'), + this.$t('error.incorrectCredentialMessage') + ) + } + } else { + const message = error.handler.getMessage('login') + this.showError(message) + } + + error.handler.handled() + } else { + throw error + } + } + this.loading = false + }, + }, + validations: { + values: { + email: { required, email }, + password: { required }, + }, + }, +} +</script> diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm.vue b/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm.vue index 6854b6415..81f931351 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm.vue @@ -5,7 +5,6 @@ small-label horizontal horizontal-variable - class="margin-top-2" required > <Dropdown @@ -32,16 +31,16 @@ </template> <script> -import form from '@baserow/modules/core/mixins/form' +import authProviderForm from '@baserow/modules/core/mixins/authProviderForm' export default { - mixins: [form], + mixins: [authProviderForm], props: { integration: { type: Object, required: true, }, - currentUserSource: { + userSource: { type: Object, required: true, }, @@ -55,9 +54,6 @@ export default { } }, computed: { - authProviderType() { - return this.$registry.get('appAuthProvider', 'local_baserow_password') - }, databases() { return this.integration.context_data.databases }, @@ -65,12 +61,12 @@ export default { return this.$registry.getAll('field') }, selectedTable() { - if (!this.currentUserSource.table_id) { + if (!this.userSource.table_id) { return null } for (const database of this.databases) { for (const table of database.tables) { - if (table.id === this.currentUserSource.table_id) { + if (table.id === this.userSource.table_id) { return table } } @@ -91,7 +87,7 @@ export default { }, }, watch: { - 'currentUserSource.table_id'() { + 'userSource.table_id'() { this.values.password_field_id = null }, }, diff --git a/enterprise/web-frontend/modules/baserow_enterprise/integrations/userSourceTypes.js b/enterprise/web-frontend/modules/baserow_enterprise/integrations/userSourceTypes.js index 33debbace..2f35e2ba4 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/integrations/userSourceTypes.js +++ b/enterprise/web-frontend/modules/baserow_enterprise/integrations/userSourceTypes.js @@ -117,17 +117,22 @@ export class LocalBaserowUserSourceType extends UserSourceType { if (!userSource.email_field_id || !userSource.name_field_id) { return {} } - if (userSource.auth_providers.length !== 1) { - return {} - } - const authProvider = userSource.auth_providers[0] - if ( - authProvider.type !== 'local_baserow_password' || - !authProvider.password_field_id - ) { - return {} - } - return { password: {} } + + return userSource.auth_providers.reduce((acc, authProvider) => { + if (!acc[authProvider.type]) { + acc[authProvider.type] = [] + } + + const loginOptions = this.app.$registry + .get('appAuthProvider', authProvider.type) + .getLoginOptions(authProvider) + + if (loginOptions) { + acc[authProvider.type].push(loginOptions) + } + + return acc + }, {}) } getOrder() { diff --git a/enterprise/web-frontend/modules/baserow_enterprise/locales/en.json b/enterprise/web-frontend/modules/baserow_enterprise/locales/en.json index 1dff8fd92..875811b74 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/locales/en.json +++ b/enterprise/web-frontend/modules/baserow_enterprise/locales/en.json @@ -95,14 +95,17 @@ "addProvider": "Add provider" }, "authProviderTypes": { - "password": "Email and password authentication" + "password": "Email and password authentication", + "saml": "SSO SAML provider", + "ssoSamlProviderName": "SSO SAML: {domain}", + "ssoSamlProviderNameUnconfigured": "Unconfigured SSO SAML" }, "editAuthProviderMenuContext": { "edit": "Edit", "delete": "Delete" }, "samlSettingsForm": { - "domain": "Domain", + "domain": "SAML Domain", "domainPlaceholder": "Insert the company domain name...", "invalidDomain": "Invalid domain name", "domainAlreadyExists": "A SAML provider for this domain already exists", @@ -311,12 +314,22 @@ "roleFieldPlaceholder": "Select a field..." }, "appAuthProviderType": { - "localBaserowPassword": "Email/Password" + "localBaserowPassword": "Email/Password", + "commonSaml": "Saml SSO" }, "localBaserowPasswordAppAuthProviderForm": { "passwordFieldLabel": "Select password field", "noFields": "No compatible fields" }, + "commonSamlSettingModal": { + "title": "Edit provider", + "relayStateTitle": "Default Relay State URL per domain (Click to copy)", + "acsTitle": "Single Sign On URL per domain (Click to copy)", + "preview": "Preview" + }, + "commonSamlSettingForm": { + "authProviderInError": "Please edit this provider to fix the error." + }, "enterpriseSettings": { "branding": "Branding", "showBaserowHelpMessage": "Show help message", @@ -383,5 +396,11 @@ "projectIdHelper": "The ID of the project where you want to sync the issues from. Can be found by going to your project page (e.g. https://gitlab.com/baserow/baserow), click on the three dots in the top right corner, and then on 'Copy project ID: 12345678'.", "accessToken": "Access token", "accessTokenHelper": "Can be generated here https://gitlab.com/-/user_settings/personal_access_tokens by clicking on `Add new token`, select `read_api`." + }, + "samlAuthLink": { + "loginWithSaml": "Login with SAML", + "placeholderWithSaml": "{login} with SAML", + "provideEmail": "Provide your SAML account email", + "emailPlaceholder": "Enter your email..." } } diff --git a/enterprise/web-frontend/modules/baserow_enterprise/pages/admin/authProviders.vue b/enterprise/web-frontend/modules/baserow_enterprise/pages/admin/authProviders.vue index 4d3453f02..d8ab36aef 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/pages/admin/authProviders.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/pages/admin/authProviders.vue @@ -13,6 +13,7 @@ @create="showCreateModal($event)" /> <CreateAuthProviderModal + v-if="authProviderTypeToCreate" ref="createModal" :auth-provider-type="authProviderTypeToCreate" @created="$refs.createModal.hide()" @@ -57,9 +58,15 @@ export default { }, computed: { ...mapGetters({ + authProviderMap: 'authProviderAdmin/getAll', authProviders: 'authProviderAdmin/getAllOrdered', - authProviderTypesCanBeCreated: 'authProviderAdmin/getCreatableTypes', }), + authProviderTypesCanBeCreated() { + return Object.values(this.$registry.getAll('authProvider')).filter( + (authProviderType) => + authProviderType.canCreateNew(this.authProviderMap) + ) + }, }, methods: { getAdminListComponent(authProvider) { @@ -75,8 +82,10 @@ export default { 4 ) }, - showCreateModal(authProviderType) { - this.authProviderTypeToCreate = authProviderType.type + async showCreateModal(authProviderType) { + this.authProviderTypeToCreate = authProviderType + // Wait for the modal to appear in DOM + await this.$nextTick() this.$refs.createModal.show() this.$refs.createContext.hide() }, diff --git a/enterprise/web-frontend/modules/baserow_enterprise/plugin.js b/enterprise/web-frontend/modules/baserow_enterprise/plugin.js index b6c3361ec..3b0b8f1cb 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/plugin.js +++ b/enterprise/web-frontend/modules/baserow_enterprise/plugin.js @@ -27,7 +27,10 @@ import { } from '@baserow_enterprise/licenseTypes' import { EnterprisePlugin } from '@baserow_enterprise/plugins' import { LocalBaserowUserSourceType } from '@baserow_enterprise/integrations/userSourceTypes' -import { LocalBaserowPasswordAppAuthProviderType } from '@baserow_enterprise/integrations/appAuthProviderTypes' +import { + LocalBaserowPasswordAppAuthProviderType, + SamlAppAuthProviderType, +} from '@baserow_enterprise/integrations/appAuthProviderTypes' import { AuthFormElementType } from '@baserow_enterprise/builder/elementTypes' import { EnterpriseAdminRoleType, @@ -46,6 +49,8 @@ import { GitLabIssuesDataSyncType, } from '@baserow_enterprise/dataSyncTypes' +import { FF_AB_SSO } from '@baserow/modules/core/plugins/featureFlags' + export default (context) => { const { app, isDev, store } = context @@ -115,6 +120,13 @@ export default (context) => { new LocalBaserowPasswordAppAuthProviderType(context) ) + if (app.$featureFlagIsEnabled(FF_AB_SSO)) { + app.$registry.register( + 'appAuthProvider', + new SamlAppAuthProviderType(context) + ) + } + app.$registry.register('roles', new EnterpriseAdminRoleType(context)) app.$registry.register('roles', new EnterpriseMemberRoleType(context)) app.$registry.register('roles', new EnterpriseBuilderRoleType(context)) diff --git a/enterprise/web-frontend/modules/baserow_enterprise/store/authProviderAdmin.js b/enterprise/web-frontend/modules/baserow_enterprise/store/authProviderAdmin.js index b370a6a86..c56775be5 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/store/authProviderAdmin.js +++ b/enterprise/web-frontend/modules/baserow_enterprise/store/authProviderAdmin.js @@ -118,15 +118,6 @@ export const getters = { } return authProviders }, - getCreatableTypes: (state) => { - const items = [] - for (const authProviderType of Object.values(state.items)) { - if (authProviderType.canCreateNewProviders) { - items.push(authProviderType) - } - } - return items - }, getNextProviderId: (state) => { return state.nextProviderId }, diff --git a/web-frontend/modules/builder/components/PublicSiteErrorPage.vue b/web-frontend/modules/builder/components/PublicSiteErrorPage.vue index eeabfe467..ff7c016d6 100644 --- a/web-frontend/modules/builder/components/PublicSiteErrorPage.vue +++ b/web-frontend/modules/builder/components/PublicSiteErrorPage.vue @@ -16,17 +16,7 @@ </p> <p v-else class="placeholder__content">{{ content }}</p> <div class="placeholder__action"> - <Button - type="primary" - icon="iconoir-home" - size="large" - @click=" - $router.go({ - name: 'application-builder-page', - params: { pathMatch: '/' }, - }) - " - > + <Button type="primary" icon="iconoir-home" size="large" @click="onHome()"> {{ $t('action.backToHome') }} </Button> </div> @@ -62,5 +52,15 @@ export default { return this.error.content || this.$t('errorLayout.error') }, }, + methods: { + onHome() { + this.$router.push({ + name: 'application-builder-page', + params: { pathMatch: '/' }, + // We remove the query parameters. Important if we have some with error + query: {}, + }) + }, + }, } </script> diff --git a/web-frontend/modules/builder/components/page/settings/PageVisibilityForm.vue b/web-frontend/modules/builder/components/page/settings/PageVisibilityForm.vue index 267626d7a..246902d55 100644 --- a/web-frontend/modules/builder/components/page/settings/PageVisibilityForm.vue +++ b/web-frontend/modules/builder/components/page/settings/PageVisibilityForm.vue @@ -11,10 +11,7 @@ <slot name="title">{{ $t('pageVisibilitySettingsTypes.logInPageWarningTitle') }}</slot> - <!-- eslint-disable-next-line vue/no-v-html vue/no-v-text-v-html-on-component --> - <p - v-html="$t('pageVisibilitySettingsTypes.logInPagewarningMessage')" - ></p> + <p>{{ $t('pageVisibilitySettingsTypes.logInPagewarningMessage') }}</p> </Alert> <Alert v-else-if="showLoginPageAlert && !showLogInPageWarning" @@ -23,14 +20,13 @@ <slot name="title">{{ $t('pageVisibilitySettingsTypes.logInPageInfoTitle') }}</slot> - <!-- eslint-disable-next-line vue/no-v-html vue/no-v-text-v-html-on-component --> - <p - v-html=" + <p> + {{ $t('pageVisibilitySettingsTypes.logInPageInfoMessage', { logInPageName: loginPageName, }) - " - ></p> + }} + </p> </Alert> </div> <div class="margin-top-1 visibility-form__visibility-all"> diff --git a/web-frontend/modules/builder/components/settings/UserSourcesSettings.vue b/web-frontend/modules/builder/components/settings/UserSourcesSettings.vue index 3be1ffede..a94260f48 100644 --- a/web-frontend/modules/builder/components/settings/UserSourcesSettings.vue +++ b/web-frontend/modules/builder/components/settings/UserSourcesSettings.vue @@ -53,7 +53,6 @@ <UpdateUserSourceForm ref="userSourceForm" :builder="builder" - :integrations="integrations" :user-source-type="getUserSourceType(editedUserSource)" :default-values="editedUserSource" @submitted="updateUserSource" @@ -71,7 +70,7 @@ :disabled="actionInProgress || invalidForm" :loading="actionInProgress" size="large" - @click="$refs.userSourceForm.submit()" + @click="$refs.userSourceForm.submit(true)" > {{ $t('action.save') }} </Button> @@ -86,7 +85,6 @@ <CreateUserSourceForm ref="userSourceForm" :builder="builder" - :integrations="integrations" @submitted="createUserSource" @values-changed="onValueChange" /> @@ -122,6 +120,9 @@ export default { name: 'UserSourceSettings', components: { CreateUserSourceForm, UpdateUserSourceForm }, mixins: [error], + provide() { + return { builder: this.builder } + }, props: { builder: { type: Object, @@ -169,7 +170,7 @@ export default { return this.$registry.get('userSource', userSource.type) }, onValueChange() { - this.invalidForm = !this.$refs.userSourceForm.isFormValid() + this.invalidForm = !this.$refs.userSourceForm.isFormValid(true) }, async showForm(userSourceToEdit) { if (userSourceToEdit) { @@ -205,6 +206,10 @@ export default { this.actionInProgress = false }, async updateUserSource(newValues) { + if (!this.$refs.userSourceForm.isFormValid(true)) { + return + } + this.actionInProgress = true try { await this.actionUpdateUserSource({ @@ -215,8 +220,10 @@ export default { this.hideForm() } catch (error) { // Restore the previously saved values from the store - this.$refs.userSourceForm.reset() - this.handleError(error) + if (!this.$refs.userSourceForm.handleServerError(error)) { + this.$refs.userSourceForm.reset() + this.handleError(error) + } } this.actionInProgress = false }, diff --git a/web-frontend/modules/builder/components/userSource/CreateUserSourceForm.vue b/web-frontend/modules/builder/components/userSource/CreateUserSourceForm.vue index 190e83595..59ddb6bfe 100644 --- a/web-frontend/modules/builder/components/userSource/CreateUserSourceForm.vue +++ b/web-frontend/modules/builder/components/userSource/CreateUserSourceForm.vue @@ -67,10 +67,6 @@ export default { type: Object, required: true, }, - integrations: { - type: Array, - required: true, - }, }, data() { return { @@ -78,6 +74,9 @@ export default { } }, computed: { + integrations() { + return this.$store.getters['integration/getIntegrations'](this.builder) + }, userSources() { return this.$store.getters['userSource/getUserSources'](this.builder) }, @@ -117,6 +116,10 @@ export default { } return '' }, + + handleServerError() { + return false + }, }, validations: { values: { diff --git a/web-frontend/modules/builder/components/userSource/UpdateUserSourceForm.vue b/web-frontend/modules/builder/components/userSource/UpdateUserSourceForm.vue index 0368371cd..2892364d3 100644 --- a/web-frontend/modules/builder/components/userSource/UpdateUserSourceForm.vue +++ b/web-frontend/modules/builder/components/userSource/UpdateUserSourceForm.vue @@ -47,13 +47,28 @@ <div v-for="appAuthType in appAuthProviderTypes" :key="appAuthType.type" + class="update-user-source-form__auth-provider" > - <Checkbox - :checked="hasAtLeastOneOfThisType(appAuthType)" - @input="onSelect(appAuthType)" - > - {{ appAuthType.name }} - </Checkbox> + <div class="update-user-source-form__auth-provider-header"> + <Checkbox + :checked="hasAtLeastOneOfThisType(appAuthType)" + @input="onSelect(appAuthType)" + > + {{ appAuthType.name }} + </Checkbox> + + <ButtonText + v-if=" + hasAtLeastOneOfThisType(appAuthType) && + appAuthType.canCreateNew(appAuthProviderPerTypes) + " + icon="iconoir-plus" + type="secondary" + @click.prevent="addNew(appAuthType)" + > + {{ $t('updateUserSourceForm.addProvider') }} + </ButtonText> + </div> <div v-for="appAuthProvider in appAuthProviderPerTypes[appAuthType.type]" @@ -63,11 +78,16 @@ <component :is="appAuthType.formComponent" v-if="hasAtLeastOneOfThisType(appAuthType)" - :integration="integration" - :current-user-source="fullValues" - :default-values="appAuthProvider" + :ref="`authProviderForm`" excluded-form + :integration="integration" + :user-source="fullValues" + :auth-providers="appAuthProviderPerTypes" + :auth-provider="appAuthProvider" + :default-values="appAuthProvider" + :auth-provider-type="appAuthType" @values-changed="updateAuthProvider(appAuthProvider, $event)" + @delete="remove(appAuthProvider)" /> </div> </div> @@ -82,7 +102,6 @@ import form from '@baserow/modules/core/mixins/form' import IntegrationDropdown from '@baserow/modules/core/components/integrations/IntegrationDropdown' import { required, maxLength } from 'vuelidate/lib/validators' -import { uuid } from '@baserow/modules/core/utils/string' export default { components: { IntegrationDropdown }, @@ -97,10 +116,6 @@ export default { required: false, default: null, }, - integrations: { - type: Array, - required: true, - }, }, data() { return { @@ -113,6 +128,9 @@ export default { } }, computed: { + integrations() { + return this.$store.getters['integration/getIntegrations'](this.builder) + }, integration() { if (!this.values.integration_id) { return null @@ -140,6 +158,8 @@ export default { methods: { // Override the default getChildFormValues to exclude the provider forms from // final values as they are handled directly by this component + // The problem is that the child provider forms are not handled as a sub array + // so they override the userSource configuration getChildFormsValues() { return Object.assign( {}, @@ -157,6 +177,14 @@ export default { hasAtLeastOneOfThisType(appAuthProviderType) { return this.appAuthProviderPerTypes[appAuthProviderType.type]?.length > 0 }, + /** Return an integer bigger than any of the current auth_provider id to + * keep the right order when we want to map the error coming back from the server. + */ + nextID() { + return ( + Math.max(1, ...this.values.auth_providers.map(({ id }) => id)) + 100 + ) + }, onSelect(appAuthProviderType) { if (this.hasAtLeastOneOfThisType(appAuthProviderType)) { this.values.auth_providers = this.values.auth_providers.filter( @@ -165,10 +193,21 @@ export default { } else { this.values.auth_providers.push({ type: appAuthProviderType.type, - id: uuid(), + id: this.nextID(), }) } }, + addNew(appAuthProviderType) { + this.values.auth_providers.push({ + type: appAuthProviderType.type, + id: this.nextID(), + }) + }, + remove(appAuthProvider) { + this.values.auth_providers = this.values.auth_providers.filter( + ({ id }) => id !== appAuthProvider.id + ) + }, updateAuthProvider(authProviderToChange, values) { this.values.auth_providers = this.values.auth_providers.map( (authProvider) => { @@ -182,6 +221,16 @@ export default { emitChange() { this.fullValues = this.getFormValues() }, + handleServerError(error) { + if ( + this.$refs.authProviderForm + .map((form) => form.handleServerError(error)) + .some((result) => result) + ) { + return true + } + return false + }, getError(fieldName) { if (!this.$v.values[fieldName].$dirty) { return '' diff --git a/web-frontend/modules/builder/locales/en.json b/web-frontend/modules/builder/locales/en.json index 6496950a7..b83a4e05c 100644 --- a/web-frontend/modules/builder/locales/en.json +++ b/web-frontend/modules/builder/locales/en.json @@ -732,7 +732,8 @@ "nameFieldLabel": "Name", "nameFieldPlaceholder": "Enter a name...", "authTitle": "Authentication", - "integrationFieldLabel": "Integration" + "integrationFieldLabel": "Integration", + "addProvider": "Add provider" }, "builderLoginPageForm": { "pageDropdownLabel": "Login Page", diff --git a/web-frontend/modules/builder/pages/publicPage.vue b/web-frontend/modules/builder/pages/publicPage.vue index 18c2063e6..37cfe9b95 100644 --- a/web-frontend/modules/builder/pages/publicPage.vue +++ b/web-frontend/modules/builder/pages/publicPage.vue @@ -24,8 +24,8 @@ import { userCanViewPage } from '@baserow/modules/builder/utils/visibility' import { getTokenIfEnoughTimeLeft, - setToken, userSourceCookieTokenName, + setToken, } from '@baserow/modules/core/utils/auth' const logOffAndReturnToLogin = async ({ builder, store, redirect }) => { @@ -59,8 +59,8 @@ export default { $registry, app, req, - route, redirect, + route, }) { let mode = 'public' const builderId = params.builderId ? parseInt(params.builderId, 10) : null @@ -71,6 +71,7 @@ export default { } let builder = store.getters['application/getSelected'] + let needPostBuilderLoading = false if (!builder || (builderId && builderId !== builder.id)) { try { @@ -105,7 +106,40 @@ export default { }) } + needPostBuilderLoading = true + } + store.dispatch('userSourceUser/setCurrentApplication', { + application: builder, + }) + + if ( + (!process.server || req) && + !store.getters['userSourceUser/isAuthenticated'](builder) + ) { + const refreshToken = getTokenIfEnoughTimeLeft( + app, + userSourceCookieTokenName + ) + if (refreshToken) { + try { + await store.dispatch('userSourceUser/refreshAuth', { + application: builder, + token: refreshToken, + }) + } catch (error) { + if (error.response?.status === 401) { + // We logoff as the token has probably expired or became invalid + logOffAndReturnToLogin({ builder, store, redirect }) + } else { + throw error + } + } + } + } + + if (needPostBuilderLoading) { // Post builder loading task executed once per application + // It's executed here to make sure we are authenticated at that point const sharedPage = await store.getters['page/getSharedPage'](builder) await Promise.all([ store.dispatch('dataSource/fetchPublished', { @@ -127,37 +161,18 @@ export default { } ) } - store.dispatch('userSourceUser/setCurrentApplication', { - application: builder, - }) - if ( - (!process.server || req) && - !store.getters['userSourceUser/isAuthenticated'](builder) - ) { - // token can be in the query string (SSO) or in the cookies (previous session) - let refreshToken = route.query.token - if (refreshToken) { - setToken(app, refreshToken, userSourceCookieTokenName, { - sameSite: 'Lax', - }) - } else { - refreshToken = getTokenIfEnoughTimeLeft(app, userSourceCookieTokenName) - } - - if (refreshToken) { - try { - await store.dispatch('userSourceUser/refreshAuth', { - application: builder, - token: refreshToken, + // Auth providers can get error code from the URL parameters + for (const userSource of builder.user_sources) { + for (const authProvider of userSource.auth_providers) { + const authError = $registry + .get('appAuthProvider', authProvider.type) + .handleError(userSource, authProvider, route) + if (authError) { + return error({ + statusCode: authError.code, + message: authError.message, }) - } catch (error) { - if (error.response?.status === 401) { - // We logoff as the token has probably expired or became invalid - logOffAndReturnToLogin({ builder, store, redirect }) - } else { - throw error - } } } } @@ -348,7 +363,7 @@ export default { } }, }, - async isAuthenticated() { + async isAuthenticated(newIsAuthenticated) { // When the user login or logout, we need to refetch the elements and actions // as they might have changed await this.$store.dispatch('element/fetchPublished', { @@ -364,12 +379,18 @@ export default { page: this.sharedPage, }) - // If the user is on a hidden page, redirect them to the Login page if possible. - await this.maybeRedirectUserToLoginPage() + if (newIsAuthenticated) { + // If the user has just logged in, we redirect him to the next page. + await this.maybeRedirectToNextPage() + } else { + // If the user is on a hidden page, redirect them to the Login page if possible. + await this.maybeRedirectUserToLoginPage() + } }, }, async mounted() { await this.maybeRedirectUserToLoginPage() + await this.checkProviderAuthentication() }, methods: { /** @@ -389,8 +410,57 @@ export default { this.mode ) - if (url !== this.$router.history.current?.fullPath) { - this.$router.push(url) + const currentPath = this.$route.fullPath + if (url !== currentPath) { + const nextPath = encodeURIComponent(currentPath) + this.$router.push({ path: url, query: { next: nextPath } }) + } + } + }, + maybeRedirectToNextPage() { + if (this.$route.query.next) { + const decodedNext = decodeURIComponent(this.$route.query.next) + this.$router.push(decodedNext) + } + }, + async checkProviderAuthentication() { + // Iterate over all auth providers to check if one can get a refresh token + let refreshTokenFromProvider = null + + for (const userSource of this.builder.user_sources) { + for (const authProvider of userSource.auth_providers) { + refreshTokenFromProvider = this.$registry + .get('appAuthProvider', authProvider.type) + .getAuthToken(userSource, authProvider, this.$route) + if (refreshTokenFromProvider) { + break + } + } + if (refreshTokenFromProvider) { + break + } + } + + if (refreshTokenFromProvider) { + setToken(this, refreshTokenFromProvider, userSourceCookieTokenName, { + sameSite: 'Lax', + }) + try { + await this.$store.dispatch('userSourceUser/refreshAuth', { + application: this.builder, + token: refreshTokenFromProvider, + }) + } catch (error) { + if (error.response?.status === 401) { + // We logoff as the token has probably expired or became invalid + logOffAndReturnToLogin({ + builder: this.builder, + store: this.$store, + redirect: (...args) => this.$router.push(...args), + }) + } else { + throw error + } } } }, diff --git a/web-frontend/modules/core/appAuthProviderTypes.js b/web-frontend/modules/core/appAuthProviderTypes.js index 31fd2d743..1cdb6bddf 100644 --- a/web-frontend/modules/core/appAuthProviderTypes.js +++ b/web-frontend/modules/core/appAuthProviderTypes.js @@ -1,18 +1,30 @@ -import { Registerable } from '@baserow/modules/core/registry' +import { BaseAuthProviderType } from '@baserow/modules/core/authProviderTypes' -export class AppAuthProviderType extends Registerable { +export class AppAuthProviderType extends BaseAuthProviderType { get name() { - throw new Error('Must be set on the type.') + return this.getName() + } + + getLoginOptions(authProvider) { + return null + } + + get component() { + return null } /** * The form to edit this user source. */ get formComponent() { + return this.getAdminSettingsFormComponent() + } + + getAuthToken(userSource, authProvider, route) { return null } - getOrder() { - return 0 + handleError(userSource, authProvider, route) { + return null } } diff --git a/web-frontend/modules/core/assets/scss/components/builder/data_source_form.scss b/web-frontend/modules/core/assets/scss/components/builder/data_source_form.scss index bf3515dfa..c95d3488b 100644 --- a/web-frontend/modules/core/assets/scss/components/builder/data_source_form.scss +++ b/web-frontend/modules/core/assets/scss/components/builder/data_source_form.scss @@ -2,8 +2,6 @@ display: flex; flex-direction: column; gap: 20px; - padding: 12px; - padding-bottom: 0; .tabs__header { border-bottom: none; diff --git a/web-frontend/modules/core/assets/scss/components/builder/update_user_source_form.scss b/web-frontend/modules/core/assets/scss/components/builder/update_user_source_form.scss index e12c82ba4..2aa9b901a 100644 --- a/web-frontend/modules/core/assets/scss/components/builder/update_user_source_form.scss +++ b/web-frontend/modules/core/assets/scss/components/builder/update_user_source_form.scss @@ -1,4 +1,16 @@ +.update-user-source-form__auth-provider { + margin-top: 16px; +} + +.update-user-source-form__auth-provider-header { + display: flex; + justify-content: space-between; +} + .update-user-source-form__auth-provider-form { margin-left: 24px; + margin-top: 12px; display: flex; + flex-direction: column; + align-items: stretch; } diff --git a/web-frontend/modules/core/authProviderTypes.js b/web-frontend/modules/core/authProviderTypes.js index 7530d04a3..294c806de 100644 --- a/web-frontend/modules/core/authProviderTypes.js +++ b/web-frontend/modules/core/authProviderTypes.js @@ -2,10 +2,9 @@ import { Registerable } from '@baserow/modules/core/registry' import PasswordAuthIcon from '@baserow/modules/core/assets/images/providers/Key.svg' /** - * The authorization provider type base class that can be extended when creating - * a plugin for the frontend. + * Base class for authorization provider types */ -export class AuthProviderType extends Registerable { +export class BaseAuthProviderType extends Registerable { /** * The icon for the provider */ @@ -14,14 +13,14 @@ export class AuthProviderType extends Registerable { } /** - * A human readable name of the application type. + * A human readable name of the authentication provider. */ getName() { return null } /** - * A human readable name of the application type. + * A human readable name of the authentication provider. */ getProviderName(provider) { return null @@ -81,6 +80,35 @@ export class AuthProviderType extends Registerable { } } + /** + * Whether we can create new providers on this type. Sometimes providers can't be + * created because of permissions reasons or because of unicity constraints. + * @returns a boolean saying if you can create new providers of this type. + */ + canCreateNew(authProviders) { + return true + } + + /** + * + * @param {Object} err to handle + * @param {VueInstance} vueComponentInstance the vue component instance + * @returns true if the error is handled else false. + */ + handleServerError(vueComponentInstance, error) { + return false + } + + getOrder() { + throw new Error('The order of the authentication provider must be set.') + } +} + +/** + * The authorization provider type base class that can be extended when creating + * a plugin for the frontend. + */ +export class AuthProviderType extends BaseAuthProviderType { constructor(...args) { super(...args) this.type = this.getType() @@ -108,10 +136,6 @@ export class AuthProviderType extends Registerable { routeName: this.routeName, } } - - getOrder() { - throw new Error('The order of an application type must be set.') - } } export class PasswordAuthProviderType extends AuthProviderType { @@ -139,6 +163,16 @@ export class PasswordAuthProviderType extends AuthProviderType { return null } + /** + * We can create only one password provider. + */ + canCreateNew(authProviders) { + return ( + !authProviders[this.getType()] || + authProviders[this.getType()].length === 0 + ) + } + getOrder() { return 1 } diff --git a/web-frontend/modules/core/components/Modal.vue b/web-frontend/modules/core/components/Modal.vue index 1ec53ac94..b783d9533 100644 --- a/web-frontend/modules/core/components/Modal.vue +++ b/web-frontend/modules/core/components/Modal.vue @@ -1,6 +1,7 @@ <template> <div - v-if="open" + v-if="open || keepContent" + v-show="(keepContent && open) || !keepContent" ref="modalWrapper" class="modal__wrapper" @click="outside($event)" @@ -172,6 +173,14 @@ export default { default: false, required: false, }, + // This flag allow to keep the modal content in case you want it to be available. + // Useful if you have a form inside the modal that is a sub part of the current + // form. + keepContent: { + type: Boolean, + default: false, + required: false, + }, }, data() { return { diff --git a/web-frontend/modules/core/mixins/authProviderForm.js b/web-frontend/modules/core/mixins/authProviderForm.js new file mode 100644 index 000000000..bbaea9b23 --- /dev/null +++ b/web-frontend/modules/core/mixins/authProviderForm.js @@ -0,0 +1,40 @@ +import form from '@baserow/modules/core/mixins/form' + +export default { + mixins: [form], + props: { + authProviders: { + type: Object, + required: true, + }, + authProvider: { + type: Object, + required: false, + default: () => ({}), + }, + authProviderType: { + type: Object, + required: true, + }, + }, + data() { + return { serverErrors: {} } + }, + computed: { + providerName() { + return this.authProviderType.getProviderName(this.authProvider) + }, + }, + methods: { + submit() { + this.$v.$touch() + if (this.$v.$invalid) { + return + } + this.$emit('submit', this.values) + }, + handleServerError(error) { + return this.authProviderType.handleServerError(this, error) + }, + }, +} diff --git a/web-frontend/modules/core/mixins/form.js b/web-frontend/modules/core/mixins/form.js index 63cf859b5..80355059c 100644 --- a/web-frontend/modules/core/mixins/form.js +++ b/web-frontend/modules/core/mixins/form.js @@ -85,28 +85,54 @@ export default { firstError.scrollIntoView({ behavior: 'smooth' }) } }, - touch() { + /** + * Select all children that match the given predicate. + * @param {Function} predicate a function that receive the current child as parameter and + * should return true if the child should be accepted. + * @param {Boolean} deep true if you want to deeply search for child + * @returns + */ + getChildForms(predicate = (child) => 'isFormValid' in child, deep = false) { + const children = [] + + const getDeep = (child) => { + if (predicate(child)) { + children.push(child) + } + if (deep) { + // Search into children of children + child.$children.forEach(getDeep) + } + } + + for (const child of this.$children) { + getDeep(child) + } + return children + }, + touch(deep = false) { if ('$v' in this) { this.$v.$touch() } // Also touch all the child forms so that all the error messages are going to // be displayed. - for (const child of this.$children) { - if ('isFormValid' in child && '$v' in child) { - child.touch() - } + for (const child of this.getChildForms( + (child) => 'touch' in child, + deep + )) { + child.touch(deep) } }, - submit() { + submit(deep = false) { if (this.selectedFieldIsDeactivated) { return } - this.touch() + this.touch(deep) - if (this.isFormValid()) { - this.$emit('submitted', this.getFormValues()) + if (this.isFormValid(deep)) { + this.$emit('submitted', this.getFormValues(deep)) } else { this.$nextTick(() => this.focusOnFirstError()) } @@ -120,21 +146,6 @@ export default { ? this.$v.values[fieldName].$error : false }, - getChildForms(deep = false) { - const children = [] - const getDeep = (child) => { - if ('isFormValid' in child) { - children.push(child) - } else if (deep) { - child.$children.forEach(getDeep) - } - } - - for (const child of this.$children) { - getDeep(child) - } - return children - }, /** * Returns true is everything is valid. * @@ -151,8 +162,11 @@ export default { * Returns true if all the child form components are valid. */ areChildFormsValid(deep = false) { - for (const child of this.getChildForms(deep)) { - if ('isFormValid' in child && !child.isFormValid()) { + for (const child of this.getChildForms( + (child) => 'isFormValid' in child, + deep + )) { + if (!child.isFormValid(deep)) { return false } } @@ -162,17 +176,21 @@ export default { * A method that can be overridden to do some mutations on the values before * calling the submitted event. */ - getFormValues() { - return Object.assign({}, this.values, this.getChildFormsValues()) + getFormValues(deep = false) { + return Object.assign({}, this.values, this.getChildFormsValues(deep)) }, /** * Returns an object containing the values of the child forms. */ - getChildFormsValues() { + getChildFormsValues(deep = false) { + const children = this.getChildForms( + (child) => 'getChildFormsValues' in child, + deep + ) return Object.assign( {}, - ...this.$children.map((child) => { - return 'getChildFormsValues' in child ? child.getFormValues() : {} + ...children.map((child) => { + return child.getFormValues(deep) }) ) }, @@ -196,16 +214,22 @@ export default { await this.$nextTick() // Also reset the child forms after a tick to allow default values to be updated. - this.getChildForms(deep).forEach((child) => child.reset()) + this.getChildForms((child) => 'reset' in child, deep).forEach((child) => + child.reset() + ) }, /** * Returns if a child form has indicated it handled the error, false otherwise. */ - handleErrorByForm(error) { + handleErrorByForm(error, deep = false) { let childHandledIt = false - for (const child of this.$children) { - if ('handleErrorByForm' in child && child.handleErrorByForm(error)) { + const children = this.getChildForms( + (child) => 'handleErrorByForm' in child, + deep + ) + for (const child of children) { + if (child.handleErrorByForm(error)) { childHandledIt = true } } diff --git a/web-frontend/modules/core/plugins/featureFlags.js b/web-frontend/modules/core/plugins/featureFlags.js index 0941fbd61..670410a5a 100644 --- a/web-frontend/modules/core/plugins/featureFlags.js +++ b/web-frontend/modules/core/plugins/featureFlags.js @@ -1,5 +1,6 @@ const FF_ENABLE_ALL = '*' export const FF_DASHBOARDS = 'dashboards' +export const FF_AB_SSO = 'ab_sso' /** * A comma separated list of feature flags used to enable in-progress or not ready