1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-06 14:05:28 +00:00

Add Saml auth provider

This commit is contained in:
Jérémie Pardou 2024-12-10 14:27:16 +00:00
parent 993e1c1233
commit 076c1ebf53
99 changed files with 3007 additions and 972 deletions
backend
enterprise
web-frontend/modules

View file

@ -7,6 +7,7 @@ from rest_framework import serializers
from baserow.api.polymorphic import PolymorphicSerializer from baserow.api.polymorphic import PolymorphicSerializer
from baserow.core.app_auth_providers.models import AppAuthProvider from baserow.core.app_auth_providers.models import AppAuthProvider
from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry
from baserow.core.auth_provider.validators import validate_domain
class AppAuthProviderSerializer(serializers.ModelSerializer): class AppAuthProviderSerializer(serializers.ModelSerializer):
@ -46,6 +47,13 @@ class BaseAppAuthProviderSerializer(serializers.ModelSerializer):
help_text="The type of the app_auth_provider.", help_text="The type of the app_auth_provider.",
) )
domain = serializers.CharField(
validators=[validate_domain],
required=False,
allow_null=True,
help_text=AppAuthProvider._meta.get_field("domain").help_text,
)
class Meta: class Meta:
model = AppAuthProvider model = AppAuthProvider
fields = ("type", "user_source_id", "domain") fields = ("type", "user_source_id", "domain")

View file

@ -3,12 +3,19 @@ from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers from rest_framework import serializers
from baserow.core.auth_provider.models import AuthProviderModel from baserow.core.auth_provider.models import AuthProviderModel
from baserow.core.auth_provider.validators import validate_domain
from baserow.core.registries import auth_provider_type_registry from baserow.core.registries import auth_provider_type_registry
class AuthProviderSerializer(serializers.ModelSerializer): class AuthProviderSerializer(serializers.ModelSerializer):
type = serializers.SerializerMethodField(help_text="The type of the related field.") type = serializers.SerializerMethodField(help_text="The type of the related field.")
domain = serializers.CharField(
validators=[validate_domain],
required=True,
help_text=AuthProviderModel._meta.get_field("domain").help_text,
)
class Meta: class Meta:
model = AuthProviderModel model = AuthProviderModel
fields = ("id", "type", "domain", "enabled") fields = ("id", "type", "domain", "enabled")

View file

@ -104,6 +104,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class, base_class=self.base_class,
request=self.request, request=self.request,
context=self.context, context=self.context,
extra_params=self.extra_params,
) )
ret = serializer.to_representation(instance) ret = serializer.to_representation(instance)
@ -122,6 +123,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class, base_class=self.base_class,
request=self.request, request=self.request,
context=self.context, context=self.context,
extra_params=self.extra_params,
) )
return serializer.to_internal_value(data) return serializer.to_internal_value(data)
@ -134,6 +136,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class, base_class=self.base_class,
request=self.request, request=self.request,
context=self.context, context=self.context,
extra_params=self.extra_params,
) )
return serializer.create(validated_data) return serializer.create(validated_data)
@ -150,6 +153,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class, base_class=self.base_class,
request=self.request, request=self.request,
context=self.context, context=self.context,
extra_params=self.extra_params,
) )
return serializer.update(instance, validated_data) return serializer.update(instance, validated_data)
@ -170,6 +174,7 @@ class PolymorphicSerializer(serializers.Serializer):
context=self.context, context=self.context,
data=self.data, data=self.data,
partial=self.partial, partial=self.partial,
extra_params=self.extra_params,
) )
except serializers.ValidationError: except serializers.ValidationError:
child_valid = False child_valid = False
@ -194,6 +199,7 @@ class PolymorphicSerializer(serializers.Serializer):
request=self.request, request=self.request,
context=self.context, context=self.context,
partial=self.partial, partial=self.partial,
extra_params=self.extra_params,
) )
validated_data = serializer.run_validation(data) validated_data = serializer.run_validation(data)

View file

@ -2,7 +2,11 @@ from django.urls import include, path
from drf_spectacular.views import SpectacularRedocView from drf_spectacular.views import SpectacularRedocView
from baserow.core.registries import application_type_registry, plugin_registry from baserow.core.registries import (
application_type_registry,
auth_provider_type_registry,
plugin_registry,
)
from .applications import urls as application_urls from .applications import urls as application_urls
from .auth_provider import urls as auth_provider_urls from .auth_provider import urls as auth_provider_urls
@ -53,5 +57,6 @@ urlpatterns = (
), ),
] ]
+ application_type_registry.api_urls + application_type_registry.api_urls
+ auth_provider_type_registry.api_urls
+ plugin_registry.api_urls + plugin_registry.api_urls
) )

View file

@ -3,6 +3,7 @@ from typing import Optional, Tuple, TypeVar
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.authentication import JWTAuthentication
@ -157,3 +158,19 @@ class UserSourceJSONWebTokenAuthentication(JWTAuthentication):
user, user,
validated_token, validated_token,
) )
class UserSourceJSONWebTokenAuthenticationExtension(OpenApiAuthenticationExtension):
target_class = (
"baserow.api.user_sources.authentication.UserSourceJSONWebTokenAuthentication"
)
name = "UserSource JWT"
match_subclasses = True
priority = -1
def get_security_definition(self, auto_schema):
return {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT your_token",
}

View file

@ -6,9 +6,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers from rest_framework import serializers
from baserow.api.app_auth_providers.serializers import ( from baserow.api.app_auth_providers.serializers import AppAuthProviderSerializer
ReadPolymorphicAppAuthProviderSerializer,
)
from baserow.api.polymorphic import PolymorphicSerializer from baserow.api.polymorphic import PolymorphicSerializer
from baserow.api.services.serializers import PublicServiceSerializer from baserow.api.services.serializers import PublicServiceSerializer
from baserow.api.user_files.serializers import UserFileField, UserFileSerializer from baserow.api.user_files.serializers import UserFileField, UserFileSerializer
@ -26,6 +24,7 @@ from baserow.contrib.builder.elements.registries import element_type_registry
from baserow.contrib.builder.models import Builder from baserow.contrib.builder.models import Builder
from baserow.contrib.builder.pages.handler import PageHandler from baserow.contrib.builder.pages.handler import PageHandler
from baserow.contrib.builder.pages.models import Page from baserow.contrib.builder.pages.models import Page
from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry
from baserow.core.services.registries import service_type_registry from baserow.core.services.registries import service_type_registry
from baserow.core.user_sources.models import UserSource from baserow.core.user_sources.models import UserSource
from baserow.core.user_sources.registries import user_source_type_registry from baserow.core.user_sources.registries import user_source_type_registry
@ -175,6 +174,16 @@ class PublicPageSerializer(serializers.ModelSerializer):
} }
class PublicPolymorphicAppAuthProviderSerializer(PolymorphicSerializer):
"""
Polymorphic serializer for App Auth providers.
"""
base_class = AppAuthProviderSerializer
registry = app_auth_provider_type_registry
extra_params = {"public": True}
class BasePublicUserSourceSerializer(serializers.ModelSerializer): class BasePublicUserSourceSerializer(serializers.ModelSerializer):
""" """
Basic user source serializer mostly for returned values. Basic user source serializer mostly for returned values.
@ -186,7 +195,7 @@ class BasePublicUserSourceSerializer(serializers.ModelSerializer):
def get_type(self, instance): def get_type(self, instance):
return user_source_type_registry.get_by_model(instance.specific_class).type return user_source_type_registry.get_by_model(instance.specific_class).type
auth_providers = ReadPolymorphicAppAuthProviderSerializer( auth_providers = PublicPolymorphicAppAuthProviderSerializer(
required=False, required=False,
many=True, many=True,
help_text="Auth providers related to this user source.", help_text="Auth providers related to this user source.",

View file

@ -1,6 +1,8 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urljoin
from zipfile import ZipFile from zipfile import ZipFile
from django.conf import settings
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.core.files.storage import Storage from django.core.files.storage import Storage
from django.db import transaction from django.db import transaction
@ -419,6 +421,30 @@ class BuilderApplicationType(ApplicationType):
return builder return builder
def get_default_application_urls(self, application: Builder) -> list[str]:
"""
Returns the default frontend urls of a builder application.
"""
from baserow.contrib.builder.domains.handler import DomainHandler
domain = DomainHandler().get_domain_for_builder(application)
if domain is not None:
# Let's also return the preview url so that it's easier to test
preview_url = urljoin(
settings.PUBLIC_WEB_FRONTEND_URL,
f"/builder/{domain.builder_id}/preview/",
)
return [domain.get_public_url(), preview_url]
preview_url = urljoin(
settings.PUBLIC_WEB_FRONTEND_URL,
f"/builder/{application.id}/preview/",
)
# It's an unpublished version let's return to the home preview page
return [preview_url]
def enhance_queryset(self, queryset): def enhance_queryset(self, queryset):
queryset = queryset.prefetch_related("page_set") queryset = queryset.prefetch_related("page_set")
queryset = queryset.prefetch_related("user_sources") queryset = queryset.prefetch_related("user_sources")

View file

@ -85,6 +85,17 @@ class DomainHandler:
return domain.published_to return domain.published_to
def get_domain_for_builder(self, builder: Builder) -> Domain | None:
"""
Returns the domain the builder is published for or None if it's not a published
builder.
"""
try:
return Domain.objects.get(published_to=builder)
except Domain.DoesNotExist:
return None
def create_domain( def create_domain(
self, domain_type: DomainType, builder: Builder, **kwargs self, domain_type: DomainType, builder: Builder, **kwargs
) -> Domain: ) -> Domain:

View file

@ -1,3 +1,6 @@
from urllib.parse import urlparse
from django.conf import settings
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
from django.db.models import CASCADE, SET_NULL from django.db.models import CASCADE, SET_NULL
@ -80,6 +83,16 @@ class Domain(
class Meta: class Meta:
ordering = ("order",) ordering = ("order",)
def get_public_url(self):
"""
Returns the URL for this domain.
"""
# Parse the PUBLIC_WEB_FRONTEND_URL to extract the scheme and port
parsed_url = urlparse(settings.PUBLIC_WEB_FRONTEND_URL)
port_string = f":{parsed_url.port}" if parsed_url.port else ""
return f"{parsed_url.scheme}://{self.domain_name}{port_string}"
@classmethod @classmethod
def get_last_order(cls, builder): def get_last_order(cls, builder):
queryset = Domain.objects.filter(builder=builder) queryset = Domain.objects.filter(builder=builder)

View file

@ -4,6 +4,7 @@ from baserow.contrib.builder.data_sources.operations import (
DispatchDataSourceOperationType, DispatchDataSourceOperationType,
ListDataSourcesPageOperationType, ListDataSourcesPageOperationType,
) )
from baserow.contrib.builder.domains.handler import DomainHandler
from baserow.contrib.builder.elements.operations import ListElementsPageOperationType from baserow.contrib.builder.elements.operations import ListElementsPageOperationType
from baserow.contrib.builder.models import Builder from baserow.contrib.builder.models import Builder
from baserow.contrib.builder.workflow_actions.operations import ( from baserow.contrib.builder.workflow_actions.operations import (
@ -20,8 +21,6 @@ from baserow.core.user_sources.operations import (
) )
from baserow.core.user_sources.subjects import UserSourceUserSubjectType from baserow.core.user_sources.subjects import UserSourceUserSubjectType
from .models import Domain
User = get_user_model() User = get_user_model()
@ -101,10 +100,7 @@ class AllowPublicBuilderManagerType(PermissionManagerType):
# give access to specific data. # give access to specific data.
continue continue
if ( if DomainHandler().get_domain_for_builder(builder) is not None:
builder.workspace is None
and Domain.objects.filter(published_to=builder).exists()
):
# it's a public builder, we allow it. # it's a public builder, we allow it.
result[check] = True result[check] = True

View file

@ -1,72 +0,0 @@
from typing import Dict, List, Optional
from baserow.core.registry import CustomFieldsInstanceMixin
class PublicCustomFieldsInstanceMixin(CustomFieldsInstanceMixin):
public_serializer_field_names = []
"""The field names that must be added to the serializer if it's public."""
public_request_serializer_field_names = []
"""
The field names that must be added to the public request serializer if different
from the `public_serializer_field_names`.
"""
request_serializer_field_overrides = None
"""
The fields that must be added to the request serializer if different from the
`serializer_field_overrides` property.
"""
public_serializer_field_overrides = None
"""The fields that must be added to the public serializer."""
public_request_serializer_field_overrides = None
"""
The fields that must be added to the public request serializer if different from the
`public_serializer_field_overrides` property.
"""
def get_field_overrides(
self, request_serializer: bool, extra_params=None, **kwargs
) -> Dict:
public = extra_params.get("public", False)
if public:
if request_serializer and self.public_request_serializer_field_overrides:
return self.public_request_serializer_field_overrides
if self.public_serializer_field_overrides:
return self.public_serializer_field_overrides
return super().get_field_overrides(request_serializer, extra_params, **kwargs)
def get_field_names(
self, request_serializer: bool, extra_params=None, **kwargs
) -> List[str]:
public = extra_params.get("public", False)
if public:
if request_serializer and self.public_request_serializer_field_names:
return self.public_request_serializer_field_names
if self.public_serializer_field_names:
return self.public_serializer_field_names
return super().get_field_names(request_serializer, extra_params, **kwargs)
def get_meta_ref_name(
self,
request_serializer: bool,
extra_params=None,
**kwargs,
) -> Optional[str]:
meta_ref_name = super().get_meta_ref_name(
request_serializer, extra_params, **kwargs
)
public = extra_params.get("public", False)
if public:
meta_ref_name = f"Public{meta_ref_name}"
return meta_ref_name

View file

@ -4,11 +4,11 @@ from django.contrib.auth.models import AbstractUser
from baserow.contrib.builder.formula_importer import import_formula from baserow.contrib.builder.formula_importer import import_formula
from baserow.contrib.builder.mixins import BuilderInstanceWithFormulaMixin from baserow.contrib.builder.mixins import BuilderInstanceWithFormulaMixin
from baserow.contrib.builder.registries import PublicCustomFieldsInstanceMixin
from baserow.contrib.builder.workflow_actions.models import BuilderWorkflowAction from baserow.contrib.builder.workflow_actions.models import BuilderWorkflowAction
from baserow.core.registry import ( from baserow.core.registry import (
CustomFieldsRegistryMixin, CustomFieldsRegistryMixin,
ModelRegistryMixin, ModelRegistryMixin,
PublicCustomFieldsInstanceMixin,
Registry, Registry,
) )
from baserow.core.workflow_actions.registries import WorkflowActionType from baserow.core.workflow_actions.registries import WorkflowActionType
@ -88,6 +88,7 @@ class BuilderWorkflowActionType(
cache = {} cache = {}
element_id = serialized_values["element_id"] element_id = serialized_values["element_id"]
import_context = {}
if element_id: if element_id:
imported_element_id = id_mapping["builder_page_elements"][element_id] imported_element_id = id_mapping["builder_page_elements"][element_id]
import_context = ElementHandler().get_import_context_addition( import_context = ElementHandler().get_import_context_addition(

View file

@ -1,18 +1,20 @@
from typing import TYPE_CHECKING, Callable, List, Type, Union from typing import TYPE_CHECKING, Callable, List, Tuple, Type, Union
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from baserow.core.app_auth_providers.exceptions import IncompatibleUserSourceType from baserow.core.app_auth_providers.exceptions import IncompatibleUserSourceType
from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict
from baserow.core.auth_provider.registries import BaseAuthProviderType from baserow.core.auth_provider.registries import BaseAuthProviderType
from baserow.core.auth_provider.types import AuthProviderModelSubClass from baserow.core.auth_provider.types import AuthProviderModelSubClass, UserInfo
from baserow.core.registry import EasyImportExportMixin from baserow.core.registry import EasyImportExportMixin, PublicCustomFieldsInstanceMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from baserow.core.user_sources.types import UserSourceSubClass from baserow.core.user_sources.types import UserSourceSubClass
class AppAuthProviderType(EasyImportExportMixin, BaseAuthProviderType): class AppAuthProviderType(
EasyImportExportMixin, PublicCustomFieldsInstanceMixin, BaseAuthProviderType
):
""" """
Authentication provider for application user sources. Authentication provider for application user sources.
""" """
@ -70,3 +72,16 @@ class AppAuthProviderType(EasyImportExportMixin, BaseAuthProviderType):
:param instance: The auth provider instance related to the user source. :param instance: The auth provider instance related to the user source.
:param user_source: The user source being updated. :param user_source: The user source being updated.
""" """
def get_or_create_user_and_sign_in(
self, auth_provider: AuthProviderModelSubClass, user_info: UserInfo
) -> Tuple[AbstractUser, bool]:
"""
Get or create a user for the given UserInfo. Calls the related userSource
get_or_create_user.
"""
user_source = auth_provider.user_source.specific
return user_source.get_type().get_or_create_user(
user_source, email=user_info.email, name=user_info.name
)

View file

@ -37,4 +37,4 @@ class AppAuthProvider(BaseAuthProviderModel, HierarchicalModelMixin):
return app_auth_provider_type_registry return app_auth_provider_type_registry
class Meta: class Meta:
ordering = ["domain", "id"] ordering = ["id"]

View file

@ -17,8 +17,14 @@ class BaseAuthProviderModel(
Base abstract model for app_providers. Base abstract model for app_providers.
""" """
domain = models.CharField(max_length=255, null=True) domain = models.CharField(
enabled = models.BooleanField(default=True) max_length=255,
null=True,
help_text="The email domain registered with this provider.",
)
enabled = models.BooleanField(
help_text="Whether the provider is enabled or not.", default=True
)
class Meta: class Meta:
abstract = True abstract = True

View file

@ -47,13 +47,6 @@ class BaseAuthProviderType(
default_create_allowed_fields = ["domain", "enabled"] default_create_allowed_fields = ["domain", "enabled"]
default_update_allowed_fields = ["domain", "enabled"] default_update_allowed_fields = ["domain", "enabled"]
@abstractmethod
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
"""
Returns a dictionary containing the login options
to populate the login component accordingly.
"""
def can_create_new_providers(self, **kwargs) -> bool: def can_create_new_providers(self, **kwargs) -> bool:
""" """
Returns True if it's possible to create an authentication provider of this type. Returns True if it's possible to create an authentication provider of this type.
@ -249,6 +242,13 @@ class AuthenticationProviderTypeRegistry(
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._default = None self._default = None
@abstractmethod
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
"""
Returns a dictionary containing the login options
to populate the login component accordingly.
"""
def get_all_available_login_options(self): def get_all_available_login_options(self):
login_options = {} login_options = {}
for provider_type in self.get_all(): for provider_type in self.get_all():

View file

@ -0,0 +1,33 @@
from django.core.management.base import BaseCommand
from django.urls import URLPattern, URLResolver, get_resolver
class Command(BaseCommand):
help = "List all registered full URLs with their namespaces"
def handle(self, *args, **kwargs):
resolver = get_resolver()
self.list_urls(resolver.url_patterns)
def list_urls(self, urlpatterns, prefix="", namespace=None):
for pattern in urlpatterns:
if isinstance(pattern, URLPattern):
# Construct the full URL path
full_url = f"{prefix}{pattern.pattern}"
full_namespace = (
f"{namespace}:{pattern.name}"
if namespace and pattern.name
else pattern.name or "None"
)
self.stdout.write(f"URL: {full_url}, Namespace: {full_namespace}")
elif isinstance(pattern, URLResolver):
# Construct the full namespace and recurse
new_prefix = f"{prefix}{pattern.pattern}"
new_namespace = (
f"{namespace}:{pattern.namespace}"
if namespace and pattern.namespace
else pattern.namespace or namespace
)
self.list_urls(
pattern.url_patterns, prefix=new_prefix, namespace=new_namespace
)

View file

@ -0,0 +1,48 @@
# Generated by Django 5.0.9 on 2024-11-19 09:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("core", "0092_alter_userprofile_language"),
]
operations = [
migrations.AlterModelOptions(
name="appauthprovider",
options={"ordering": ["id"]},
),
migrations.AlterField(
model_name="appauthprovider",
name="domain",
field=models.CharField(
help_text="The email domain registered with this provider.",
max_length=255,
null=True,
),
),
migrations.AlterField(
model_name="appauthprovider",
name="enabled",
field=models.BooleanField(
default=True, help_text="Whether the provider is enabled or not."
),
),
migrations.AlterField(
model_name="authprovidermodel",
name="domain",
field=models.CharField(
help_text="The email domain registered with this provider.",
max_length=255,
null=True,
),
),
migrations.AlterField(
model_name="authprovidermodel",
name="enabled",
field=models.BooleanField(
default=True, help_text="Whether the provider is enabled or not."
),
),
]

View file

@ -507,6 +507,13 @@ class ApplicationType(
def enhance_queryset(self, queryset): def enhance_queryset(self, queryset):
return queryset return queryset
def get_default_application_urls(self, application: "Application") -> list[str]:
"""
Returns the default frontend urls of the application if any.
"""
return []
ApplicationSubClassInstance = TypeVar( ApplicationSubClassInstance = TypeVar(
"ApplicationSubClassInstance", bound="Application" "ApplicationSubClassInstance", bound="Application"

View file

@ -276,13 +276,95 @@ class CustomFieldsInstanceMixin:
return None return None
class PublicCustomFieldsInstanceMixin(CustomFieldsInstanceMixin):
"""
A mixin for instance with custom fields but some field should remains private
when used in some APIs.
"""
public_serializer_field_names = None
"""The field names that must be added to the serializer if it's public."""
public_request_serializer_field_names = None
"""
The field names that must be added to the public request serializer if different
from the `public_serializer_field_names`.
"""
request_serializer_field_overrides = None
"""
The fields that must be added to the request serializer if different from the
`serializer_field_overrides` property.
"""
public_serializer_field_overrides = None
"""The fields that must be added to the public serializer."""
public_request_serializer_field_overrides = None
"""
The fields that must be added to the public request serializer if different from the
`public_serializer_field_overrides` property.
"""
def get_field_overrides(
self, request_serializer: bool, extra_params=None, **kwargs
) -> Dict:
public = extra_params.get("public", False)
if public:
if (
request_serializer is not None
and self.public_request_serializer_field_overrides is not None
):
return self.public_request_serializer_field_overrides
if self.public_serializer_field_overrides is not None:
return self.public_serializer_field_overrides
return super().get_field_overrides(request_serializer, extra_params, **kwargs)
def get_field_names(
self, request_serializer: bool, extra_params=None, **kwargs
) -> List[str]:
public = extra_params.get("public", False)
if public:
if (
request_serializer is not None
and self.public_request_serializer_field_names is not None
):
return self.public_request_serializer_field_names
if self.public_serializer_field_names is not None:
return self.public_serializer_field_names
return super().get_field_names(request_serializer, extra_params, **kwargs)
def get_meta_ref_name(
self,
request_serializer: bool,
extra_params=None,
**kwargs,
) -> Optional[str]:
meta_ref_name = super().get_meta_ref_name(
request_serializer, extra_params, **kwargs
)
public = extra_params.get("public", False)
if public:
meta_ref_name = f"Public{meta_ref_name}"
return meta_ref_name
class APIUrlsInstanceMixin: class APIUrlsInstanceMixin:
def get_api_urls(self): def get_api_urls(self) -> List:
""" """
If needed custom api related urls to the instance can be added here. If needed custom api related urls to the instance can be added here.
Example: Example:
from django.urls import include, path
def get_api_urls(self): def get_api_urls(self):
from . import api_urls from . import api_urls
@ -298,7 +380,6 @@ class APIUrlsInstanceMixin:
] ]
:return: A list containing the urls. :return: A list containing the urls.
:rtype: list
""" """
return [] return []

View file

@ -1,6 +1,6 @@
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
@ -16,6 +16,7 @@ from baserow.core.registry import (
ModelRegistryMixin, ModelRegistryMixin,
Registry, Registry,
) )
from baserow.core.user.exceptions import UserNotFound
from baserow.core.user_sources.constants import DEFAULT_USER_ROLE_PREFIX from baserow.core.user_sources.constants import DEFAULT_USER_ROLE_PREFIX
from baserow.core.user_sources.user_source_user import UserSourceUser from baserow.core.user_sources.user_source_user import UserSourceUser
@ -233,15 +234,41 @@ class UserSourceType(
""" """
@abstractmethod @abstractmethod
def get_user(self, user_source: UserSource, **kwargs) -> Optional[UserSourceUser]: def create_user(
self, user_source: UserSource, email: str, name: str
) -> UserSourceUser:
"""
Create a user for the given user source instance from it's email and name.
:param user_source: The user source we want to create the user for.
:param email: Email of the user to create.
:param name: Name of the user to create.
:return: A user instance.
"""
@abstractmethod
def get_user(self, user_source: UserSource, **kwargs) -> UserSourceUser:
""" """
Returns a user given some args. Returns a user given some args.
:param user_source: The user source used to get the user. :param user_source: The user source used to get the user.
:param kwargs: Keyword arguments to get the user. :param kwargs: Keyword arguments to get the user.
:raises UserNotFound: When the user can't be found.
:return: A user instance if any found with the given parameters. :return: A user instance if any found with the given parameters.
""" """
def get_or_create_user(
self, user_source: UserSource, email: str, name: str
) -> Tuple[UserSourceUser, bool]:
"""
Shorthand to create a user if he doesn't exist.
"""
try:
return self.get_user(user_source, email=email), False
except UserNotFound:
return self.create_user(user_source, email, name), True
@abstractmethod @abstractmethod
def authenticate(self, user_source: UserSource, **kwargs) -> UserSourceUser: def authenticate(self, user_source: UserSource, **kwargs) -> UserSourceUser:
""" """

View file

@ -265,6 +265,9 @@ def stub_user_source_registry(data_fixture, mutable_user_source_registry, fake):
return get_user_return return get_user_return
return data_fixture.create_user_source_user(user_source=user_source) return data_fixture.create_user_source_user(user_source=user_source)
def create_user(self, user_source, email, name):
return data_fixture.create_user_source_user(user_source=user_source)
def authenticate(self, user_source, **kwargs): def authenticate(self, user_source, **kwargs):
if authenticate_return: if authenticate_return:
if callable(authenticate_return): if callable(authenticate_return):

View file

@ -193,7 +193,10 @@ def test_create_user_source_w_auth_providers(api_client, data_fixture):
"name": "test", "name": "test",
"integration_id": integration.id, "integration_id": integration.id,
"auth_providers": [ "auth_providers": [
{"type": "local_baserow_password", "enabled": False, "domain": "test1"}, {
"type": "local_baserow_password",
"enabled": False,
},
], ],
}, },
format="json", format="json",
@ -208,14 +211,87 @@ def test_create_user_source_w_auth_providers(api_client, data_fixture):
assert response_json["auth_providers"] == [ assert response_json["auth_providers"] == [
{ {
"domain": "test1",
"id": first.id, "id": first.id,
"password_field_id": None, "password_field_id": None,
"type": "local_baserow_password", "type": "local_baserow_password",
"domain": None,
}, },
] ]
@pytest.mark.django_db
def test_create_user_source_w_auth_providers_w_domain(api_client, data_fixture):
user, token = data_fixture.create_user_and_token()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(application=application)
url = reverse("api:user_sources:list", kwargs={"application_id": application.id})
response = api_client.post(
url,
{
"type": "local_baserow",
"name": "test",
"integration_id": integration.id,
"auth_providers": [
{
"domain": "domain.com",
"type": "local_baserow_password",
"enabled": False,
},
],
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_200_OK
assert AppAuthProvider.objects.count() == 1
first = AppAuthProvider.objects.first()
assert response_json["auth_providers"] == [
{
"id": first.id,
"password_field_id": None,
"type": "local_baserow_password",
"domain": "domain.com",
},
]
@pytest.mark.django_db
def test_create_user_source_w_auth_providers_w_wrong_domain(api_client, data_fixture):
user, token = data_fixture.create_user_and_token()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(application=application)
url = reverse("api:user_sources:list", kwargs={"application_id": application.id})
response = api_client.post(
url,
{
"type": "local_baserow",
"name": "test",
"integration_id": integration.id,
"auth_providers": [
{
"domain": "baddomain",
"type": "local_baserow_password",
"enabled": False,
},
],
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_400_BAD_REQUEST
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
@pytest.mark.django_db @pytest.mark.django_db
def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture): def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture):
user, token = data_fixture.create_user_and_token() user, token = data_fixture.create_user_and_token()
@ -237,7 +313,6 @@ def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture)
"integration_id": integration.id, "integration_id": integration.id,
"auth_providers": [ "auth_providers": [
{ {
"domain": "test_domain",
"enabled": True, "enabled": True,
"type": app_auth_provider_type.type, "type": app_auth_provider_type.type,
}, },
@ -270,7 +345,6 @@ def test_create_user_source_w_auth_provider_missing_type(api_client, data_fixtur
"integration_id": integration.id, "integration_id": integration.id,
"auth_providers": [ "auth_providers": [
{ {
"domain": "test_domain",
"enabled": True, "enabled": True,
"type": "bad_type", "type": "bad_type",
}, },
@ -374,7 +448,11 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
url, url,
{ {
"auth_providers": [ "auth_providers": [
{"type": "local_baserow_password", "enabled": False, "domain": "test1"}, {
"type": "local_baserow_password",
"enabled": False,
"domain": "test2.com",
},
], ],
}, },
format="json", format="json",
@ -388,7 +466,7 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
assert response.json()["auth_providers"] == [ assert response.json()["auth_providers"] == [
{ {
"domain": "test1", "domain": "test2.com",
"id": first.id, "id": first.id,
"password_field_id": None, "password_field_id": None,
"type": "local_baserow_password", "type": "local_baserow_password",
@ -399,7 +477,11 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
url, url,
{ {
"auth_providers": [ "auth_providers": [
{"type": "local_baserow_password", "enabled": False, "domain": "test3"}, {
"type": "local_baserow_password",
"enabled": False,
"domain": "test3.com",
},
], ],
}, },
format="json", format="json",
@ -411,7 +493,7 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
assert response.json()["auth_providers"] == [ assert response.json()["auth_providers"] == [
{ {
"domain": "test3", "domain": "test3.com",
"id": first.id, "id": first.id,
"password_field_id": None, "password_field_id": None,
"type": "local_baserow_password", "type": "local_baserow_password",

View file

@ -196,3 +196,34 @@ def test_domain_publishing(data_fixture):
DomainHandler().publish(domain1, progress) DomainHandler().publish(domain1, progress)
assert Builder.objects.count() == 2 assert Builder.objects.count() == 2
@pytest.mark.django_db
def test_get_domain_for_builder(data_fixture):
user = data_fixture.create_user()
builder = data_fixture.create_builder_application(user=user)
builder_to = data_fixture.create_builder_application(workspace=None)
domain1 = data_fixture.create_builder_custom_domain(
builder=builder, published_to=builder_to, domain_name="mytest.com"
)
domain2 = data_fixture.create_builder_custom_domain(
builder=builder, domain_name="mytest2.com"
)
assert (
DomainHandler().get_domain_for_builder(builder_to).domain_name == "mytest.com"
)
assert DomainHandler().get_domain_for_builder(builder) is None
@pytest.mark.django_db
def test_get_domain_public_url(data_fixture):
user = data_fixture.create_user()
builder = data_fixture.create_builder_application(user=user)
builder_to = data_fixture.create_builder_application(workspace=None)
domain1 = data_fixture.create_builder_custom_domain(
builder=builder, published_to=builder_to, domain_name="mytest.com"
)
assert domain1.get_public_url() == "http://mytest.com:3000"

View file

@ -1636,3 +1636,21 @@ def test_builder_application_exports_file_with_zip_file(
serialized_image_element = visible_pages[0]["elements"][0] serialized_image_element = visible_pages[0]["elements"][0]
assert serialized_image_element["image_source_type"] == "upload" assert serialized_image_element["image_source_type"] == "upload"
assert serialized_image_element["image_file_id"] == serialized_file assert serialized_image_element["image_file_id"] == serialized_file
@pytest.mark.django_db
def test_get_default_application_urls(data_fixture):
user = data_fixture.create_user()
builder = data_fixture.create_builder_application(user=user)
builder_to = data_fixture.create_builder_application(workspace=None)
domain1 = data_fixture.create_builder_custom_domain(
builder=builder, published_to=builder_to, domain_name="mytest.com"
)
assert builder.get_type().get_default_application_urls(builder) == [
f"http://localhost:3000/builder/{builder.id}/preview/"
]
assert builder_to.get_type().get_default_application_urls(builder_to) == [
"http://mytest.com:3000",
f"http://localhost:3000/builder/{builder.id}/preview/",
]

View file

@ -14,7 +14,7 @@ class CreateAuthProviderSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = AuthProviderModel model = AuthProviderModel
fields = ("domain", "type") fields = ("domain", "type", "enabled")
class UpdateAuthProviderSerializer(serializers.ModelSerializer): class UpdateAuthProviderSerializer(serializers.ModelSerializer):
@ -29,7 +29,7 @@ class UpdateAuthProviderSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = AuthProviderModel model = AuthProviderModel
fields = ("domain", "type") fields = ("domain", "type", "enabled")
extra_kwargs = { extra_kwargs = {
"domain": {"required": False}, "domain": {"required": False},
} }

View file

@ -0,0 +1,15 @@
from rest_framework import serializers
from baserow_enterprise.api.sso.saml.serializers import SAMLResponseSerializer
from baserow_enterprise.api.sso.serializers import BaseSsoLoginRequestSerializer
class CommonSsoLoginRequestSerializer(BaseSsoLoginRequestSerializer):
next = serializers.CharField(
required=False,
help_text="If provided, the user will be redirected to that path after login",
)
class CommonSAMLResponseSerializer(SAMLResponseSerializer):
query_param_serializer = CommonSsoLoginRequestSerializer

View file

@ -0,0 +1,19 @@
from django.urls import re_path
from .views import (
SamlAppAuthProviderAssertionConsumerServiceView,
SamlAppAuthProviderBaserowInitiatedSingleSignOn,
)
app_name = "baserow_enterprise.api.integrations.common.sso.saml"
urlpatterns = [
re_path(
r"acs/$", SamlAppAuthProviderAssertionConsumerServiceView.as_view(), name="acs"
),
re_path(
r"login/$",
SamlAppAuthProviderBaserowInitiatedSingleSignOn.as_view(),
name="login",
),
]

View file

@ -0,0 +1,245 @@
from django.db import transaction
from django.http import HttpResponseRedirect
from django.shortcuts import redirect
from drf_spectacular.openapi import OpenApiParameter, OpenApiTypes
from drf_spectacular.utils import extend_schema
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.views import APIView
from baserow.api.decorators import map_exceptions
from baserow.api.exceptions import (
QueryParameterValidationException,
RequestBodyValidationException,
)
from baserow.api.user_sources.errors import ERROR_USER_SOURCE_DOES_NOT_EXIST
from baserow.api.utils import validate_data
from baserow.core.user.exceptions import DeactivatedUserException
from baserow.core.user_sources.exceptions import (
UserSourceDoesNotExist,
UserSourceImproperlyConfigured,
)
from baserow.core.user_sources.handler import UserSourceHandler
from baserow_enterprise.api.integrations.common.sso.saml.serializers import (
CommonSAMLResponseSerializer,
CommonSsoLoginRequestSerializer,
)
from baserow_enterprise.api.sso.serializers import BaseSsoLoginRequestSerializer
from baserow_enterprise.api.sso.utils import (
SsoErrorCode,
get_valid_frontend_url,
map_sso_exceptions,
urlencode_query_params,
)
from baserow_enterprise.integrations.common.sso.saml.handler import (
SamlAppAuthProviderHandler,
)
from baserow_enterprise.integrations.common.sso.saml.models import (
SamlAppAuthProviderModel,
)
from baserow_enterprise.sso.saml.exceptions import (
InvalidSamlConfiguration,
InvalidSamlRequest,
InvalidSamlResponse,
)
class SamlAppAuthProviderAssertionConsumerServiceView(APIView):
permission_classes = (AllowAny,)
@extend_schema(
tags=["Auth"],
request=CommonSAMLResponseSerializer,
operation_id="auth_provider_saml_acs_url",
description=(
"Complete the SAML authentication flow by validating the SAML response. "
"Sign in the user if already exists in user_source or create a new one "
"otherwise."
"Once authenticated, the user will be redirected to the original "
"URL they were trying to access. If the response is invalid, the user "
"will be redirected to an error page with a specific error message."
"It accepts the language code and the workspace invitation token as query "
"parameters if provided."
),
responses={302: None},
auth=[],
)
@transaction.atomic
@map_exceptions(
{
UserSourceDoesNotExist: ERROR_USER_SOURCE_DOES_NOT_EXIST,
}
)
def post(
self,
request: Request,
user_source_uid,
) -> HttpResponseRedirect:
user_source = UserSourceHandler().get_user_source_by_uid(user_source_uid)
default_frontend_urls = (
user_source.application.get_type().get_default_application_urls(
user_source.application.specific
)
)
error_raised = {"code": None}
def on_error(error_code):
error_raised["code"] = error_code
with map_sso_exceptions(
{
InvalidSamlConfiguration: SsoErrorCode.INVALID_SAML_RESPONSE,
InvalidSamlResponse: SsoErrorCode.INVALID_SAML_RESPONSE,
DeactivatedUserException: SsoErrorCode.USER_DEACTIVATED,
RequestBodyValidationException: SsoErrorCode.INVALID_SAML_RESPONSE,
UserSourceDoesNotExist: SsoErrorCode.INVALID_SAML_REQUEST,
UserSourceImproperlyConfigured: SsoErrorCode.INVALID_SAML_REQUEST,
},
on_error=on_error,
):
# We can't use the decorator here because the redirect_url is related
# to the user source and we don't have it before.
data = validate_data(
CommonSAMLResponseSerializer,
request.data,
return_validated=True,
)
next_path = data["saml_request_data"].pop("next", None)
user = SamlAppAuthProviderHandler.sign_in_user_from_saml_response(
data["SAMLResponse"],
data["saml_request_data"],
base_queryset=SamlAppAuthProviderModel.objects.filter(
user_source=user_source
),
)
if error_raised["code"]:
# We redirect to the default frontend url with an error code
error_url = urlencode_query_params(
default_frontend_urls[0],
{f"saml_error__{user_source.id}": error_raised["code"].value},
)
return redirect(error_url)
query_params = {
f"user_source_saml_token__{user_source.id}": user.get_refresh_token()
}
if next_path:
query_params["next"] = next_path
# Otherwise it a success, we redirect to the login page
redirect_url = get_valid_frontend_url(
data["RelayState"],
default_frontend_urls=default_frontend_urls,
# Add the refresh token as query parameter
query_params=query_params,
allow_any_path=False,
)
return redirect(redirect_url)
class SamlAppAuthProviderBaserowInitiatedSingleSignOn(APIView):
permission_classes = (AllowAny,)
@extend_schema(
parameters=[
OpenApiParameter(
name="email",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="The email address of the user that want to sign in using SAML.",
),
OpenApiParameter(
name="original",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description=(
"The url to which the user should be redirected after a successful "
"login or sign up."
),
),
],
tags=["User sources"],
request=BaseSsoLoginRequestSerializer,
operation_id="app_auth_provider_saml_sp_login",
description=(
"This is the endpoint that is called when the user wants to initiate a "
"SSO SAML login from Baserow (the service provider). The user will be "
"redirected to the SAML identity provider (IdP) where the user "
"can authenticate. "
"Once logged in in the IdP, the user will be redirected back "
"to the assertion consumer service endpoint (ACS) where the SAML response "
"will be validated and a new JWT session token will be provided to work "
"with Baserow APIs."
),
responses={302: None},
auth=[],
)
@map_exceptions(
{
UserSourceDoesNotExist: ERROR_USER_SOURCE_DOES_NOT_EXIST,
}
)
def get(self, request: Request, user_source_uid: str) -> HttpResponseRedirect:
user_source = UserSourceHandler().get_user_source_by_uid(user_source_uid)
default_frontend_urls = (
user_source.application.get_type().get_default_application_urls(
user_source.application.specific
)
)
error_raised = {"code": None}
def on_error(error_code):
error_raised["code"] = error_code
with map_sso_exceptions(
{
InvalidSamlConfiguration: SsoErrorCode.INVALID_SAML_REQUEST,
InvalidSamlRequest: SsoErrorCode.INVALID_SAML_REQUEST,
RequestBodyValidationException: SsoErrorCode.INVALID_SAML_REQUEST,
},
on_error=on_error,
):
# Validate query parameters
query_params = validate_data(
CommonSsoLoginRequestSerializer,
request.GET.dict(),
partial=False,
exception_to_raise=QueryParameterValidationException,
return_validated=True,
)
original_url = query_params.pop("original", "")
valid_relay_state_url = get_valid_frontend_url(
original_url,
query_params,
default_frontend_urls=default_frontend_urls,
allow_any_path=False,
)
idp_sign_in_url = SamlAppAuthProviderHandler.get_sign_in_url(
query_params,
SamlAppAuthProviderHandler.model_class.objects.filter(
user_source__uid=user_source.uid
),
redirect_to=valid_relay_state_url,
)
if error_raised["code"]:
# We redirect to the default frontend url with an error code
error_url = urlencode_query_params(
default_frontend_urls[0],
{f"saml_error__{user_source.id}": error_raised["code"].value},
)
return redirect(error_url)
return redirect(idp_sign_in_url)

View file

@ -7,6 +7,8 @@ from baserow_enterprise.sso.saml.exceptions import InvalidSamlResponse
class SAMLResponseSerializer(serializers.Serializer): class SAMLResponseSerializer(serializers.Serializer):
query_param_serializer = SsoLoginRequestSerializer
SAMLResponse = serializers.CharField( SAMLResponse = serializers.CharField(
required=True, help_text="The encoded SAML response from the IdP." required=True, help_text="The encoded SAML response from the IdP."
) )
@ -25,11 +27,12 @@ class SAMLResponseSerializer(serializers.Serializer):
parsed_relay_state = urlparse(relay_state) parsed_relay_state = urlparse(relay_state)
query_params = dict(parse_qsl(parsed_relay_state.query)) query_params = dict(parse_qsl(parsed_relay_state.query))
if query_params: if query_params:
request_data_serializer = SsoLoginRequestSerializer(data=query_params) request_data_serializer = self.query_param_serializer(data=query_params)
if request_data_serializer.is_valid(): if request_data_serializer.is_valid():
data["saml_request_data"] = request_data_serializer.validated_data data["saml_request_data"] = request_data_serializer.validated_data
else: else:
raise InvalidSamlResponse("Invalid RelayState query parameters.") raise InvalidSamlResponse("Invalid RelayState query parameters.")
data["RelayState"] = parsed_relay_state._replace(query="").geturl() data["RelayState"] = parsed_relay_state._replace(query="").geturl()
return data return data

View file

@ -1,5 +1,7 @@
import io import io
from django.db.models import QuerySet
from rest_framework import serializers from rest_framework import serializers
from saml2.xml.schema import XMLSchemaError from saml2.xml.schema import XMLSchemaError
from saml2.xml.schema import validate as validate_saml_metadata_schema from saml2.xml.schema import validate as validate_saml_metadata_schema
@ -9,14 +11,17 @@ from baserow_enterprise.sso.saml.models import SamlAuthProviderModel
def validate_unique_saml_domain( def validate_unique_saml_domain(
domain, instance=None, model_class=SamlAuthProviderModel domain, instance=None, base_queryset: QuerySet | None = None
): ):
queryset = model_class.objects.filter(domain=domain) if base_queryset is None:
base_queryset = SamlAuthProviderModel.objects
queryset = base_queryset.filter(domain=domain)
if instance: if instance:
queryset = queryset.exclude(id=instance.id) queryset = queryset.exclude(id=instance.id)
if queryset.exists(): if queryset.exists():
raise SamlProviderForDomainAlreadyExists( raise SamlProviderForDomainAlreadyExists(
f"There is already a {model_class.__name__} for this domain." "There is already a provider for this domain."
) )
return domain return domain

View file

@ -230,7 +230,7 @@ class AdminAuthProvidersLoginUrlView(APIView):
) )
saml_login_url = urljoin( saml_login_url = urljoin(
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:login") settings.PUBLIC_BACKEND_URL, reverse("api:enterprise_sso_saml:login")
) )
saml_login_url = urlencode_query_params(saml_login_url, query_params) saml_login_url = urlencode_query_params(saml_login_url, query_params)
return Response({"redirect_url": saml_login_url}) return Response({"redirect_url": saml_login_url})

View file

@ -7,13 +7,13 @@ from baserow.api.user.serializers import NormalizedEmailField
from baserow.api.user.validators import language_validation from baserow.api.user.validators import language_validation
class SsoLoginRequestSerializer(serializers.Serializer): class BaseSsoLoginRequestSerializer(serializers.Serializer):
email = NormalizedEmailField( email = NormalizedEmailField(
required=False, help_text="The email address of the user." required=False, help_text="The email address of the user."
) )
original = serializers.CharField( original = serializers.CharField(
required=False, required=False,
help_text="The relative part of URL that the user wanted to access.", help_text="The original URL that the user wanted to access.",
) )
language = serializers.CharField( language = serializers.CharField(
required=False, required=False,
@ -23,6 +23,13 @@ class SsoLoginRequestSerializer(serializers.Serializer):
help_text="An ISO 639 language code (with optional variant) " help_text="An ISO 639 language code (with optional variant) "
"selected by the user. Ex: en-GB.", "selected by the user. Ex: en-GB.",
) )
def to_internal_value(self, instance) -> Dict[str, str]:
data = super().to_internal_value(instance)
return {k: v for k, v in data.items() if v is not None}
class SsoLoginRequestSerializer(BaseSsoLoginRequestSerializer):
workspace_invitation_token = serializers.CharField( workspace_invitation_token = serializers.CharField(
required=False, required=False,
help_text="If provided and valid, the user accepts the workspace invitation and" help_text="If provided and valid, the user accepts the workspace invitation and"
@ -34,8 +41,5 @@ class SsoLoginRequestSerializer(serializers.Serializer):
if urlparse(value).hostname: if urlparse(value).hostname:
return None return None
return value
def to_internal_value(self, instance) -> Dict[str, str]: return value
data = super().to_internal_value(instance)
return {k: v for k, v in data.items() if v is not None}

View file

@ -1,11 +0,0 @@
from django.urls import include, path
from .oauth2 import urls as oauth2_urls
from .saml import urls as saml_urls
app_name = "baserow_enterprise.api.sso"
urlpatterns = [
path("saml/", include(saml_urls, namespace="saml")),
path("oauth2/", include(oauth2_urls, namespace="oauth2")),
]

View file

@ -1,5 +1,7 @@
from contextlib import ContextDecorator
from enum import Enum from enum import Enum
from typing import Dict, Optional from functools import wraps
from typing import Callable, Dict, Optional, Type
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from django.conf import settings from django.conf import settings
@ -25,31 +27,58 @@ class SsoErrorCode(Enum):
SIGNUP_DISABLED = "errorSignupDisabled" SIGNUP_DISABLED = "errorSignupDisabled"
def map_sso_exceptions(mapping: Dict[Exception, SsoErrorCode]): class map_sso_exceptions(ContextDecorator):
""" """
This decorator can be used to map exceptions to SSO error codes. If the A context manager and decorator to map exceptions to SSO error codes. If the
decorated function raises an exception that is in the mapping, the enclosed code block or decorated function raises an exception that is in the
redirect_to_sign_in_error_page() function will be called with the mapped mapping, the provided redirect function will be called with the mapped error code.
error code. If the exception is not in the mapping, it will be raised If the exception is not in the mapping, it will be raised normally.
normally.
:param mapping: A dictionary that maps exceptions to SSO error codes. :param mapping: A dictionary that maps exceptions to SSO error codes.
:return: The decorator. :param on_error: A callable that takes an error code and handles the action.
""" """
def decorator(func): def __init__(
def wrapper(*args, **kwargs): self,
mapping: Dict[Type[Exception], str],
on_error: Callable[[str], None] | None = None,
):
self.mapping = mapping
if on_error is None:
self.on_error = redirect_to_sign_in_error_page
else:
self.on_error = on_error
def __enter__(self):
pass
# Context manager
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
# No exception occurred
return False
for exception, error_code in self.mapping.items():
if isinstance(exc_value, exception):
self.on_error(error_code)
return True # Swallow the exception after handling it
# If exception not handled, propagate it
return False
# Decorator version
def __call__(self, func):
@wraps(func)
def wrapped_function(*args, **kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except Exception as e: except Exception as e:
for exception, error_code in mapping.items(): for exception, error_code in self.mapping.items():
if isinstance(e, exception): if isinstance(e, exception):
return redirect_to_sign_in_error_page(error_code) return self.on_error(error_code)
raise e raise e
return wrapper return wrapped_function
return decorator
def urlencode_query_params(url: str, query_params: Dict[str, str]) -> str: def urlencode_query_params(url: str, query_params: Dict[str, str]) -> str:
@ -79,6 +108,7 @@ def redirect_to_sign_in_error_page(
""" """
frontend_error_page_url = get_frontend_login_error_url() frontend_error_page_url = get_frontend_login_error_url()
if error_code: if error_code:
frontend_error_page_url = urlencode_query_params( frontend_error_page_url = urlencode_query_params(
frontend_error_page_url, {"error": error_code.value} frontend_error_page_url, {"error": error_code.value}
@ -89,40 +119,64 @@ def redirect_to_sign_in_error_page(
def get_valid_frontend_url( def get_valid_frontend_url(
requested_original_url: Optional[str] = None, requested_original_url: Optional[str] = None,
query_params: Optional[Dict[str, str]] = None, query_params: Optional[Dict[str, str]] = None,
) -> str: default_frontend_urls: list[str] | None = None,
allow_any_path: bool = True,
):
""" """
Returns a valid absolute frontend url based on the original url requested Returns a valid absolute frontend url based on the original url requested
before the redirection to the login (can be relative or absolute). If the before the redirection to the login (can be relative or absolute). If the
original url is relative, it will be prefixed with the frontend hostname to original url is relative, it will be prefixed with the default hostname to
make the IdP redirection work. If the original url is external to Baserow, make the IdP redirection work. If the original url doesn't match any of the given
the default frontend dashboard url will be returned instead. default_front_urls, the first default frontend url will be used instead.
:param requested_original_url: The url to which the user should be :param requested_original_url: The url to which the user should be
redirected after a successful login. redirected after a successful login.
:param query_params: The query parameters to add to the URL.
:param default_frontend_urls: The first one is the default fallback frontend URL.
Baserow one is used if None. Others are also allowed as valid URLs.
:return: The url with the token as a query parameter. :return: The url with the token as a query parameter.
""" """
requested_url_parsed = urlparse(requested_original_url or "") requested_url_parsed = urlparse(requested_original_url or "")
default_frontend_url_parsed = urlparse(get_frontend_default_redirect_url())
if requested_url_parsed.path in ["", "/"]: if default_frontend_urls is None:
# use the default frontend path if the requested one is empty default_frontend_urls = [get_frontend_default_redirect_url()]
requested_url_parsed = requested_url_parsed._replace(
path=default_frontend_url_parsed.path default_frontend_urls_parsed = [urlparse(u) for u in default_frontend_urls]
) default_frontend_url_parsed = default_frontend_urls_parsed[0]
matching_url = default_frontend_url_parsed
if requested_url_parsed.hostname is None: if requested_url_parsed.hostname is None:
# provide a correct absolute url if the requested one is relative # provide a correct absolute url if the requested one is relative
requested_url_parsed = default_frontend_url_parsed._replace( requested_url_parsed = default_frontend_url_parsed._replace(
path=requested_url_parsed.path path=requested_url_parsed.path
) )
elif requested_url_parsed.hostname != default_frontend_url_parsed.hostname:
# return the default url if the requested url is external to Baserow else:
requested_url_parsed = default_frontend_url_parsed found = False
for allowed_url in default_frontend_urls_parsed:
if requested_url_parsed.hostname == allowed_url.hostname:
matching_url = allowed_url
found = True
if not found:
# None are matching -> redirecting to main homepage
requested_url_parsed = default_frontend_url_parsed
matching_url = default_frontend_url_parsed
if allow_any_path:
if requested_url_parsed.path in ["", "/"]:
# use the default frontend path if the requested one is empty
requested_url_parsed = requested_url_parsed._replace(path=matching_url.path)
elif not requested_url_parsed.geturl().startswith(matching_url.geturl()):
# if using a path that doesn't match the allowed urls, we reset to default url
requested_url_parsed = matching_url
if query_params: if query_params:
return urlencode_query_params(requested_url_parsed.geturl(), query_params) return urlencode_query_params(requested_url_parsed.geturl(), query_params)
return str(requested_url_parsed.geturl()) return requested_url_parsed.geturl()
def urlencode_user_token(frontend_url: str, user: AbstractUser) -> str: def urlencode_user_token(frontend_url: str, user: AbstractUser) -> str:

View file

@ -4,7 +4,6 @@ from .admin import urls as admin_urls
from .audit_log import urls as audit_log_urls from .audit_log import urls as audit_log_urls
from .role import urls as role_urls from .role import urls as role_urls
from .secure_file_serve import urls as secure_file_serve_urls from .secure_file_serve import urls as secure_file_serve_urls
from .sso import urls as sso_urls
from .teams import urls as teams_urls from .teams import urls as teams_urls
app_name = "baserow_enterprise.api" app_name = "baserow_enterprise.api"
@ -13,7 +12,6 @@ urlpatterns = [
path("teams/", include(teams_urls, namespace="teams")), path("teams/", include(teams_urls, namespace="teams")),
path("role/", include(role_urls, namespace="role")), path("role/", include(role_urls, namespace="role")),
path("admin/", include(admin_urls, namespace="admin")), path("admin/", include(admin_urls, namespace="admin")),
path("sso/", include(sso_urls, namespace="sso")),
path("audit-log/", include(audit_log_urls, namespace="audit_log")), path("audit-log/", include(audit_log_urls, namespace="audit_log")),
path("files/", include(secure_file_serve_urls, namespace="files")), path("files/", include(secure_file_serve_urls, namespace="files")),
] ]

View file

@ -174,6 +174,12 @@ class BaserowEnterpriseConfig(AppConfig):
LocalBaserowPasswordAppAuthProviderType() LocalBaserowPasswordAppAuthProviderType()
) )
from baserow_enterprise.integrations.common.sso.saml.app_auth_provider_types import (
SamlAppAuthProviderType,
)
app_auth_provider_type_registry.register(SamlAppAuthProviderType())
from baserow.contrib.builder.elements.registries import element_type_registry from baserow.contrib.builder.elements.registries import element_type_registry
from baserow_enterprise.builder.elements.element_types import ( from baserow_enterprise.builder.elements.element_types import (
AuthFormElementType, AuthFormElementType,

View file

@ -0,0 +1,113 @@
from typing import List
from urllib.parse import urljoin
from django.conf import settings
from django.urls import include, path, reverse
from baserow.core.app_auth_providers.auth_provider_types import AppAuthProviderType
from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict
from baserow_enterprise.api.sso.saml.validators import validate_unique_saml_domain
from baserow_enterprise.integrations.local_baserow.user_source_types import (
LocalBaserowUserSourceType,
)
from baserow_enterprise.sso.saml.auth_provider_types import SamlAuthProviderTypeMixin
from .models import SamlAppAuthProviderModel
class SamlAppAuthProviderType(SamlAuthProviderTypeMixin, AppAuthProviderType):
"""
The SAML authentication provider type allows users to login using SAML.
"""
model_class = SamlAppAuthProviderModel
compatible_user_source_types = [LocalBaserowUserSourceType.type]
class SerializedDict(
AppAuthProviderTypeDict, SamlAuthProviderTypeMixin.SamlSerializedDict
):
...
@property
def allowed_fields(self) -> List[str]:
return SamlAuthProviderTypeMixin.saml_allowed_fields
@property
def serializer_field_names(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_names
public_serializer_field_names = []
@property
def serializer_field_overrides(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_overrides
public_serializer_field_overrides = {}
def get_api_urls(self):
from baserow_enterprise.api.integrations.common.sso.saml import urls
return [
path(
"user-source/<str:user_source_uid>/sso/saml/",
include(urls, namespace="sso_saml"),
)
]
def before_create(self, user, **values):
user_source = values["user_source"]
if "domain" in values:
validate_unique_saml_domain(
values["domain"],
base_queryset=SamlAppAuthProviderModel.objects.filter(
user_source=user_source
),
)
return super().before_create(user, **values)
def before_update(self, user, provider, **values):
if "domain" in values:
user_source = values.get("user_source", provider.user_source)
validate_unique_saml_domain(
values["domain"],
provider,
base_queryset=SamlAppAuthProviderModel.objects.filter(
user_source=user_source
),
)
return super().before_update(user, provider, **values)
@classmethod
def get_acs_absolute_url(
cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None
):
"""
Returns the ACS url for SAML authentication purpose. The user is redirected
to this URL after a successful login.
"""
return urljoin(
settings.PUBLIC_BACKEND_URL,
reverse(
"api:user_sources:sso_saml:acs",
kwargs={"user_source_uid": auth_provider.user_source.uid},
),
)
@classmethod
def get_login_absolute_url(
cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None
):
"""
Returns the login URL for this auth_provider. The login URL is used to initiate
the Saml login process.
"""
return urljoin(
settings.PUBLIC_BACKEND_URL,
reverse(
"api:user_sources:sso_saml:login",
kwargs={"user_source_uid": auth_provider.user_source.uid},
),
)

View file

@ -0,0 +1,8 @@
from baserow_enterprise.integrations.common.sso.saml.models import (
SamlAppAuthProviderModel,
)
from baserow_enterprise.sso.saml.handler import SamlAuthProviderHandler
class SamlAppAuthProviderHandler(SamlAuthProviderHandler):
model_class = SamlAppAuthProviderModel

View file

@ -0,0 +1,8 @@
from baserow.core.app_auth_providers.models import AppAuthProvider
from baserow_enterprise.sso.saml.models import SamlAuthProviderModelMixin
class SamlAppAuthProviderModel(SamlAuthProviderModelMixin, AppAuthProvider):
# Restore ordering
class Meta(AppAuthProvider.Meta):
ordering = ["id"]

View file

@ -31,6 +31,8 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType):
] ]
serializer_field_names = ["password_field_id"] serializer_field_names = ["password_field_id"]
public_serializer_field_names = []
allowed_fields = ["password_field"] allowed_fields = ["password_field"]
serializer_field_overrides = { serializer_field_overrides = {
@ -40,6 +42,7 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType):
help_text="The id of the field to use as password for the user account.", help_text="The id of the field to use as password for the user account.",
), ),
} }
public_serializer_field_overrides = {}
class SerializedDict(AppAuthProviderTypeDict): class SerializedDict(AppAuthProviderTypeDict):
password_field_id: int password_field_id: int
@ -151,13 +154,6 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType):
instance.password_field = None instance.password_field = None
instance.save() instance.save()
def get_login_options(self, **kwargs) -> Dict[str, Any]:
"""
Not implemented yet.
"""
return {}
def get_or_create_user_and_sign_in( def get_or_create_user_and_sign_in(
self, auth_provider: AuthProviderModelSubClass, user_info: Dict[str, Any] self, auth_provider: AuthProviderModelSubClass, user_info: Dict[str, Any]
) -> Tuple[AbstractUser, bool]: ) -> Tuple[AbstractUser, bool]:

View file

@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from loguru import logger
from rest_framework import serializers from rest_framework import serializers
from baserow.api.exceptions import RequestBodyValidationException from baserow.api.exceptions import RequestBodyValidationException
@ -17,6 +18,7 @@ from baserow.contrib.database.fields.field_types import (
) )
from baserow.contrib.database.fields.handler import FieldHandler from baserow.contrib.database.fields.handler import FieldHandler
from baserow.contrib.database.fields.registries import FieldType from baserow.contrib.database.fields.registries import FieldType
from baserow.contrib.database.rows.actions import CreateRowsActionType
from baserow.contrib.database.rows.operations import ReadDatabaseRowOperationType from baserow.contrib.database.rows.operations import ReadDatabaseRowOperationType
from baserow.contrib.database.search.handler import SearchHandler from baserow.contrib.database.search.handler import SearchHandler
from baserow.contrib.database.table.exceptions import TableDoesNotExist from baserow.contrib.database.table.exceptions import TableDoesNotExist
@ -450,9 +452,9 @@ class LocalBaserowUserSourceType(UserSourceType):
return ( return (
f"{user_source.id}" f"{user_source.id}"
f"_{user_source.table_id if user_source.table_id else '?'}" f"_{user_source.table_id if user_source.table_id else '0'}"
f"_{user_source.email_field_id if user_source.email_field_id else '?'}" f"_{user_source.email_field_id if user_source.email_field_id else '0'}"
f"_{user_source.role_field_id if user_source.role_field_id else '?'}" f"_{user_source.role_field_id if user_source.role_field_id else '0'}"
) )
def role_type_is_allowed(self, role_field: Optional[FieldType]) -> bool: def role_type_is_allowed(self, role_field: Optional[FieldType]) -> bool:
@ -596,6 +598,54 @@ class LocalBaserowUserSourceType(UserSourceType):
raise UserNotFound() raise UserNotFound()
def create_user(self, user_source: UserSource, email, name, role=None):
"""
Creates the user in the configured table.
"""
if not self.is_configured(user_source):
raise UserSourceImproperlyConfigured()
try:
# Use table handler to exclude trashed table
table = TableHandler().get_table(user_source.table_id)
except TableDoesNotExist as exc:
# As we CASCADE when a table is deleted, the table shouldn't
# exist only if it's trashed and not yet deleted.
raise UserSourceImproperlyConfigured("The table doesn't exist.") from exc
integration = user_source.integration.specific
model = table.get_model()
values = {
user_source.name_field.db_column: name,
user_source.email_field.db_column: email,
}
if role and user_source.role_field_id:
values[user_source.role_field.db_column] = role
try:
# Use the action to keep track on what's going on
(user,) = CreateRowsActionType.do(
user=integration.authorized_user,
table=table,
rows_values=[values],
model=model,
)
except Exception as e:
logger.exception(e)
raise ("Error while creating the user") from e
return UserSourceUser(
user_source,
user,
user.id,
getattr(user, user_source.name_field.db_column),
getattr(user, user_source.email_field.db_column),
self.get_user_role(user, user_source),
)
def authenticate(self, user_source: UserSource, **kwargs): def authenticate(self, user_source: UserSource, **kwargs):
""" """
Authenticates using the given credentials. It uses the password auth provider. Authenticates using the given credentials. It uses the password auth provider.

View file

@ -0,0 +1,81 @@
# Generated by Django 5.0.9 on 2024-12-05 11:23
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("baserow_enterprise", "0033_samlauthprovidermodel_email_attr_key_and_more"),
("core", "0093_alter_appauthprovider_options_and_more"),
]
operations = [
migrations.CreateModel(
name="SamlAppAuthProviderModel",
fields=[
(
"appauthprovider_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="core.appauthprovider",
),
),
(
"metadata",
models.TextField(
blank=True,
help_text="The XML metadata downloaded from the metadata_url.",
),
),
(
"is_verified",
models.BooleanField(
default=False,
help_text="This will be set to True only after a user successfully login with this IdP. This must be True to disable normal username/password login and make SAML the only authentication provider. ",
),
),
(
"email_attr_key",
models.CharField(
db_default="user.email",
default="user.email",
help_text="The key in the SAML response that contains the email address of the user. If this is not set, the email will be taken from the user's profile.",
max_length=32,
),
),
(
"first_name_attr_key",
models.CharField(
db_default="user.first_name",
default="user.first_name",
help_text="The key in the SAML response that contains the first name of the user. If this is not set, the first name will be taken from the user's profile.",
max_length=32,
),
),
(
"last_name_attr_key",
models.CharField(
blank=True,
db_default="user.last_name",
default="user.last_name",
help_text="The key in the SAML response that contains the last name of the user. If this is not set, the last name will be taken from the user's profile.",
max_length=32,
),
),
],
options={
"ordering": ["id"],
"abstract": False,
},
bases=("core.appauthprovider", models.Model),
),
migrations.AlterModelOptions(
name="samlauthprovidermodel",
options={"ordering": ["domain"]},
),
]

View file

@ -1,6 +1,12 @@
from baserow_enterprise.builder.elements.models import AuthFormElement from baserow_enterprise.builder.elements.models import AuthFormElement
from baserow_enterprise.data_sync.models import LocalBaserowTableDataSync from baserow_enterprise.data_sync.models import LocalBaserowTableDataSync
from baserow_enterprise.integrations.models import LocalBaserowUserSource from baserow_enterprise.integrations.common.sso.saml.models import (
SamlAppAuthProviderModel,
)
from baserow_enterprise.integrations.models import (
LocalBaserowPasswordAppAuthProvider,
LocalBaserowUserSource,
)
from baserow_enterprise.role.models import Role, RoleAssignment from baserow_enterprise.role.models import Role, RoleAssignment
from baserow_enterprise.teams.models import Team, TeamSubject from baserow_enterprise.teams.models import Team, TeamSubject
@ -12,4 +18,6 @@ __all__ = [
"LocalBaserowUserSource", "LocalBaserowUserSource",
"AuthFormElement", "AuthFormElement",
"LocalBaserowTableDataSync", "LocalBaserowTableDataSync",
"LocalBaserowPasswordAppAuthProvider",
"SamlAppAuthProviderModel",
] ]

View file

@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple
from django.conf import settings from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase from django.contrib.sessions.backends.base import SessionBase
from django.urls import reverse from django.urls import include, path, reverse
import requests import requests
from loguru import logger from loguru import logger
@ -30,6 +30,8 @@ from .models import (
OAUTH_BACKEND_URL = settings.PUBLIC_BACKEND_URL OAUTH_BACKEND_URL = settings.PUBLIC_BACKEND_URL
_is_url_already_loaded = False
@dataclass @dataclass
class WellKnownUrls: class WellKnownUrls:
@ -50,6 +52,20 @@ class OAuth2AuthProviderMixin:
- self.SCOPE - self.SCOPE
""" """
def get_api_urls(self):
global _is_url_already_loaded
from baserow_enterprise.api.sso.oauth2 import urls
if not _is_url_already_loaded:
_is_url_already_loaded = True
# We need to register this only once
return [
path("sso/oauth2/", include(urls, namespace="enterprise_sso_oauth2"))
]
else:
return []
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]: def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
if not is_sso_feature_active(): if not is_sso_feature_active():
return None return None
@ -64,7 +80,7 @@ class OAuth2AuthProviderMixin:
{ {
"redirect_url": urllib.parse.urljoin( "redirect_url": urllib.parse.urljoin(
OAUTH_BACKEND_URL, OAUTH_BACKEND_URL,
reverse("api:enterprise:sso:oauth2:login", args=(instance.id,)), reverse("api:enterprise_sso_oauth2:login", args=(instance.id,)),
), ),
"name": instance.name, "name": instance.name,
"type": self.type, "type": self.type,
@ -144,7 +160,7 @@ class OAuth2AuthProviderMixin:
redirect_uri = urllib.parse.urljoin( redirect_uri = urllib.parse.urljoin(
OAUTH_BACKEND_URL, OAUTH_BACKEND_URL,
reverse("api:enterprise:sso:oauth2:callback", args=(instance.id,)), reverse("api:enterprise_sso_oauth2:callback", args=(instance.id,)),
) )
if "oauth_state" in session: if "oauth_state" in session:
return OAuth2Session( return OAuth2Session(

View file

@ -1,8 +1,9 @@
from typing import Any, Dict, List, Optional from abc import abstractmethod
from typing import Any, Dict, List, Optional, TypedDict
from urllib.parse import urljoin from urllib.parse import urljoin
from django.conf import settings from django.conf import settings
from django.urls import reverse from django.urls import include, path, reverse
from rest_framework import serializers from rest_framework import serializers
@ -26,43 +27,33 @@ from baserow_enterprise.sso.utils import is_sso_feature_active
from .models import SamlAuthProviderModel from .models import SamlAuthProviderModel
class SamlAuthProviderType(AuthProviderType): class SamlAuthProviderTypeMixin:
""" """
The SAML authentication provider type allows users to login using SAML. The SAML authentication provider type allows users to login using SAML.
""" """
type = "saml" type = "saml"
model_class = SamlAuthProviderModel
allowed_fields: List[str] = [ class SamlSerializedDict(TypedDict):
"id", metadata: Dict
"domain", is_verified: bool
"type",
"enabled", saml_allowed_fields: List[str] = [
"metadata", "metadata",
"is_verified", "is_verified",
"email_attr_key", "email_attr_key",
"first_name_attr_key", "first_name_attr_key",
"last_name_attr_key", "last_name_attr_key",
] ]
serializer_field_names = [ saml_serializer_field_names = [
"domain",
"metadata", "metadata",
"enabled",
"is_verified", "is_verified",
"email_attr_key", "email_attr_key",
"first_name_attr_key", "first_name_attr_key",
"last_name_attr_key", "last_name_attr_key",
] ]
serializer_field_overrides = {
"domain": serializers.CharField( saml_serializer_field_overrides = {
validators=[validate_domain],
required=True,
help_text="The email domain registered with this provider.",
),
"enabled": serializers.BooleanField(
help_text="Whether the provider is enabled or not.",
required=False,
),
"metadata": serializers.CharField( "metadata": serializers.CharField(
validators=[validate_saml_metadata], validators=[validate_saml_metadata],
required=True, required=True,
@ -98,21 +89,61 @@ class SamlAuthProviderType(AuthProviderType):
SamlProviderForDomainAlreadyExists: ERROR_SAML_PROVIDER_FOR_DOMAIN_ALREADY_EXISTS SamlProviderForDomainAlreadyExists: ERROR_SAML_PROVIDER_FOR_DOMAIN_ALREADY_EXISTS
} }
def before_create(self, user, **values): @classmethod
validate_unique_saml_domain(values["domain"]) @abstractmethod
return super().before_create(user, **values) def get_acs_absolute_url(
cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None
):
"""
Returns the ACS url for SAML authentication purpose. The user is redirected
to this URL after a successful login.
"""
def before_update(self, user, provider, **values): @classmethod
if "domain" in values: @abstractmethod
validate_unique_saml_domain(values["domain"], provider) def get_login_absolute_url(cls):
return super().before_update(user, provider, **values) """
Returns the login URL for this auth_provider. The login URL is used to initiate
the Saml login process.
"""
class SamlAuthProviderType(SamlAuthProviderTypeMixin, AuthProviderType):
"""
The SAML authentication provider type allows users to login using SAML.
"""
model_class = SamlAuthProviderModel
@property
def allowed_fields(self) -> List[str]:
return SamlAuthProviderTypeMixin.saml_allowed_fields
@property
def serializer_field_names(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_names
@property
def serializer_field_overrides(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_overrides | {
"domain": serializers.CharField(
validators=[validate_domain],
required=True,
help_text="The email domain registered with this provider.",
),
}
def get_api_urls(self):
from baserow_enterprise.api.sso.saml import urls
return [path("sso/saml/", include(urls, namespace="enterprise_sso_saml"))]
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]: def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
single_sign_on_feature_active = is_sso_feature_active() single_sign_on_feature_active = is_sso_feature_active()
if not single_sign_on_feature_active: if not single_sign_on_feature_active:
return None return None
configured_domains = SamlAuthProviderModel.objects.filter(enabled=True).count() configured_domains = self.model_class.objects.filter(enabled=True).count()
if not configured_domains: if not configured_domains:
return None return None
@ -130,16 +161,28 @@ class SamlAuthProviderType(AuthProviderType):
"default_redirect_url": default_redirect_url, "default_redirect_url": default_redirect_url,
} }
def before_create(self, user, **values):
validate_unique_saml_domain(values["domain"])
return super().before_create(user, **values)
def before_update(self, user, provider, **values):
if "domain" in values:
validate_unique_saml_domain(values["domain"], provider)
return super().before_update(user, provider, **values)
@classmethod @classmethod
def get_acs_absolute_url(cls): def get_acs_absolute_url(
cls, auth_provider: SamlAuthProviderTypeMixin | None = None
):
return urljoin( return urljoin(
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:acs") settings.PUBLIC_BACKEND_URL, reverse("api:enterprise_sso_saml:acs")
) )
@classmethod @classmethod
def get_login_absolute_url(cls): def get_login_absolute_url(cls):
return urljoin( return urljoin(
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:login") settings.PUBLIC_BACKEND_URL,
reverse("api:enterprise_sso_saml:login"),
) )
def export_serialized(self) -> Dict[str, Any]: def export_serialized(self) -> Dict[str, Any]:

View file

@ -4,6 +4,7 @@ from typing import Any, Dict, Optional
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.db.models import Model, QuerySet
from defusedxml import ElementTree from defusedxml import ElementTree
from loguru import logger from loguru import logger
@ -13,9 +14,11 @@ from saml2.config import Config as Saml2Config
from saml2.response import AuthnResponse from saml2.response import AuthnResponse
from baserow.core.auth_provider.types import UserInfo from baserow.core.auth_provider.types import UserInfo
from baserow.core.registries import auth_provider_type_registry
from baserow_enterprise.api.sso.utils import get_valid_frontend_url from baserow_enterprise.api.sso.utils import get_valid_frontend_url
from baserow_enterprise.sso.saml.models import SamlAuthProviderModel from baserow_enterprise.sso.saml.models import (
SamlAuthProviderModel,
SamlAuthProviderModelMixin,
)
from .exceptions import ( from .exceptions import (
InvalidSamlConfiguration, InvalidSamlConfiguration,
@ -25,10 +28,12 @@ from .exceptions import (
class SamlAuthProviderHandler: class SamlAuthProviderHandler:
model_class: Model = SamlAuthProviderModel
@classmethod @classmethod
def prepare_saml_client( def prepare_saml_client(
cls, cls,
saml_auth_provider: SamlAuthProviderModel, saml_auth_provider: SamlAuthProviderModelMixin,
) -> Saml2Client: ) -> Saml2Client:
""" """
Returns a SAML client with the correct configuration for the given Returns a SAML client with the correct configuration for the given
@ -39,10 +44,9 @@ class SamlAuthProviderHandler:
:return: The SAML client that can be used to authenticate the user. :return: The SAML client that can be used to authenticate the user.
""" """
saml_provider_type = auth_provider_type_registry.get_by_model( acs_url = saml_auth_provider.get_type().get_acs_absolute_url(
saml_auth_provider saml_auth_provider.specific
) )
acs_url = saml_provider_type.get_acs_absolute_url()
saml_settings: Dict[str, Any] = { saml_settings: Dict[str, Any] = {
"entityid": acs_url, "entityid": acs_url,
@ -95,9 +99,8 @@ class SamlAuthProviderHandler:
@classmethod @classmethod
def get_saml_auth_provider_from_saml_response( def get_saml_auth_provider_from_saml_response(
cls, cls, raw_saml_response: str, base_queryset: QuerySet | None = None
raw_saml_response: str, ) -> SamlAuthProviderModelMixin:
) -> SamlAuthProviderModel:
""" """
Parses the saml response and returns the authentication provider that needs to Parses the saml response and returns the authentication provider that needs to
be used to authenticate the user. be used to authenticate the user.
@ -120,7 +123,10 @@ class SamlAuthProviderHandler:
except (ElementTree.ParseError, AttributeError): except (ElementTree.ParseError, AttributeError):
raise InvalidSamlResponse("Impossible decode SAML response.") raise InvalidSamlResponse("Impossible decode SAML response.")
saml_auth_provider = SamlAuthProviderModel.objects.filter( if base_queryset is None:
base_queryset = cls.model_class.objects
saml_auth_provider = base_queryset.filter(
enabled=True, metadata__contains=issuer enabled=True, metadata__contains=issuer
).first() ).first()
if not saml_auth_provider: if not saml_auth_provider:
@ -130,7 +136,7 @@ class SamlAuthProviderHandler:
@classmethod @classmethod
def get_user_info_from_authn_user_identity( def get_user_info_from_authn_user_identity(
cls, cls,
saml_auth_provider: SamlAuthProviderModel, saml_auth_provider: SamlAuthProviderModelMixin,
authn_identity: Dict[str, str], authn_identity: Dict[str, str],
saml_request_data: Optional[Dict[str, str]] = None, saml_request_data: Optional[Dict[str, str]] = None,
) -> UserInfo: ) -> UserInfo:
@ -172,8 +178,7 @@ class SamlAuthProviderHandler:
@classmethod @classmethod
def get_saml_auth_provider_from_email( def get_saml_auth_provider_from_email(
cls, cls, email: Optional[str] = None, base_queryset: QuerySet | None = None
email: Optional[str] = None,
) -> SamlAuthProviderModel: ) -> SamlAuthProviderModel:
""" """
It returns the Saml Identity Provider for the the given email address. It returns the Saml Identity Provider for the the given email address.
@ -187,20 +192,24 @@ class SamlAuthProviderHandler:
address provided. address provided.
""" """
base_queryset = SamlAuthProviderModel.objects.filter(enabled=True) if base_queryset is None:
base_queryset = cls.model_class.objects
queryset = base_queryset.filter(enabled=True)
if email is not None: if email is not None:
try: try:
domain = email.rsplit("@", 1)[1] domain = email.rsplit("@", 1)[1]
except IndexError: except IndexError:
raise InvalidSamlRequest("Invalid mail address provided.") raise InvalidSamlRequest("Invalid mail address provided.")
base_queryset = base_queryset.filter(domain=domain)
queryset = queryset.filter(domain=domain)
try: try:
return base_queryset.get() return queryset.get()
except ( except (
SamlAuthProviderModel.DoesNotExist, cls.model_class.DoesNotExist,
SamlAuthProviderModel.MultipleObjectsReturned, cls.model_class.MultipleObjectsReturned,
): ):
raise InvalidSamlRequest("No valid SAML identity provider found.") raise InvalidSamlRequest("No valid SAML identity provider found.")
@ -213,7 +222,10 @@ class SamlAuthProviderHandler:
@classmethod @classmethod
def sign_in_user_from_saml_response( def sign_in_user_from_saml_response(
cls, saml_response: str, saml_request_data: Optional[Dict[str, str]] = None cls,
saml_response: str,
saml_request_data: Optional[Dict[str, str]] = None,
base_queryset: QuerySet | None = None,
) -> AbstractUser: ) -> AbstractUser:
""" """
Signs in the user using the SAML response received from the identity Signs in the user using the SAML response received from the identity
@ -230,7 +242,7 @@ class SamlAuthProviderHandler:
try: try:
saml_auth_provider = cls.get_saml_auth_provider_from_saml_response( saml_auth_provider = cls.get_saml_auth_provider_from_saml_response(
saml_response saml_response, base_queryset=base_queryset
) )
saml_client = cls.prepare_saml_client(saml_auth_provider) saml_client = cls.prepare_saml_client(saml_auth_provider)
@ -249,11 +261,10 @@ class SamlAuthProviderHandler:
logger.exception(exc) logger.exception(exc)
raise InvalidSamlResponse(str(exc)) raise InvalidSamlResponse(str(exc))
saml_provider_type = saml_auth_provider.get_type()
( (
user, user,
_, _,
) = saml_provider_type.get_or_create_user_and_sign_in( ) = saml_auth_provider.get_type().get_or_create_user_and_sign_in(
saml_auth_provider, idp_provided_user_info saml_auth_provider, idp_provided_user_info
) )
@ -268,7 +279,7 @@ class SamlAuthProviderHandler:
@classmethod @classmethod
def get_sign_in_url_for_auth_provider( def get_sign_in_url_for_auth_provider(
cls, cls,
saml_auth_provider: SamlAuthProviderModel, saml_auth_provider: SamlAuthProviderModelMixin,
original_url: str = "", original_url: str = "",
) -> str: ) -> str:
""" """
@ -284,6 +295,7 @@ class SamlAuthProviderHandler:
""" """
saml_client = cls.prepare_saml_client(saml_auth_provider) saml_client = cls.prepare_saml_client(saml_auth_provider)
_, info = saml_client.prepare_for_authenticate(relay_state=original_url) _, info = saml_client.prepare_for_authenticate(relay_state=original_url)
for key, value in info["headers"]: for key, value in info["headers"]:
@ -294,24 +306,36 @@ class SamlAuthProviderHandler:
raise InvalidSamlConfiguration("No Location header found in SAML response.") raise InvalidSamlConfiguration("No Location header found in SAML response.")
@classmethod @classmethod
def get_sign_in_url(cls, query_params: Dict[str, str]) -> str: def get_sign_in_url(
cls,
query_params: Dict[str, str],
base_queryset: QuerySet | None = None,
redirect_to: str | None = None,
) -> str:
""" """
Returns the sign in url for the correct identity provider. This url is Returns the sign in url for the correct identity provider. This url is
used to initiate the SAML authentication flow from the service provider. used to initiate the SAML authentication flow from the service provider.
:param query_params: A dict containing the query parameters from the :param query_params: A dict containing the query parameters from the
sign in request. sign in request.
:param redirect_to: if set, used as relay state url.
:raises InvalidSamlRequest: If the email address is invalid. :raises InvalidSamlRequest: If the email address is invalid.
:raises InvalidSamlConfiguration: If the SAML configuration is invalid. :raises InvalidSamlConfiguration: If the SAML configuration is invalid.
:return: The redirect url to the identity provider. :return: The redirect url to the identity provider.
""" """
user_email = query_params.pop("email", None) user_email = query_params.pop("email", None)
original_url = query_params.pop("original", "")
valid_relay_state_url = get_valid_frontend_url(original_url, query_params) if redirect_to:
valid_relay_state_url = redirect_to
else:
original_url = query_params.pop("original", "")
valid_relay_state_url = get_valid_frontend_url(original_url, query_params)
try: try:
saml_auth_provider = cls.get_saml_auth_provider_from_email(user_email) saml_auth_provider = cls.get_saml_auth_provider_from_email(
user_email, base_queryset=base_queryset
)
return cls.get_sign_in_url_for_auth_provider( return cls.get_sign_in_url_for_auth_provider(
saml_auth_provider, valid_relay_state_url saml_auth_provider, valid_relay_state_url
) )

View file

@ -5,7 +5,7 @@ from django.dispatch import receiver
from baserow.core.auth_provider.models import AuthProviderModel from baserow.core.auth_provider.models import AuthProviderModel
class SamlAuthProviderModel(AuthProviderModel): class SamlAuthProviderModelMixin(models.Model):
metadata = models.TextField( metadata = models.TextField(
blank=True, help_text="The XML metadata downloaded from the metadata_url." blank=True, help_text="The XML metadata downloaded from the metadata_url."
) )
@ -46,6 +46,15 @@ class SamlAuthProviderModel(AuthProviderModel):
), ),
) )
class Meta:
abstract = True
class SamlAuthProviderModel(SamlAuthProviderModelMixin, AuthProviderModel):
# Restore ordering
class Meta(AuthProviderModel.Meta):
ordering = ["domain"]
@receiver(pre_save, sender=SamlAuthProviderModel) @receiver(pre_save, sender=SamlAuthProviderModel)
def reset_is_verified_if_metadata_changed(sender, instance, **kwargs): def reset_is_verified_if_metadata_changed(sender, instance, **kwargs):

View file

@ -52,7 +52,7 @@ def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture):
) )
response = api_client.get( response = api_client.get(
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": provider.id}), reverse("api:enterprise_sso_oauth2:login", kwargs={"provider_id": provider.id}),
format="json", format="json",
) )
@ -68,7 +68,7 @@ def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture):
def test_oauth2_login_provider_doesnt_exist(api_client, enterprise_data_fixture): def test_oauth2_login_provider_doesnt_exist(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise() enterprise_data_fixture.enable_enterprise()
response = api_client.get( response = api_client.get(
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": 50}), reverse("api:enterprise_sso_oauth2:login", kwargs={"provider_id": 50}),
format="json", format="json",
) )
@ -88,7 +88,7 @@ def test_oauth2_login_with_url_param(api_client, enterprise_data_fixture):
) )
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:login", "api:enterprise_sso_oauth2:login",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?original=templates&workspace_invitation_token=t&language=en", + "?original=templates&workspace_invitation_token=t&language=en",
@ -118,7 +118,7 @@ def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture)
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", kwargs={"provider_id": provider.id} "api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}
), ),
format="json", format="json",
) )
@ -135,7 +135,7 @@ def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture)
def test_oauth2_callback_provider_doesnt_exist(api_client, enterprise_data_fixture): def test_oauth2_callback_provider_doesnt_exist(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise() enterprise_data_fixture.enable_enterprise()
response = api_client.get( response = api_client.get(
reverse("api:enterprise:sso:oauth2:callback", kwargs={"provider_id": 50}), reverse("api:enterprise_sso_oauth2:callback", kwargs={"provider_id": 50}),
format="json", format="json",
) )
@ -167,7 +167,7 @@ def test_oauth2_callback_signup_success(api_client, enterprise_data_fixture):
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -211,7 +211,7 @@ def test_oauth2_callback_signup_set_language(api_client, enterprise_data_fixture
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -263,7 +263,7 @@ def test_oauth2_callback_signup_workspace_invitation(
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -318,7 +318,7 @@ def test_oauth2_callback_signup_workspace_invitation_email_mismatch(
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -360,7 +360,7 @@ def test_oauth2_callback_signup_disabled(api_client, enterprise_data_fixture):
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -408,7 +408,7 @@ def test_oauth2_callback_login_success(
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -460,7 +460,7 @@ def test_oauth2_callback_login_deactivated_user(
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -505,7 +505,7 @@ def test_oauth2_callback_login_different_provider(
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",
@ -549,7 +549,7 @@ def test_oauth2_callback_login_auth_flow_error(
response = api_client.get( response = api_client.get(
reverse( reverse(
"api:enterprise:sso:oauth2:callback", "api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id}, kwargs={"provider_id": provider.id},
) )
+ "?code=validcode", + "?code=validcode",

View file

@ -33,8 +33,8 @@ def test_saml_provider_get_login_url(api_client, data_fixture, enterprise_data_f
auth_provider_1 = enterprise_data_fixture.create_saml_auth_provider( auth_provider_1 = enterprise_data_fixture.create_saml_auth_provider(
domain="test1.com" domain="test1.com"
) )
auth_provider_login = reverse("api:enterprise:sso:saml:login") auth_provider_login = reverse("api:enterprise_sso_saml:login")
auth_provider_login_url = reverse("api:enterprise:sso:saml:login_url") auth_provider_login_url = reverse("api:enterprise_sso_saml:login_url")
_, unauthorized_token = data_fixture.create_user_and_token() _, unauthorized_token = data_fixture.create_user_and_token()
@ -147,7 +147,7 @@ def test_user_cannot_initiate_saml_sso_without_enterprise_license(
api_client, enterprise_data_fixture api_client, enterprise_data_fixture
): ):
enterprise_data_fixture.create_saml_auth_provider(domain="test1.com") enterprise_data_fixture.create_saml_auth_provider(domain="test1.com")
response = api_client.get(reverse("api:enterprise:sso:saml:login")) response = api_client.get(reverse("api:enterprise_sso_saml:login"))
assert response.status_code == HTTP_302_FOUND assert response.status_code == HTTP_302_FOUND
assert ( assert (
response.headers["Location"] response.headers["Location"]
@ -173,7 +173,7 @@ def test_user_can_initiate_saml_sso_with_enterprise_license(
enterprise_data_fixture.create_saml_auth_provider(domain="test1.com") enterprise_data_fixture.create_saml_auth_provider(domain="test1.com")
enterprise_data_fixture.enable_enterprise() enterprise_data_fixture.enable_enterprise()
sp_sso_saml_login_url = reverse("api:enterprise:sso:saml:login") sp_sso_saml_login_url = reverse("api:enterprise_sso_saml:login")
original_relative_url = "database/1/table/1/" original_relative_url = "database/1/table/1/"
request_query_string = urlencode({"original": original_relative_url}) request_query_string = urlencode({"original": original_relative_url})
@ -204,7 +204,7 @@ def test_user_can_initiate_saml_sso_with_enterprise_license(
@override_settings(DEBUG=True) @override_settings(DEBUG=True)
def test_saml_assertion_consumer_service(api_client, enterprise_data_fixture): def test_saml_assertion_consumer_service(api_client, enterprise_data_fixture):
user, _ = enterprise_data_fixture.create_enterprise_admin_user_and_token() user, _ = enterprise_data_fixture.create_enterprise_admin_user_and_token()
sp_sso_saml_acs_url = reverse("api:enterprise:sso:saml:acs") sp_sso_saml_acs_url = reverse("api:enterprise_sso_saml:acs")
( (
metadata, metadata,

View file

@ -654,6 +654,95 @@ def test_create_user_source_field_from_other_table(api_client, data_fixture):
assert response.json()["detail"]["name_field_id"][0]["code"] == "missing_table" assert response.json()["detail"]["name_field_id"][0]["code"] == "missing_table"
@pytest.mark.django_db
def test_user_source_create_user(data_fixture):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
database = data_fixture.create_database_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(
application=application, user=user
)
table_from_same_workspace1, fields, rows = data_fixture.build_table(
user=user,
database=database,
columns=[
("Email", "text"),
("Name", "text"),
],
rows=[
["test@baserow.io", "Test"],
],
)
email_field, name_field = fields
user_source = data_fixture.create_user_source_with_first_type(
integration=integration,
name="Test name",
table=table_from_same_workspace1,
email_field=email_field,
name_field=name_field,
uid="uid",
)
created_user = user_source.get_type().create_user(
user_source, "test2@baserow.io", "Test2"
)
model = table_from_same_workspace1.get_model()
assert created_user.email == "test2@baserow.io"
assert model.objects.count() == 2
assert getattr(model.objects.last(), email_field.db_column) == "test2@baserow.io"
@pytest.mark.django_db
def test_user_source_create_user_w_role(data_fixture):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
database = data_fixture.create_database_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(
application=application, user=user
)
table_from_same_workspace1, fields, rows = data_fixture.build_table(
user=user,
database=database,
columns=[
("Email", "text"),
("Name", "text"),
("Role", "text"),
],
rows=[
["test@baserow.io", "Test", "role1"],
],
)
email_field, name_field, role_field = fields
user_source = data_fixture.create_user_source_with_first_type(
integration=integration,
name="Test name",
table=table_from_same_workspace1,
email_field=email_field,
name_field=name_field,
role_field=role_field,
uid="uid",
)
created_user = user_source.get_type().create_user(
user_source, "test2@baserow.io", "Test2", "role2"
)
model = table_from_same_workspace1.get_model()
assert created_user.role == "role2"
assert getattr(model.objects.last(), role_field.db_column) == "role2"
@pytest.mark.django_db @pytest.mark.django_db
def test_export_user_source(data_fixture): def test_export_user_source(data_fixture):
user = data_fixture.create_user() user = data_fixture.create_user()
@ -803,7 +892,7 @@ def test_create_local_baserow_user_source_w_auth_providers(api_client, data_fixt
{ {
"type": "local_baserow_password", "type": "local_baserow_password",
"enabled": True, "enabled": True,
"domain": "test1", "domain": "test1.com",
"password_field_id": password_field.id, "password_field_id": password_field.id,
} }
], ],
@ -820,7 +909,7 @@ def test_create_local_baserow_user_source_w_auth_providers(api_client, data_fixt
assert response_json["auth_providers"] == [ assert response_json["auth_providers"] == [
{ {
"domain": "test1", "domain": "test1.com",
"id": first.id, "id": first.id,
"password_field_id": password_field.id, "password_field_id": password_field.id,
"type": "local_baserow_password", "type": "local_baserow_password",

View file

@ -0,0 +1,100 @@
from baserow_enterprise.api.sso.utils import get_valid_frontend_url
def test_get_valid_front_url():
assert get_valid_frontend_url() == "http://localhost:3000/dashboard"
assert (
get_valid_frontend_url("http://localhost:3000/dashboard")
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://localhost:3000/dashboard/after")
== "http://localhost:3000/dashboard/after"
)
assert (
get_valid_frontend_url("http://localhost:3000/other")
== "http://localhost:3000/other"
)
assert (
get_valid_frontend_url("http://localhost:3000/other", allow_any_path=False)
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://localhost:3000/")
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://something.com/")
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://something.com/dashboard/test")
== "http://localhost:3000/dashboard"
)
def test_get_valid_front_url_with_defaults():
defaults = ["https://test.com/toto", "http://random.net/"]
assert (
get_valid_frontend_url(default_frontend_urls=defaults)
== "https://test.com/toto"
)
assert (
get_valid_frontend_url("https://test.com/toto", default_frontend_urls=defaults)
== "https://test.com/toto"
)
assert (
get_valid_frontend_url("http://random.net/", default_frontend_urls=defaults)
== "http://random.net/"
)
assert (
get_valid_frontend_url(
"https://test.com/toto/subpath/", default_frontend_urls=defaults
)
== "https://test.com/toto/subpath/"
)
assert (
get_valid_frontend_url("https://test.com/titi/", default_frontend_urls=defaults)
== "https://test.com/titi/"
)
assert (
get_valid_frontend_url(
"https://test.com/titi/",
default_frontend_urls=defaults,
)
== "https://test.com/titi/"
)
assert (
get_valid_frontend_url(
"https://test.com/titi/",
default_frontend_urls=defaults,
allow_any_path=False,
)
== "https://test.com/toto"
)
assert (
get_valid_frontend_url("http://random.net/", default_frontend_urls=defaults)
== "http://random.net/"
)
assert (
get_valid_frontend_url("http://random.net/path", default_frontend_urls=defaults)
== "http://random.net/path"
)
assert (
get_valid_frontend_url("http://other.net/path", default_frontend_urls=defaults)
== "https://test.com/toto"
)
def test_get_valid_front_url_w_params():
assert (
get_valid_frontend_url(query_params={"test": "value"})
== "http://localhost:3000/dashboard?test=value"
)
assert (
get_valid_frontend_url(
"http://localhost:3000/dashboard", query_params={"test": "value"}
)
== "http://localhost:3000/dashboard?test=value"
)

View file

@ -14,3 +14,6 @@
@import 'long_text_field'; @import 'long_text_field';
@import 'highest_role_field'; @import 'highest_role_field';
@import 'auth_form_element'; @import 'auth_form_element';
@import 'common_saml_setting_form';
@import 'common_saml_setting_modal';
@import 'saml_auth_link';

View file

@ -15,3 +15,11 @@
font-size: var(--label-font-size, 13px); font-size: var(--label-font-size, 13px);
padding-top: 0.5em; padding-top: 0.5em;
} }
.auth-form-element__provider {
padding-top: 16px;
&:first-child {
padding-top: 0;
}
}

View file

@ -0,0 +1,13 @@
.common-saml-setting-form {
@include elevation($elevation-low);
@include rounded($rounded);
padding: 20px;
border: 1px solid $palette-neutral-200;
display: flex;
cursor: pointer;
&--error {
border-color: $palette-red-600;
}
}

View file

@ -0,0 +1,24 @@
.common-saml-setting-modal__url-block {
@include rounded($rounded);
position: relative;
font-family: monospace;
padding: 12px 16px;
background-color: $palette-neutral-50;
border: 1px solid $palette-neutral-400;
}
.common-saml-setting-modal__url {
padding-bottom: 8px;
cursor: pointer;
display: flex;
gap: 5px;
}
.common-saml-setting-modal__url-domain {
color: $palette-neutral-800;
}
.common-saml-setting-modal__url-dest {
@extend %ellipsis;
}

View file

@ -0,0 +1,6 @@
.saml-auth-link,
.saml-auth-link__modal-footer {
display: flex;
flex-direction: column;
font-size: var(--label-font-size, 13px);
}

View file

@ -42,68 +42,150 @@ export class PasswordAuthProviderType extends AuthProviderType {
return null return null
} }
/**
* We can create only one password provider.
*/
canCreateNew(authProviders) {
return (
!authProviders[this.getType()] ||
authProviders[this.getType()].length === 0
)
}
getOrder() { getOrder() {
return 1 return 1
} }
} }
export class SamlAuthProviderType extends AuthProviderType { export const SamlAuthProviderTypeMixin = (Base) =>
static getType() { class extends Base {
return 'saml' static getType() {
} return 'saml'
}
getIcon() {
return SAMLIcon getIcon() {
} return SAMLIcon
}
getVerifiedIcon() {
return VerifiedProviderIcon getVerifiedIcon() {
} return VerifiedProviderIcon
}
getName() {
return 'SSO SAML provider' getName() {
} return this.app.i18n.t('authProviderTypes.saml')
}
getProviderName(provider) {
return `SSO SAML: ${provider.domain}` getProviderName(provider) {
} if (provider.domain) {
return this.app.i18n.t('authProviderTypes.ssoSamlProviderName', {
getLoginActionComponent() { domain: provider.domain,
return SamlLoginAction })
} } else {
return this.app.i18n.t(
getAdminListComponent() { 'authProviderTypes.ssoSamlProviderNameUnconfigured'
return AuthProviderItem )
} }
}
getAdminSettingsFormComponent() {
return SamlSettingsForm getLoginActionComponent() {
return SamlLoginAction
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return SamlSettingsForm
}
getRelayStateUrl() {
return this.app.store.getters['authProviderAdmin/getType'](this.getType())
.relayStateUrl
}
getAcsUrl() {
return this.app.store.getters['authProviderAdmin/getType'](this.getType())
.acsUrl
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
redirectUrl: authProviderOption.redirect_url,
domainRequired: authProviderOption.domain_required,
...loginOptions,
}
}
handleServerError(vueComponentInstance, error) {
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
for (const [key, value] of Object.entries(error.handler.detail || {})) {
vueComponentInstance.serverErrors[key] = value
}
return true
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
acsUrl: authProviderType.acs_url,
relayStateUrl: authProviderType.relay_state_url,
...populated,
}
}
} }
export class SamlAuthProviderType extends SamlAuthProviderTypeMixin(
AuthProviderType
) {
getOrder() { getOrder() {
return 50 return 50
} }
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
redirectUrl: authProviderOption.redirect_url,
domainRequired: authProviderOption.domain_required,
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
acsUrl: authProviderType.acs_url,
relayStateUrl: authProviderType.relay_state_url,
...populated,
}
}
} }
export class GoogleAuthProviderType extends AuthProviderType { export const OAuth2AuthProviderTypeMixin = (Base) =>
class extends Base {
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getCallbackUrl(authProvider) {
if (!authProvider.id) {
const nextProviderId =
this.app.store.getters['authProviderAdmin/getNextProviderId']
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/`
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}
export class GoogleAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() { static getType() {
return 'google' return 'google'
} }
@ -120,38 +202,14 @@ export class GoogleAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : `Google` return provider.name ? provider.name : `Google`
} }
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getOrder() { getOrder() {
return 50 return 50
} }
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
} }
export class FacebookAuthProviderType extends AuthProviderType { export class FacebookAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() { static getType() {
return 'facebook' return 'facebook'
} }
@ -168,38 +226,14 @@ export class FacebookAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : this.getName() return provider.name ? provider.name : this.getName()
} }
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getOrder() { getOrder() {
return 50 return 50
} }
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
} }
export class GitHubAuthProviderType extends AuthProviderType { export class GitHubAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() { static getType() {
return 'github' return 'github'
} }
@ -216,35 +250,9 @@ export class GitHubAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : this.getName() return provider.name ? provider.name : this.getName()
} }
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getOrder() { getOrder() {
return 50 return 50
} }
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
} }
export class GitLabAuthProviderType extends AuthProviderType { export class GitLabAuthProviderType extends AuthProviderType {
@ -276,26 +284,23 @@ export class GitLabAuthProviderType extends AuthProviderType {
return GitLabSettingsForm return GitLabSettingsForm
} }
getCallbackUrl(authProvider) {
if (!authProvider.id) {
const nextProviderId =
this.app.store.getters['authProviderAdmin/getNextProviderId']
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/`
}
getOrder() { getOrder() {
return 50 return 50
} }
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
} }
export class OpenIdConnectAuthProviderType extends AuthProviderType { export class OpenIdConnectAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() { static getType() {
return 'openid_connect' return 'openid_connect'
} }
@ -312,33 +317,38 @@ export class OpenIdConnectAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : this.getName() return provider.name ? provider.name : this.getName()
} }
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() { getAdminSettingsFormComponent() {
return OpenIdConnectSettingsForm return OpenIdConnectSettingsForm
} }
getCallbackUrl(authProvider) {
if (!authProvider.id) {
const nextProviderId =
this.app.store.getters['authProviderAdmin/getNextProviderId']
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/`
}
handleServerError(vueComponentInstance, error) {
if (error.handler.code === 'ERROR_INVALID_PROVIDER_URL') {
vueComponentInstance.serverErrors = {
...vueComponentInstance.serverErrors,
baseUrl: error.handler.detail,
}
return true
}
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
vueComponentInstance.serverErrors = structuredClone(
error.handler.detail || {}
)
return true
}
getOrder() { getOrder() {
return 50 return 50
} }
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
} }

View file

@ -1,64 +1,30 @@
<template> <template>
<form <div v-if="hasAtLeastOneLoginOption" :style="fullStyle">
v-if="hasAtLeastOneLoginOption" <template v-for="appAuthType in appAuthProviderTypes">
class="auth-form-element" <div
:style="getStyleOverride('input')" v-if="hasAtLeastOneProvider(appAuthType)"
@submit.prevent="onLogin" :key="appAuthType.type"
> class="auth-form-element__provider"
<Error :error="error"></Error> >
<ABFormGroup <component
:label="$t('authFormElement.email')" :is="appAuthType.component"
:error-message=" :user-source="selectedUserSource"
$v.values.email.$dirty :auth-providers="appAuthProviderPerTypes[appAuthType.type]"
? !$v.values.email.required :login-button-label="resolvedLoginButtonLabel"
? $t('error.requiredField') @after-login="afterLogin"
: !$v.values.email.email />
? $t('error.invalidEmail') </div>
: '' </template>
: '' </div>
" <p v-else>
:autocomplete="isEditMode ? 'off' : ''" {{ $t('authFormElement.selectOrConfigureUserSourceFirst') }}
required </p>
>
<ABInput
v-model="values.email"
:placeholder="$t('authFormElement.emailPlaceholder')"
@blur="$v.values.email.$touch()"
/>
</ABFormGroup>
<ABFormGroup
:label="$t('authFormElement.password')"
:error-message="
$v.values.password.$dirty
? !$v.values.password.required
? $t('error.requiredField')
: ''
: ''
"
required
>
<ABInput
ref="passwordRef"
v-model="values.password"
type="password"
:placeholder="$t('authFormElement.passwordPlaceholder')"
@blur="$v.values.password.$touch()"
/>
</ABFormGroup>
<div :style="getStyleOverride('login_button')" class="auth-form__footer">
<ABButton :disabled="$v.$error" :loading="loading" size="large">
{{ resolvedLoginButtonLabel }}
</ABButton>
</div>
</form>
<p v-else>{{ $t('authFormElement.selectOrConfigureUserSourceFirst') }}</p>
</template> </template>
<script> <script>
import form from '@baserow/modules/core/mixins/form' import form from '@baserow/modules/core/mixins/form'
import error from '@baserow/modules/core/mixins/error' import error from '@baserow/modules/core/mixins/error'
import element from '@baserow/modules/builder/mixins/element' import element from '@baserow/modules/builder/mixins/element'
import { required, email } from 'vuelidate/lib/validators'
import { ensureString } from '@baserow/modules/core/utils/validator' import { ensureString } from '@baserow/modules/core/utils/validator'
import { mapActions } from 'vuex' import { mapActions } from 'vuex'
@ -79,12 +45,15 @@ export default {
}, },
}, },
data() { data() {
return { return {}
loading: false,
values: { email: '', password: '' },
}
}, },
computed: { computed: {
fullStyle() {
return {
...this.getStyleOverride('input'),
...this.getStyleOverride('login_button'),
}
},
selectedUserSource() { selectedUserSource() {
return this.$store.getters['userSource/getUserSourceById']( return this.$store.getters['userSource/getUserSourceById'](
this.builder, this.builder,
@ -97,8 +66,21 @@ export default {
} }
return this.$registry.get('userSource', this.selectedUserSource.type) return this.$registry.get('userSource', this.selectedUserSource.type)
}, },
isAuthenticated() { authProviders() {
return this.$store.getters['userSourceUser/isAuthenticated'](this.builder) return this.selectedUserSource?.auth_providers || []
},
appAuthProviderTypes() {
return this.$registry.getOrderedList('appAuthProvider')
},
appAuthProviderPerTypes() {
return Object.fromEntries(
this.appAuthProviderTypes.map((authType) => {
return [
authType.type,
this.authProviders.filter(({ type }) => type === authType.type),
]
})
)
}, },
loginOptions() { loginOptions() {
if (!this.selectedUserSourceType) { if (!this.selectedUserSourceType) {
@ -141,68 +123,15 @@ export default {
...mapActions({ ...mapActions({
actionForceUpdateElement: 'element/forceUpdate', actionForceUpdateElement: 'element/forceUpdate',
}), }),
async onLogin(event) { hasAtLeastOneProvider(authProviderType) {
if (this.isAuthenticated) { return (
await this.$store.dispatch('userSourceUser/logoff', { this.appAuthProviderPerTypes[authProviderType.getType()]?.length > 0
application: this.builder, )
})
}
this.$v.$touch()
if (this.$v.$invalid) {
this.focusOnFirstError()
return
}
this.loading = true
this.hideError()
try {
await this.$store.dispatch('userSourceUser/authenticate', {
application: this.builder,
userSource: this.selectedUserSource,
credentials: {
email: this.values.email,
password: this.values.password,
},
setCookie: this.mode === 'public',
})
this.values.password = ''
this.values.email = ''
this.$v.$reset()
this.fireEvent(
this.elementType.getEventByName(this.element, 'after_login')
)
} catch (error) {
if (error.handler) {
const response = error.handler.response
if (response && response.status === 401) {
this.values.password = ''
this.$v.$reset()
this.$v.$touch()
this.$refs.passwordRef.focus()
if (response.data?.error === 'ERROR_INVALID_CREDENTIALS') {
this.showError(
this.$t('error.incorrectCredentialTitle'),
this.$t('error.incorrectCredentialMessage')
)
}
} else {
const message = error.handler.getMessage('login')
this.showError(message)
}
error.handler.handled()
} else {
throw error
}
}
this.loading = false
}, },
}, afterLogin() {
validations: { this.fireEvent(
values: { this.elementType.getEventByName(this.element, 'after_login')
email: { required, email }, )
password: { required },
}, },
}, },
} }

View file

@ -10,8 +10,8 @@
class="context__menu-item-link" class="context__menu-item-link"
@click="$emit('create', authProviderType)" @click="$emit('create', authProviderType)"
> >
<AuthProviderIcon :icon="getIcon(authProviderType)" /> <AuthProviderIcon :icon="authProviderType.getIcon()" />
{{ getName(authProviderType) }} {{ authProviderType.getName() }}
</a> </a>
</li> </li>
</ul> </ul>
@ -32,13 +32,5 @@ export default {
required: true, required: true,
}, },
}, },
methods: {
getIcon(providerType) {
return this.$registry.get('authProvider', providerType.type).getIcon()
},
getName(providerType) {
return this.$registry.get('authProvider', providerType.type).getName()
},
},
} }
</script> </script>

View file

@ -107,59 +107,25 @@
<script> <script>
import { required, url } from 'vuelidate/lib/validators' import { required, url } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form' import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default { export default {
name: 'GitLabSettingsForm', name: 'GitLabSettingsForm',
mixins: [form], mixins: [authProviderForm],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
},
data() { data() {
return { return {
allowedValues: ['name', 'base_url', 'client_id', 'secret'], allowedValues: ['name', 'base_url', 'client_id', 'secret'],
values: { values: {
name: '', name: '',
base_url: '', base_url: 'https://gitlab.com',
client_id: '', client_id: '',
secret: '', secret: '',
}, },
} }
}, },
computed: { computed: {
providerName() {
return this.$registry
.get('authProvider', 'gitlab')
.getProviderName(this.authProvider)
},
callbackUrl() { callbackUrl() {
if (!this.authProvider.id) { return this.authProviderType.getCallbackUrl(this.authProvider)
const nextProviderId =
this.$store.getters['authProviderAdmin/getNextProviderId']
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/`
},
},
methods: {
getDefaultValues() {
return {
name: this.providerName,
base_url: this.authProvider.base_url || 'https://gitlab.com',
client_id: this.authProvider.client_id || '',
secret: this.authProvider.secret || '',
}
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
}, },
}, },
validations() { validations() {

View file

@ -81,23 +81,11 @@
<script> <script>
import { required } from 'vuelidate/lib/validators' import { required } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form' import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default { export default {
name: 'OAuth2SettingsForm', name: 'OAuth2SettingsForm',
mixins: [form], mixins: [authProviderForm],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
authProviderType: {
type: String,
required: false,
default: null,
},
},
data() { data() {
return { return {
allowedValues: ['name', 'client_id', 'secret'], allowedValues: ['name', 'client_id', 'secret'],
@ -109,37 +97,8 @@ export default {
} }
}, },
computed: { computed: {
providerName() {
const type = this.authProviderType
? this.authProviderType
: this.authProvider.type
return this.$registry
.get('authProvider', type)
.getProviderName(this.authProvider)
},
callbackUrl() { callbackUrl() {
if (!this.authProvider.id) { return this.authProviderType.getCallbackUrl(this.authProvider)
const nextProviderId =
this.$store.getters['authProviderAdmin/getNextProviderId']
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/`
},
},
methods: {
getDefaultValues() {
return {
name: this.providerName,
client_id: this.authProvider.client_id || '',
secret: this.authProvider.secret || '',
}
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
}, },
}, },
validations() { validations() {

View file

@ -109,23 +109,11 @@
<script> <script>
import { required, url } from 'vuelidate/lib/validators' import { required, url } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form' import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default { export default {
name: 'OpenIdConnectSettingsForm', name: 'OpenIdConnectSettingsForm',
mixins: [form], mixins: [authProviderForm],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
serverErrors: {
type: Object,
required: false,
default: () => ({}),
},
},
data() { data() {
return { return {
allowedValues: ['name', 'base_url', 'client_id', 'secret'], allowedValues: ['name', 'base_url', 'client_id', 'secret'],
@ -139,42 +127,7 @@ export default {
}, },
computed: { computed: {
callbackUrl() { callbackUrl() {
if (!this.authProvider.id) { return this.authProviderType.getCallbackUrl(this.authProvider)
const nextProviderId =
this.$store.getters['authProviderAdmin/getNextProviderId']
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/`
},
},
methods: {
getDefaultValues() {
return {
name: this.authProvider.name || '',
base_url: this.authProvider.base_url || '',
client_id: this.authProvider.client_id || '',
secret: this.authProvider.secret || '',
}
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
},
handleServerError(error) {
if (error.handler.code === 'ERROR_INVALID_PROVIDER_URL') {
this.serverErrors.baseUrl = error.handler.detail
return true
}
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
for (const [key, value] of Object.entries(error.handler.detail || {})) {
this.serverErrors[key] = value
}
return true
}, },
}, },
validations() { validations() {

View file

@ -21,9 +21,9 @@
ref="domain" ref="domain"
v-model="values.domain" v-model="values.domain"
size="large" size="large"
:error="fieldHasErrors('domain') || serverErrors.domain" :error="fieldHasErrors('domain') || !!serverErrors.domain"
:placeholder="$t('samlSettingsForm.domainPlaceholder')" :placeholder="$t('samlSettingsForm.domainPlaceholder')"
@input="serverErrors.domain = null" @input="onDomainInput()"
@blur="$v.values.domain.$touch()" @blur="$v.values.domain.$touch()"
></FormInput> ></FormInput>
<template #error> <template #error>
@ -50,16 +50,16 @@
small-label small-label
required required
:label="$t('samlSettingsForm.metadata')" :label="$t('samlSettingsForm.metadata')"
:error="fieldHasErrors('metadata')" :error="fieldHasErrors('metadata') || !!serverErrors.metadata"
class="margin-bottom-2" class="margin-bottom-2"
> >
<FormTextarea <FormTextarea
ref="metadata" ref="metadata"
v-model="values.metadata" v-model="values.metadata"
:rows="12" :rows="8"
:error="fieldHasErrors('metadata') || serverErrors.metadata" :error="fieldHasErrors('metadata') || !!serverErrors.metadata"
:placeholder="$t('samlSettingsForm.metadataPlaceholder')" :placeholder="$t('samlSettingsForm.metadataPlaceholder')"
@input="serverErrors.metadata = null" @input="onMetadataInput()"
@blur="$v.values.metadata.$touch()" @blur="$v.values.metadata.$touch()"
></FormTextarea> ></FormTextarea>
@ -73,23 +73,25 @@
</template> </template>
</FormGroup> </FormGroup>
<FormGroup <slot name="config">
small-label <FormGroup
required small-label
:label="$t('samlSettingsForm.relayStateUrl')" required
class="margin-bottom-2" :label="$t('samlSettingsForm.relayStateUrl')"
> class="margin-bottom-2"
<code>{{ getRelayStateUrl() }}</code> >
</FormGroup> <code>{{ getRelayStateUrl() }}</code>
</FormGroup>
<FormGroup <FormGroup
small-label small-label
required required
:label="$t('samlSettingsForm.acsUrl')" :label="$t('samlSettingsForm.acsUrl')"
class="margin-bottom-2" class="margin-bottom-2"
> >
<code>{{ getAcsUrl() }}</code> <code>{{ getAcsUrl() }}</code>
</FormGroup> </FormGroup>
</slot>
<Expandable card class="margin-bottom-2"> <Expandable card class="margin-bottom-2">
<template #header="{ toggle, expanded }"> <template #header="{ toggle, expanded }">
@ -110,7 +112,7 @@
</div> </div>
<div> <div>
{{ {{
usingDefaultAttrs() usingDefaultAttrs
? $t('samlSettingsForm.defaultAttrs') ? $t('samlSettingsForm.defaultAttrs')
: $t('samlSettingsForm.customAttrs') : $t('samlSettingsForm.customAttrs')
}} }}
@ -190,7 +192,7 @@
<script> <script>
import { maxLength, required, helpers } from 'vuelidate/lib/validators' import { maxLength, required, helpers } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form' import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
const alphanumericDotDashUnderscore = helpers.regex( const alphanumericDotDashUnderscore = helpers.regex(
'alphanumericDotDashUnderscore', 'alphanumericDotDashUnderscore',
@ -199,41 +201,32 @@ const alphanumericDotDashUnderscore = helpers.regex(
export default { export default {
name: 'SamlSettingsForm', name: 'SamlSettingsForm',
mixins: [form], mixins: [authProviderForm],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
authProviderType: {
type: String,
required: false,
default: null,
},
},
data() { data() {
return { return {
allowedValues: ['domain', 'metadata'], allowedValues: [
serverErrors: {}, 'domain',
'metadata',
'email_attr_key',
'first_name_attr_key',
'last_name_attr_key',
],
values: { values: {
domain: '', domain: '',
metadata: '', metadata: '',
email_attr_key: '', email_attr_key: 'user.email',
first_name_attr_key: '', first_name_attr_key: 'user.first_name',
last_name_attr_key: '', last_name_attr_key: 'user.last_name',
}, },
} }
}, },
computed: { computed: {
allSamlProviders() {
return this.authProviders.saml || []
},
samlDomains() { samlDomains() {
const samlAuthProviders = return this.allSamlProviders
this.$store.getters['authProviderAdmin/getAll'].saml?.authProviders || .filter((authProvider) => authProvider.id !== this.authProvider.id)
[]
return samlAuthProviders
.filter(
(authProvider) => authProvider.domain !== this.authProvider.domain
)
.map((authProvider) => authProvider.domain) .map((authProvider) => authProvider.domain)
}, },
defaultAttrs() { defaultAttrs() {
@ -244,10 +237,8 @@ export default {
} }
}, },
type() { type() {
return this.authProviderType || this.authProvider.type return this.authProviderType.getType()
}, },
},
methods: {
usingDefaultAttrs() { usingDefaultAttrs() {
return ( return (
this.values.email_attr_key === this.defaultAttrs.email_attr_key && this.values.email_attr_key === this.defaultAttrs.email_attr_key &&
@ -256,20 +247,13 @@ export default {
this.values.last_name_attr_key === this.defaultAttrs.last_name_attr_key this.values.last_name_attr_key === this.defaultAttrs.last_name_attr_key
) )
}, },
getDefaultValues() { },
const authProviderAttrs = { methods: {
email_attr_key: this.authProvider.email_attr_key, onDomainInput() {
first_name_attr_key: this.authProvider.first_name_attr_key, this.serverErrors.domain = null
last_name_attr_key: this.authProvider.last_name_attr_key, },
} onMetadataInput() {
const samlAttrs = this.authProvider.id this.serverErrors.metadata = null
? authProviderAttrs
: this.defaultAttrs
return {
domain: this.authProvider.domain || '',
metadata: this.authProvider.metadata || '',
...samlAttrs,
}
}, },
getFieldErrorMsg(fieldName) { getFieldErrorMsg(fieldName) {
if (!this.$v.values[fieldName].$dirty) { if (!this.$v.values[fieldName].$dirty) {
@ -285,33 +269,17 @@ export default {
} }
}, },
getRelayStateUrl() { getRelayStateUrl() {
return this.$store.getters['authProviderAdmin/getType'](this.type) return this.authProviderType.getRelayStateUrl()
.relayStateUrl
}, },
getAcsUrl() { getAcsUrl() {
return this.$store.getters['authProviderAdmin/getType'](this.type).acsUrl return this.authProviderType.getAcsUrl()
}, },
getVerifiedIcon() { getVerifiedIcon() {
return this.$registry.get('authProvider', this.type).getVerifiedIcon() return this.authProviderType.getVerifiedIcon()
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
}, },
mustHaveUniqueDomain(domain) { mustHaveUniqueDomain(domain) {
return !this.samlDomains.includes(domain.trim()) return !this.samlDomains.includes(domain.trim())
}, },
handleServerError(error) {
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
for (const [key, value] of Object.entries(error.handler.detail || {})) {
this.serverErrors[key] = value
}
return true
},
}, },
validations() { validations() {
return { return {

View file

@ -3,14 +3,15 @@
<h2 class="box__title"> <h2 class="box__title">
{{ {{
$t('createSettingsAuthProviderModal.title', { $t('createSettingsAuthProviderModal.title', {
type: getProviderTypeName(), type: authProviderType.getName(),
}) })
}} }}
</h2> </h2>
<div v-if="authProviderType"> <div>
<component <component
:is="getProviderAdminSettingsFormComponent()" :is="getProviderAdminSettingsFormComponent()"
ref="providerSettingsForm" ref="providerSettingsForm"
:auth-providers="appAuthProviderPerTypes"
:auth-provider-type="authProviderType" :auth-provider-type="authProviderType"
@submit="create($event)" @submit="create($event)"
> >
@ -21,8 +22,8 @@
</li> </li>
</ul> </ul>
<Button type="primary" :disabled="loading" :loading="loading"> <Button type="primary" :disabled="loading" :loading="loading">
{{ $t('action.create') }}</Button {{ $t('action.create') }}
> </Button>
</div> </div>
</component> </component>
</div> </div>
@ -38,9 +39,8 @@ export default {
mixins: [modal], mixins: [modal],
props: { props: {
authProviderType: { authProviderType: {
type: String, type: Object,
required: false, required: true,
default: null,
}, },
}, },
data() { data() {
@ -48,23 +48,30 @@ export default {
loading: false, loading: false,
} }
}, },
computed: {
authProviders() {
return this.$store.getters['authProviderAdmin/getAll']
},
appAuthProviderPerTypes() {
return Object.fromEntries(
this.$registry
.getOrderedList('authProvider')
.map((authProviderType) => [
authProviderType.getType(),
this.authProviders[authProviderType.getType()].authProviders,
])
)
},
},
methods: { methods: {
getProviderAdminSettingsFormComponent() { getProviderAdminSettingsFormComponent() {
return this.$registry return this.authProviderType.getAdminSettingsFormComponent()
.get('authProvider', this.authProviderType)
.getAdminSettingsFormComponent()
},
getProviderTypeName() {
if (!this.authProviderType) return ''
return this.$registry.get('authProvider', this.authProviderType).getName()
}, },
async create(values) { async create(values) {
this.loading = true this.loading = true
this.serverErrors = {}
try { try {
await this.$store.dispatch('authProviderAdmin/create', { await this.$store.dispatch('authProviderAdmin/create', {
type: this.authProviderType, type: this.authProviderType.getType(),
values, values,
}) })
this.$emit('created') this.$emit('created')

View file

@ -11,7 +11,10 @@
<component <component
:is="getProviderAdminSettingsFormComponent()" :is="getProviderAdminSettingsFormComponent()"
ref="providerSettingsForm" ref="providerSettingsForm"
:auth-providers="appAuthProviderPerTypes"
:auth-provider="authProvider" :auth-provider="authProvider"
:default-values="authProvider"
:auth-provider-type="authProviderType"
@submit="onSettingsUpdated" @submit="onSettingsUpdated"
> >
<div class="actions"> <div class="actions">
@ -22,8 +25,8 @@
</ul> </ul>
<Button type="primary" :disabled="loading" :loading="loading"> <Button type="primary" :disabled="loading" :loading="loading">
{{ $t('action.save') }}</Button {{ $t('action.save') }}
> </Button>
</div> </div>
</component> </component>
</div> </div>
@ -48,16 +51,30 @@ export default {
loading: false, loading: false,
} }
}, },
computed: {
authProviderType() {
return this.$registry.get('authProvider', this.authProvider.type)
},
authProviders() {
return this.$store.getters['authProviderAdmin/getAll']
},
appAuthProviderPerTypes() {
return Object.fromEntries(
this.$registry
.getOrderedList('authProvider')
.map((authProviderType) => [
authProviderType.getType(),
this.authProviders[authProviderType.getType()].authProviders,
])
)
},
},
methods: { methods: {
getProviderAdminSettingsFormComponent() { getProviderAdminSettingsFormComponent() {
return this.$registry return this.authProviderType.getAdminSettingsFormComponent()
.get('authProvider', this.authProvider.type)
.getAdminSettingsFormComponent()
}, },
getProviderName() { getProviderName() {
return this.$registry return this.authProviderType.getProviderName(this.authProvider)
.get('authProvider', this.authProvider.type)
.getProviderName(this.authProvider)
}, },
async onSettingsUpdated(values) { async onSettingsUpdated(values) {
this.loading = true this.loading = true

View file

@ -1,5 +1,10 @@
import { AppAuthProviderType } from '@baserow/modules/core/appAuthProviderTypes' import { AppAuthProviderType } from '@baserow/modules/core/appAuthProviderTypes'
import { SamlAuthProviderTypeMixin } from '@baserow_enterprise/authProviderTypes'
import LocalBaserowUserSourceForm from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm' import LocalBaserowUserSourceForm from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm'
import LocalBaserowAuthPassword from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowAuthPassword'
import CommonSamlSettingForm from '@baserow_enterprise/integrations/common/components/CommonSamlSettingForm'
import SamlAuthLink from '@baserow_enterprise/integrations/common/components/SamlAuthLink'
import { PasswordFieldType } from '@baserow/modules/database/fieldTypes' import { PasswordFieldType } from '@baserow/modules/database/fieldTypes'
export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType { export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType {
@ -11,6 +16,10 @@ export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType
return this.app.i18n.t('appAuthProviderType.localBaserowPassword') return this.app.i18n.t('appAuthProviderType.localBaserowPassword')
} }
get component() {
return LocalBaserowAuthPassword
}
get formComponent() { get formComponent() {
return LocalBaserowUserSourceForm return LocalBaserowUserSourceForm
} }
@ -23,7 +32,91 @@ export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType
return [PasswordFieldType.getType()] return [PasswordFieldType.getType()]
} }
getLoginOptions(authProvider) {
if (authProvider.password_field_id) {
return {}
}
return null
}
/**
* We can create only one password provider.
*/
canCreateNew(appAuthProviders) {
return (
!appAuthProviders[this.getType()] ||
appAuthProviders[this.getType()].length === 0
)
}
getOrder() { getOrder() {
return 10 return 10
} }
} }
export class SamlAppAuthProviderType extends SamlAuthProviderTypeMixin(
AppAuthProviderType
) {
get name() {
return this.app.i18n.t('appAuthProviderType.commonSaml')
}
get component() {
return SamlAuthLink
}
get formComponent() {
return CommonSamlSettingForm
}
handleServerError(vueComponentInstance, error) {
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
if (error.handler.detail?.auth_providers?.length > 0) {
const flatProviders = Object.entries(vueComponentInstance.authProviders)
.map(([, providers]) => providers)
.flat()
// Sort per ID to make sure we have the same order
// as the backend
.sort((a, b) => a.id - b.id)
for (const [
index,
authError,
] of error.handler.detail.auth_providers.entries()) {
if (
Object.keys(authError).length > 0 &&
flatProviders[index].id === vueComponentInstance.authProvider.id
) {
vueComponentInstance.serverErrors = {
...vueComponentInstance.serverErrors,
...authError,
}
return true
}
}
}
return false
}
getAuthToken(userSource, authProvider, route) {
// token can be in the query string (SSO) or in the cookies (previous session)
// We use the user source id in order to prevent conflicts when using multiple
// auth forms on the same page.
const queryParamName = `user_source_saml_token__${userSource.id}`
return route.query[queryParamName]
}
handleError(userSource, authProvider, route) {
const queryParamName = `saml_error__${userSource.id}`
const errorCode = route.query[queryParamName]
if (errorCode) {
return { message: this.app.i18n.t(`loginError.${errorCode}`), code: 500 }
}
}
getOrder() {
return 20
}
}

View file

@ -0,0 +1,83 @@
<template>
<div>
<div
class="common-saml-setting-form"
:class="{ 'common-saml-setting-form--error': inError }"
>
<Presentation
class="flex-grow-1"
:title="authProviderType.getProviderName(authProvider)"
size="medium"
avatar-color="neutral"
:image="authProviderType.getIcon()"
@click="onEdit"
/>
<div class="common-saml-setting-form__actions">
<ButtonIcon
type="secondary"
icon="iconoir-edit"
@click.prevent="onEdit"
/>
<ButtonIcon
type="secondary"
icon="iconoir-bin"
@click.prevent="onDelete"
/>
</div>
</div>
<div v-if="inError" class="error">
{{ $t('commonSamlSettingForm.authProviderInError') }}
</div>
<CommonSamlSettingModal
ref="samlModal"
v-bind="$props"
:integration="integration"
:user-source="userSource"
@form-valid="onFormValid($event)"
v-on="$listeners"
></CommonSamlSettingModal>
</div>
</template>
<script>
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
import CommonSamlSettingModal from '@baserow_enterprise/integrations/common/components/CommonSamlSettingModal'
export default {
name: 'CommonSamlSettingForm',
components: { CommonSamlSettingModal },
mixins: [authProviderForm],
inject: ['builder'],
props: {
integration: {
type: Object,
required: true,
},
userSource: {
type: Object,
required: true,
},
},
data() {
return { inError: false }
},
methods: {
onFormValid(value) {
this.inError = !value
},
onEdit() {
this.$refs.samlModal.show()
},
onDelete() {
this.$emit('delete')
},
handleServerError(error) {
if (this.$refs.samlModal.handleServerError(error)) {
this.inError = true
return true
}
return false
},
},
}
</script>

View file

@ -0,0 +1,172 @@
<template>
<Modal ref="modal" keep-content @hidden="onHide">
<h2 class="box__title">{{ $t('commonSamlSettingModal.title') }}</h2>
<div>
<SamlSettingsForm
v-bind="$props"
ref="samlForm"
@values-changed="checkValidity"
v-on="$listeners"
>
<template #config>
<FormGroup
small-label
required
:label="$t('commonSamlSettingModal.relayStateTitle')"
class="margin-bottom-2"
>
<div class="common-saml-setting-modal__url-block">
<div
v-for="conf in config"
:key="conf.name"
class="common-saml-setting-modal__url"
@click.prevent="
;[copyToClipboard(conf.relay), $refs.copiedRelay.show()]
"
>
<span class="common-saml-setting-modal__url-domain">
{{ conf.name }}
</span>
<span
class="common-saml-setting-modal__url-dest"
:title="conf.relay"
>
{{ conf.relay }}
</span>
</div>
<Copied ref="copiedRelay"></Copied>
</div>
</FormGroup>
<FormGroup
small-label
required
:label="$t('commonSamlSettingModal.acsTitle')"
class="margin-bottom-2"
>
<div class="common-saml-setting-modal__url-block">
<div
v-for="conf in config"
:key="conf.name"
class="common-saml-setting-modal__url"
@click.prevent="
;[copyToClipboard(conf.acs), $refs.copiedACS.show()]
"
>
<span class="common-saml-setting-modal__url-domain">
{{ conf.name }}
</span>
<span
class="common-saml-setting-modal__url-dest"
:title="conf.acs"
>
{{ conf.acs }}
</span>
</div>
<Copied ref="copiedACS"></Copied>
</div>
</FormGroup>
</template>
</SamlSettingsForm>
<div class="actions actions--right">
<Button size="large" @click.prevent="$refs.modal.hide()">
{{ $t('action.close') }}
</Button>
</div>
</div>
</Modal>
</template>
<script>
import SamlSettingsForm from '@baserow_enterprise/components/admin/forms/SamlSettingsForm'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
import error from '@baserow/modules/core/mixins/error'
import { copyToClipboard } from '@baserow/modules/database/utils/clipboard'
import { mapActions, mapGetters } from 'vuex'
import modal from '@baserow/modules/core/mixins/modal'
export default {
name: 'CommonSamlSettingsModal',
components: { SamlSettingsForm },
mixins: [error, authProviderForm, modal],
inject: ['builder'],
props: {
integration: {
type: Object,
required: true,
},
userSource: {
type: Object,
required: true,
},
},
async fetch() {
try {
await this.actionFetchDomains({ builderId: this.builder.id })
} catch (error) {
this.handleError(error)
}
},
watch: {
'$v.$anyDirty'() {
// Force validity refresh on child touch
this.checkValidity()
},
},
computed: {
...mapGetters({ domains: 'domain/getDomains' }),
config() {
const previewRelay = `${this.$config.PUBLIC_WEB_FRONTEND_URL}/builder/${this.builder.id}/preview/`
const previewACS = `${this.$config.PUBLIC_BACKEND_URL}/api/user-source/${this.userSource.uid}/sso/saml/acs/`
const preview = [
{
name: this.$t('commonSamlSettingModal.preview'),
acs: previewACS,
relay: previewRelay,
},
]
const others = this.domains.map((domain) => ({
name: domain.domain_name,
acs: `${this.$config.PUBLIC_BACKEND_URL}/api/user-source/domain_${domain.id}__${this.userSource.uid}/sso/saml/acs/`,
relay: this.getDomainUrl(domain),
}))
return [...preview, ...others]
},
},
methods: {
...mapActions({
actionFetchDomains: 'domain/fetch',
}),
copyToClipboard(value) {
copyToClipboard(value)
},
onHide() {
this.checkValidity()
},
checkValidity() {
if (
!this.$refs.samlForm.isFormValid() &&
this.$refs.samlForm.$v.$anyDirty
) {
this.$emit('form-valid', false)
} else {
this.$emit('form-valid', true)
}
},
getDomainUrl(domain) {
const url = new URL(this.$config.PUBLIC_WEB_FRONTEND_URL)
return `${url.protocol}//${domain.domain_name}${
url.port ? `:${url.port}` : ''
}`
},
handleServerError(error) {
return this.$refs.samlForm.handleServerError(error)
},
},
validations() {
// Keep this to get the `$v` property
return {}
},
}
</script>

View file

@ -0,0 +1,135 @@
<template>
<div class="saml-auth-link">
<ABButton @click.prevent="onClick()">
{{ buttonLabel }}
</ABButton>
<Modal ref="modal">
<ThemeProvider>
<form @submit.prevent="">
<ABFormGroup
v-if="hasMultipleSamlProvider"
:label="$t('samlAuthLink.provideEmail')"
:error-message="
$v.values.email.$dirty
? !$v.values.email.required
? $t('error.requiredField')
: !$v.values.email.email
? $t('error.invalidEmail')
: ''
: ''
"
:autocomplete="isEditMode ? 'off' : ''"
required
>
<ABInput
v-model="values.email"
:placeholder="$t('samlAuthLink.emailPlaceholder')"
@blur="$v.values.email.$touch()"
/>
</ABFormGroup>
<div class="saml-auth-link__modal-footer">
<ABButton class="margin-top-2" @click.prevent.stop="login()">
{{ buttonLabel }}
</ABButton>
</div>
</form>
</ThemeProvider>
</Modal>
</div>
</template>
<script>
import form from '@baserow/modules/core/mixins/form'
import error from '@baserow/modules/core/mixins/error'
import { required, email } from 'vuelidate/lib/validators'
import ThemeProvider from '@baserow/modules/builder/components/theme/ThemeProvider'
export default {
components: { ThemeProvider },
mixins: [form, error],
inject: ['builder', 'mode'],
props: {
userSource: { type: Object, required: true },
authProviders: {
type: Array,
required: true,
},
loginButtonLabel: {
type: String,
required: true,
},
},
data() {
return {
loading: false,
values: { email: '' },
}
},
computed: {
isAuthenticated() {
return this.$store.getters['userSourceUser/isAuthenticated'](this.builder)
},
hasMultipleSamlProvider() {
return this.authProviders.length > 1
},
isEditMode() {
return this.mode === 'editing'
},
buttonLabel() {
return this.$t('samlAuthLink.placeholderWithSaml', {
login: this.loginButtonLabel,
})
},
},
methods: {
async onClick() {
if (this.hasMultipleSamlProvider) {
this.$refs.modal.show()
} else {
await this.login()
}
},
async login() {
if (this.isAuthenticated) {
await this.$store.dispatch('userSourceUser/logoff', {
application: this.builder,
})
}
if (this.hasMultipleSamlProvider) {
this.$v.$touch()
if (this.$v.$invalid) {
this.focusOnFirstError()
return
}
}
this.loading = true
this.hideError()
const dest = `${
this.$config.PUBLIC_BACKEND_URL
}/api/user-source/${encodeURIComponent(
this.userSource.uid
)}/sso/saml/login/`
const urlWithParams = new URL(dest)
if (this.hasMultipleSamlProvider) {
urlWithParams.searchParams.append('email', this.values.email)
}
// Add the current url as get parameter to be redirected here after the login.
urlWithParams.searchParams.append('original', window.location)
window.location = urlWithParams.toString()
},
},
validations: {
values: {
email: { required, email },
},
},
}
</script>

View file

@ -0,0 +1,148 @@
<template>
<form class="auth-form-element" @submit.prevent="onLogin">
<Error :error="error"></Error>
<ABFormGroup
:label="$t('authFormElement.email')"
:error-message="
$v.values.email.$dirty
? !$v.values.email.required
? $t('error.requiredField')
: !$v.values.email.email
? $t('error.invalidEmail')
: ''
: ''
"
:autocomplete="isEditMode ? 'off' : ''"
required
>
<ABInput
v-model="values.email"
:placeholder="$t('authFormElement.emailPlaceholder')"
@blur="$v.values.email.$touch()"
/>
</ABFormGroup>
<ABFormGroup
:label="$t('authFormElement.password')"
:error-message="
$v.values.password.$dirty
? !$v.values.password.required
? $t('error.requiredField')
: ''
: ''
"
required
>
<ABInput
ref="passwordRef"
v-model="values.password"
type="password"
:placeholder="$t('authFormElement.passwordPlaceholder')"
@blur="$v.values.password.$touch()"
/>
</ABFormGroup>
<div class="auth-form__footer">
<ABButton :disabled="$v.$error" :loading="loading" size="large">
{{ loginButtonLabel }}
</ABButton>
</div>
</form>
</template>
<script>
import form from '@baserow/modules/core/mixins/form'
import error from '@baserow/modules/core/mixins/error'
import { required, email } from 'vuelidate/lib/validators'
export default {
mixins: [form, error],
inject: ['builder', 'mode'],
props: {
userSource: { type: Object, required: true },
authProviders: {
type: Array,
required: true,
},
loginButtonLabel: {
type: String,
required: true,
},
},
data() {
return {
loading: false,
values: { email: '', password: '' },
}
},
computed: {
isAuthenticated() {
return this.$store.getters['userSourceUser/isAuthenticated'](this.builder)
},
isEditMode() {
return this.mode === 'editing'
},
},
methods: {
async onLogin(event) {
if (this.isAuthenticated) {
await this.$store.dispatch('userSourceUser/logoff', {
application: this.builder,
})
}
this.$v.$touch()
if (this.$v.$invalid) {
this.focusOnFirstError()
return
}
this.loading = true
this.hideError()
try {
await this.$store.dispatch('userSourceUser/authenticate', {
application: this.builder,
userSource: this.userSource,
credentials: {
email: this.values.email,
password: this.values.password,
},
setCookie: this.mode === 'public',
})
this.values.password = ''
this.values.email = ''
this.$v.$reset()
this.$emit('after-login')
} catch (error) {
if (error.handler) {
const response = error.handler.response
if (response && response.status === 401) {
this.values.password = ''
this.$v.$reset()
this.$v.$touch()
this.$refs.passwordRef.focus()
if (response.data?.error === 'ERROR_INVALID_CREDENTIALS') {
this.showError(
this.$t('error.incorrectCredentialTitle'),
this.$t('error.incorrectCredentialMessage')
)
}
} else {
const message = error.handler.getMessage('login')
this.showError(message)
}
error.handler.handled()
} else {
throw error
}
}
this.loading = false
},
},
validations: {
values: {
email: { required, email },
password: { required },
},
},
}
</script>

View file

@ -5,7 +5,6 @@
small-label small-label
horizontal horizontal
horizontal-variable horizontal-variable
class="margin-top-2"
required required
> >
<Dropdown <Dropdown
@ -32,16 +31,16 @@
</template> </template>
<script> <script>
import form from '@baserow/modules/core/mixins/form' import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default { export default {
mixins: [form], mixins: [authProviderForm],
props: { props: {
integration: { integration: {
type: Object, type: Object,
required: true, required: true,
}, },
currentUserSource: { userSource: {
type: Object, type: Object,
required: true, required: true,
}, },
@ -55,9 +54,6 @@ export default {
} }
}, },
computed: { computed: {
authProviderType() {
return this.$registry.get('appAuthProvider', 'local_baserow_password')
},
databases() { databases() {
return this.integration.context_data.databases return this.integration.context_data.databases
}, },
@ -65,12 +61,12 @@ export default {
return this.$registry.getAll('field') return this.$registry.getAll('field')
}, },
selectedTable() { selectedTable() {
if (!this.currentUserSource.table_id) { if (!this.userSource.table_id) {
return null return null
} }
for (const database of this.databases) { for (const database of this.databases) {
for (const table of database.tables) { for (const table of database.tables) {
if (table.id === this.currentUserSource.table_id) { if (table.id === this.userSource.table_id) {
return table return table
} }
} }
@ -91,7 +87,7 @@ export default {
}, },
}, },
watch: { watch: {
'currentUserSource.table_id'() { 'userSource.table_id'() {
this.values.password_field_id = null this.values.password_field_id = null
}, },
}, },

View file

@ -117,17 +117,22 @@ export class LocalBaserowUserSourceType extends UserSourceType {
if (!userSource.email_field_id || !userSource.name_field_id) { if (!userSource.email_field_id || !userSource.name_field_id) {
return {} return {}
} }
if (userSource.auth_providers.length !== 1) {
return {} return userSource.auth_providers.reduce((acc, authProvider) => {
} if (!acc[authProvider.type]) {
const authProvider = userSource.auth_providers[0] acc[authProvider.type] = []
if ( }
authProvider.type !== 'local_baserow_password' ||
!authProvider.password_field_id const loginOptions = this.app.$registry
) { .get('appAuthProvider', authProvider.type)
return {} .getLoginOptions(authProvider)
}
return { password: {} } if (loginOptions) {
acc[authProvider.type].push(loginOptions)
}
return acc
}, {})
} }
getOrder() { getOrder() {

View file

@ -95,14 +95,17 @@
"addProvider": "Add provider" "addProvider": "Add provider"
}, },
"authProviderTypes": { "authProviderTypes": {
"password": "Email and password authentication" "password": "Email and password authentication",
"saml": "SSO SAML provider",
"ssoSamlProviderName": "SSO SAML: {domain}",
"ssoSamlProviderNameUnconfigured": "Unconfigured SSO SAML"
}, },
"editAuthProviderMenuContext": { "editAuthProviderMenuContext": {
"edit": "Edit", "edit": "Edit",
"delete": "Delete" "delete": "Delete"
}, },
"samlSettingsForm": { "samlSettingsForm": {
"domain": "Domain", "domain": "SAML Domain",
"domainPlaceholder": "Insert the company domain name...", "domainPlaceholder": "Insert the company domain name...",
"invalidDomain": "Invalid domain name", "invalidDomain": "Invalid domain name",
"domainAlreadyExists": "A SAML provider for this domain already exists", "domainAlreadyExists": "A SAML provider for this domain already exists",
@ -311,12 +314,22 @@
"roleFieldPlaceholder": "Select a field..." "roleFieldPlaceholder": "Select a field..."
}, },
"appAuthProviderType": { "appAuthProviderType": {
"localBaserowPassword": "Email/Password" "localBaserowPassword": "Email/Password",
"commonSaml": "Saml SSO"
}, },
"localBaserowPasswordAppAuthProviderForm": { "localBaserowPasswordAppAuthProviderForm": {
"passwordFieldLabel": "Select password field", "passwordFieldLabel": "Select password field",
"noFields": "No compatible fields" "noFields": "No compatible fields"
}, },
"commonSamlSettingModal": {
"title": "Edit provider",
"relayStateTitle": "Default Relay State URL per domain (Click to copy)",
"acsTitle": "Single Sign On URL per domain (Click to copy)",
"preview": "Preview"
},
"commonSamlSettingForm": {
"authProviderInError": "Please edit this provider to fix the error."
},
"enterpriseSettings": { "enterpriseSettings": {
"branding": "Branding", "branding": "Branding",
"showBaserowHelpMessage": "Show help message", "showBaserowHelpMessage": "Show help message",
@ -383,5 +396,11 @@
"projectIdHelper": "The ID of the project where you want to sync the issues from. Can be found by going to your project page (e.g. https://gitlab.com/baserow/baserow), click on the three dots in the top right corner, and then on 'Copy project ID: 12345678'.", "projectIdHelper": "The ID of the project where you want to sync the issues from. Can be found by going to your project page (e.g. https://gitlab.com/baserow/baserow), click on the three dots in the top right corner, and then on 'Copy project ID: 12345678'.",
"accessToken": "Access token", "accessToken": "Access token",
"accessTokenHelper": "Can be generated here https://gitlab.com/-/user_settings/personal_access_tokens by clicking on `Add new token`, select `read_api`." "accessTokenHelper": "Can be generated here https://gitlab.com/-/user_settings/personal_access_tokens by clicking on `Add new token`, select `read_api`."
},
"samlAuthLink": {
"loginWithSaml": "Login with SAML",
"placeholderWithSaml": "{login} with SAML",
"provideEmail": "Provide your SAML account email",
"emailPlaceholder": "Enter your email..."
} }
} }

View file

@ -13,6 +13,7 @@
@create="showCreateModal($event)" @create="showCreateModal($event)"
/> />
<CreateAuthProviderModal <CreateAuthProviderModal
v-if="authProviderTypeToCreate"
ref="createModal" ref="createModal"
:auth-provider-type="authProviderTypeToCreate" :auth-provider-type="authProviderTypeToCreate"
@created="$refs.createModal.hide()" @created="$refs.createModal.hide()"
@ -57,9 +58,15 @@ export default {
}, },
computed: { computed: {
...mapGetters({ ...mapGetters({
authProviderMap: 'authProviderAdmin/getAll',
authProviders: 'authProviderAdmin/getAllOrdered', authProviders: 'authProviderAdmin/getAllOrdered',
authProviderTypesCanBeCreated: 'authProviderAdmin/getCreatableTypes',
}), }),
authProviderTypesCanBeCreated() {
return Object.values(this.$registry.getAll('authProvider')).filter(
(authProviderType) =>
authProviderType.canCreateNew(this.authProviderMap)
)
},
}, },
methods: { methods: {
getAdminListComponent(authProvider) { getAdminListComponent(authProvider) {
@ -75,8 +82,10 @@ export default {
4 4
) )
}, },
showCreateModal(authProviderType) { async showCreateModal(authProviderType) {
this.authProviderTypeToCreate = authProviderType.type this.authProviderTypeToCreate = authProviderType
// Wait for the modal to appear in DOM
await this.$nextTick()
this.$refs.createModal.show() this.$refs.createModal.show()
this.$refs.createContext.hide() this.$refs.createContext.hide()
}, },

View file

@ -27,7 +27,10 @@ import {
} from '@baserow_enterprise/licenseTypes' } from '@baserow_enterprise/licenseTypes'
import { EnterprisePlugin } from '@baserow_enterprise/plugins' import { EnterprisePlugin } from '@baserow_enterprise/plugins'
import { LocalBaserowUserSourceType } from '@baserow_enterprise/integrations/userSourceTypes' import { LocalBaserowUserSourceType } from '@baserow_enterprise/integrations/userSourceTypes'
import { LocalBaserowPasswordAppAuthProviderType } from '@baserow_enterprise/integrations/appAuthProviderTypes' import {
LocalBaserowPasswordAppAuthProviderType,
SamlAppAuthProviderType,
} from '@baserow_enterprise/integrations/appAuthProviderTypes'
import { AuthFormElementType } from '@baserow_enterprise/builder/elementTypes' import { AuthFormElementType } from '@baserow_enterprise/builder/elementTypes'
import { import {
EnterpriseAdminRoleType, EnterpriseAdminRoleType,
@ -46,6 +49,8 @@ import {
GitLabIssuesDataSyncType, GitLabIssuesDataSyncType,
} from '@baserow_enterprise/dataSyncTypes' } from '@baserow_enterprise/dataSyncTypes'
import { FF_AB_SSO } from '@baserow/modules/core/plugins/featureFlags'
export default (context) => { export default (context) => {
const { app, isDev, store } = context const { app, isDev, store } = context
@ -115,6 +120,13 @@ export default (context) => {
new LocalBaserowPasswordAppAuthProviderType(context) new LocalBaserowPasswordAppAuthProviderType(context)
) )
if (app.$featureFlagIsEnabled(FF_AB_SSO)) {
app.$registry.register(
'appAuthProvider',
new SamlAppAuthProviderType(context)
)
}
app.$registry.register('roles', new EnterpriseAdminRoleType(context)) app.$registry.register('roles', new EnterpriseAdminRoleType(context))
app.$registry.register('roles', new EnterpriseMemberRoleType(context)) app.$registry.register('roles', new EnterpriseMemberRoleType(context))
app.$registry.register('roles', new EnterpriseBuilderRoleType(context)) app.$registry.register('roles', new EnterpriseBuilderRoleType(context))

View file

@ -118,15 +118,6 @@ export const getters = {
} }
return authProviders return authProviders
}, },
getCreatableTypes: (state) => {
const items = []
for (const authProviderType of Object.values(state.items)) {
if (authProviderType.canCreateNewProviders) {
items.push(authProviderType)
}
}
return items
},
getNextProviderId: (state) => { getNextProviderId: (state) => {
return state.nextProviderId return state.nextProviderId
}, },

View file

@ -16,17 +16,7 @@
</p> </p>
<p v-else class="placeholder__content">{{ content }}</p> <p v-else class="placeholder__content">{{ content }}</p>
<div class="placeholder__action"> <div class="placeholder__action">
<Button <Button type="primary" icon="iconoir-home" size="large" @click="onHome()">
type="primary"
icon="iconoir-home"
size="large"
@click="
$router.go({
name: 'application-builder-page',
params: { pathMatch: '/' },
})
"
>
{{ $t('action.backToHome') }} {{ $t('action.backToHome') }}
</Button> </Button>
</div> </div>
@ -62,5 +52,15 @@ export default {
return this.error.content || this.$t('errorLayout.error') return this.error.content || this.$t('errorLayout.error')
}, },
}, },
methods: {
onHome() {
this.$router.push({
name: 'application-builder-page',
params: { pathMatch: '/' },
// We remove the query parameters. Important if we have some with error
query: {},
})
},
},
} }
</script> </script>

View file

@ -11,10 +11,7 @@
<slot name="title">{{ <slot name="title">{{
$t('pageVisibilitySettingsTypes.logInPageWarningTitle') $t('pageVisibilitySettingsTypes.logInPageWarningTitle')
}}</slot> }}</slot>
<!-- eslint-disable-next-line vue/no-v-html vue/no-v-text-v-html-on-component --> <p>{{ $t('pageVisibilitySettingsTypes.logInPagewarningMessage') }}</p>
<p
v-html="$t('pageVisibilitySettingsTypes.logInPagewarningMessage')"
></p>
</Alert> </Alert>
<Alert <Alert
v-else-if="showLoginPageAlert && !showLogInPageWarning" v-else-if="showLoginPageAlert && !showLogInPageWarning"
@ -23,14 +20,13 @@
<slot name="title">{{ <slot name="title">{{
$t('pageVisibilitySettingsTypes.logInPageInfoTitle') $t('pageVisibilitySettingsTypes.logInPageInfoTitle')
}}</slot> }}</slot>
<!-- eslint-disable-next-line vue/no-v-html vue/no-v-text-v-html-on-component --> <p>
<p {{
v-html="
$t('pageVisibilitySettingsTypes.logInPageInfoMessage', { $t('pageVisibilitySettingsTypes.logInPageInfoMessage', {
logInPageName: loginPageName, logInPageName: loginPageName,
}) })
" }}
></p> </p>
</Alert> </Alert>
</div> </div>
<div class="margin-top-1 visibility-form__visibility-all"> <div class="margin-top-1 visibility-form__visibility-all">

View file

@ -53,7 +53,6 @@
<UpdateUserSourceForm <UpdateUserSourceForm
ref="userSourceForm" ref="userSourceForm"
:builder="builder" :builder="builder"
:integrations="integrations"
:user-source-type="getUserSourceType(editedUserSource)" :user-source-type="getUserSourceType(editedUserSource)"
:default-values="editedUserSource" :default-values="editedUserSource"
@submitted="updateUserSource" @submitted="updateUserSource"
@ -71,7 +70,7 @@
:disabled="actionInProgress || invalidForm" :disabled="actionInProgress || invalidForm"
:loading="actionInProgress" :loading="actionInProgress"
size="large" size="large"
@click="$refs.userSourceForm.submit()" @click="$refs.userSourceForm.submit(true)"
> >
{{ $t('action.save') }} {{ $t('action.save') }}
</Button> </Button>
@ -86,7 +85,6 @@
<CreateUserSourceForm <CreateUserSourceForm
ref="userSourceForm" ref="userSourceForm"
:builder="builder" :builder="builder"
:integrations="integrations"
@submitted="createUserSource" @submitted="createUserSource"
@values-changed="onValueChange" @values-changed="onValueChange"
/> />
@ -122,6 +120,9 @@ export default {
name: 'UserSourceSettings', name: 'UserSourceSettings',
components: { CreateUserSourceForm, UpdateUserSourceForm }, components: { CreateUserSourceForm, UpdateUserSourceForm },
mixins: [error], mixins: [error],
provide() {
return { builder: this.builder }
},
props: { props: {
builder: { builder: {
type: Object, type: Object,
@ -169,7 +170,7 @@ export default {
return this.$registry.get('userSource', userSource.type) return this.$registry.get('userSource', userSource.type)
}, },
onValueChange() { onValueChange() {
this.invalidForm = !this.$refs.userSourceForm.isFormValid() this.invalidForm = !this.$refs.userSourceForm.isFormValid(true)
}, },
async showForm(userSourceToEdit) { async showForm(userSourceToEdit) {
if (userSourceToEdit) { if (userSourceToEdit) {
@ -205,6 +206,10 @@ export default {
this.actionInProgress = false this.actionInProgress = false
}, },
async updateUserSource(newValues) { async updateUserSource(newValues) {
if (!this.$refs.userSourceForm.isFormValid(true)) {
return
}
this.actionInProgress = true this.actionInProgress = true
try { try {
await this.actionUpdateUserSource({ await this.actionUpdateUserSource({
@ -215,8 +220,10 @@ export default {
this.hideForm() this.hideForm()
} catch (error) { } catch (error) {
// Restore the previously saved values from the store // Restore the previously saved values from the store
this.$refs.userSourceForm.reset() if (!this.$refs.userSourceForm.handleServerError(error)) {
this.handleError(error) this.$refs.userSourceForm.reset()
this.handleError(error)
}
} }
this.actionInProgress = false this.actionInProgress = false
}, },

View file

@ -67,10 +67,6 @@ export default {
type: Object, type: Object,
required: true, required: true,
}, },
integrations: {
type: Array,
required: true,
},
}, },
data() { data() {
return { return {
@ -78,6 +74,9 @@ export default {
} }
}, },
computed: { computed: {
integrations() {
return this.$store.getters['integration/getIntegrations'](this.builder)
},
userSources() { userSources() {
return this.$store.getters['userSource/getUserSources'](this.builder) return this.$store.getters['userSource/getUserSources'](this.builder)
}, },
@ -117,6 +116,10 @@ export default {
} }
return '' return ''
}, },
handleServerError() {
return false
},
}, },
validations: { validations: {
values: { values: {

View file

@ -47,13 +47,28 @@
<div <div
v-for="appAuthType in appAuthProviderTypes" v-for="appAuthType in appAuthProviderTypes"
:key="appAuthType.type" :key="appAuthType.type"
class="update-user-source-form__auth-provider"
> >
<Checkbox <div class="update-user-source-form__auth-provider-header">
:checked="hasAtLeastOneOfThisType(appAuthType)" <Checkbox
@input="onSelect(appAuthType)" :checked="hasAtLeastOneOfThisType(appAuthType)"
> @input="onSelect(appAuthType)"
{{ appAuthType.name }} >
</Checkbox> {{ appAuthType.name }}
</Checkbox>
<ButtonText
v-if="
hasAtLeastOneOfThisType(appAuthType) &&
appAuthType.canCreateNew(appAuthProviderPerTypes)
"
icon="iconoir-plus"
type="secondary"
@click.prevent="addNew(appAuthType)"
>
{{ $t('updateUserSourceForm.addProvider') }}
</ButtonText>
</div>
<div <div
v-for="appAuthProvider in appAuthProviderPerTypes[appAuthType.type]" v-for="appAuthProvider in appAuthProviderPerTypes[appAuthType.type]"
@ -63,11 +78,16 @@
<component <component
:is="appAuthType.formComponent" :is="appAuthType.formComponent"
v-if="hasAtLeastOneOfThisType(appAuthType)" v-if="hasAtLeastOneOfThisType(appAuthType)"
:integration="integration" :ref="`authProviderForm`"
:current-user-source="fullValues"
:default-values="appAuthProvider"
excluded-form excluded-form
:integration="integration"
:user-source="fullValues"
:auth-providers="appAuthProviderPerTypes"
:auth-provider="appAuthProvider"
:default-values="appAuthProvider"
:auth-provider-type="appAuthType"
@values-changed="updateAuthProvider(appAuthProvider, $event)" @values-changed="updateAuthProvider(appAuthProvider, $event)"
@delete="remove(appAuthProvider)"
/> />
</div> </div>
</div> </div>
@ -82,7 +102,6 @@
import form from '@baserow/modules/core/mixins/form' import form from '@baserow/modules/core/mixins/form'
import IntegrationDropdown from '@baserow/modules/core/components/integrations/IntegrationDropdown' import IntegrationDropdown from '@baserow/modules/core/components/integrations/IntegrationDropdown'
import { required, maxLength } from 'vuelidate/lib/validators' import { required, maxLength } from 'vuelidate/lib/validators'
import { uuid } from '@baserow/modules/core/utils/string'
export default { export default {
components: { IntegrationDropdown }, components: { IntegrationDropdown },
@ -97,10 +116,6 @@ export default {
required: false, required: false,
default: null, default: null,
}, },
integrations: {
type: Array,
required: true,
},
}, },
data() { data() {
return { return {
@ -113,6 +128,9 @@ export default {
} }
}, },
computed: { computed: {
integrations() {
return this.$store.getters['integration/getIntegrations'](this.builder)
},
integration() { integration() {
if (!this.values.integration_id) { if (!this.values.integration_id) {
return null return null
@ -140,6 +158,8 @@ export default {
methods: { methods: {
// Override the default getChildFormValues to exclude the provider forms from // Override the default getChildFormValues to exclude the provider forms from
// final values as they are handled directly by this component // final values as they are handled directly by this component
// The problem is that the child provider forms are not handled as a sub array
// so they override the userSource configuration
getChildFormsValues() { getChildFormsValues() {
return Object.assign( return Object.assign(
{}, {},
@ -157,6 +177,14 @@ export default {
hasAtLeastOneOfThisType(appAuthProviderType) { hasAtLeastOneOfThisType(appAuthProviderType) {
return this.appAuthProviderPerTypes[appAuthProviderType.type]?.length > 0 return this.appAuthProviderPerTypes[appAuthProviderType.type]?.length > 0
}, },
/** Return an integer bigger than any of the current auth_provider id to
* keep the right order when we want to map the error coming back from the server.
*/
nextID() {
return (
Math.max(1, ...this.values.auth_providers.map(({ id }) => id)) + 100
)
},
onSelect(appAuthProviderType) { onSelect(appAuthProviderType) {
if (this.hasAtLeastOneOfThisType(appAuthProviderType)) { if (this.hasAtLeastOneOfThisType(appAuthProviderType)) {
this.values.auth_providers = this.values.auth_providers.filter( this.values.auth_providers = this.values.auth_providers.filter(
@ -165,10 +193,21 @@ export default {
} else { } else {
this.values.auth_providers.push({ this.values.auth_providers.push({
type: appAuthProviderType.type, type: appAuthProviderType.type,
id: uuid(), id: this.nextID(),
}) })
} }
}, },
addNew(appAuthProviderType) {
this.values.auth_providers.push({
type: appAuthProviderType.type,
id: this.nextID(),
})
},
remove(appAuthProvider) {
this.values.auth_providers = this.values.auth_providers.filter(
({ id }) => id !== appAuthProvider.id
)
},
updateAuthProvider(authProviderToChange, values) { updateAuthProvider(authProviderToChange, values) {
this.values.auth_providers = this.values.auth_providers.map( this.values.auth_providers = this.values.auth_providers.map(
(authProvider) => { (authProvider) => {
@ -182,6 +221,16 @@ export default {
emitChange() { emitChange() {
this.fullValues = this.getFormValues() this.fullValues = this.getFormValues()
}, },
handleServerError(error) {
if (
this.$refs.authProviderForm
.map((form) => form.handleServerError(error))
.some((result) => result)
) {
return true
}
return false
},
getError(fieldName) { getError(fieldName) {
if (!this.$v.values[fieldName].$dirty) { if (!this.$v.values[fieldName].$dirty) {
return '' return ''

View file

@ -732,7 +732,8 @@
"nameFieldLabel": "Name", "nameFieldLabel": "Name",
"nameFieldPlaceholder": "Enter a name...", "nameFieldPlaceholder": "Enter a name...",
"authTitle": "Authentication", "authTitle": "Authentication",
"integrationFieldLabel": "Integration" "integrationFieldLabel": "Integration",
"addProvider": "Add provider"
}, },
"builderLoginPageForm": { "builderLoginPageForm": {
"pageDropdownLabel": "Login Page", "pageDropdownLabel": "Login Page",

View file

@ -24,8 +24,8 @@ import { userCanViewPage } from '@baserow/modules/builder/utils/visibility'
import { import {
getTokenIfEnoughTimeLeft, getTokenIfEnoughTimeLeft,
setToken,
userSourceCookieTokenName, userSourceCookieTokenName,
setToken,
} from '@baserow/modules/core/utils/auth' } from '@baserow/modules/core/utils/auth'
const logOffAndReturnToLogin = async ({ builder, store, redirect }) => { const logOffAndReturnToLogin = async ({ builder, store, redirect }) => {
@ -59,8 +59,8 @@ export default {
$registry, $registry,
app, app,
req, req,
route,
redirect, redirect,
route,
}) { }) {
let mode = 'public' let mode = 'public'
const builderId = params.builderId ? parseInt(params.builderId, 10) : null const builderId = params.builderId ? parseInt(params.builderId, 10) : null
@ -71,6 +71,7 @@ export default {
} }
let builder = store.getters['application/getSelected'] let builder = store.getters['application/getSelected']
let needPostBuilderLoading = false
if (!builder || (builderId && builderId !== builder.id)) { if (!builder || (builderId && builderId !== builder.id)) {
try { try {
@ -105,7 +106,40 @@ export default {
}) })
} }
needPostBuilderLoading = true
}
store.dispatch('userSourceUser/setCurrentApplication', {
application: builder,
})
if (
(!process.server || req) &&
!store.getters['userSourceUser/isAuthenticated'](builder)
) {
const refreshToken = getTokenIfEnoughTimeLeft(
app,
userSourceCookieTokenName
)
if (refreshToken) {
try {
await store.dispatch('userSourceUser/refreshAuth', {
application: builder,
token: refreshToken,
})
} catch (error) {
if (error.response?.status === 401) {
// We logoff as the token has probably expired or became invalid
logOffAndReturnToLogin({ builder, store, redirect })
} else {
throw error
}
}
}
}
if (needPostBuilderLoading) {
// Post builder loading task executed once per application // Post builder loading task executed once per application
// It's executed here to make sure we are authenticated at that point
const sharedPage = await store.getters['page/getSharedPage'](builder) const sharedPage = await store.getters['page/getSharedPage'](builder)
await Promise.all([ await Promise.all([
store.dispatch('dataSource/fetchPublished', { store.dispatch('dataSource/fetchPublished', {
@ -127,37 +161,18 @@ export default {
} }
) )
} }
store.dispatch('userSourceUser/setCurrentApplication', {
application: builder,
})
if ( // Auth providers can get error code from the URL parameters
(!process.server || req) && for (const userSource of builder.user_sources) {
!store.getters['userSourceUser/isAuthenticated'](builder) for (const authProvider of userSource.auth_providers) {
) { const authError = $registry
// token can be in the query string (SSO) or in the cookies (previous session) .get('appAuthProvider', authProvider.type)
let refreshToken = route.query.token .handleError(userSource, authProvider, route)
if (refreshToken) { if (authError) {
setToken(app, refreshToken, userSourceCookieTokenName, { return error({
sameSite: 'Lax', statusCode: authError.code,
}) message: authError.message,
} else {
refreshToken = getTokenIfEnoughTimeLeft(app, userSourceCookieTokenName)
}
if (refreshToken) {
try {
await store.dispatch('userSourceUser/refreshAuth', {
application: builder,
token: refreshToken,
}) })
} catch (error) {
if (error.response?.status === 401) {
// We logoff as the token has probably expired or became invalid
logOffAndReturnToLogin({ builder, store, redirect })
} else {
throw error
}
} }
} }
} }
@ -348,7 +363,7 @@ export default {
} }
}, },
}, },
async isAuthenticated() { async isAuthenticated(newIsAuthenticated) {
// When the user login or logout, we need to refetch the elements and actions // When the user login or logout, we need to refetch the elements and actions
// as they might have changed // as they might have changed
await this.$store.dispatch('element/fetchPublished', { await this.$store.dispatch('element/fetchPublished', {
@ -364,12 +379,18 @@ export default {
page: this.sharedPage, page: this.sharedPage,
}) })
// If the user is on a hidden page, redirect them to the Login page if possible. if (newIsAuthenticated) {
await this.maybeRedirectUserToLoginPage() // If the user has just logged in, we redirect him to the next page.
await this.maybeRedirectToNextPage()
} else {
// If the user is on a hidden page, redirect them to the Login page if possible.
await this.maybeRedirectUserToLoginPage()
}
}, },
}, },
async mounted() { async mounted() {
await this.maybeRedirectUserToLoginPage() await this.maybeRedirectUserToLoginPage()
await this.checkProviderAuthentication()
}, },
methods: { methods: {
/** /**
@ -389,8 +410,57 @@ export default {
this.mode this.mode
) )
if (url !== this.$router.history.current?.fullPath) { const currentPath = this.$route.fullPath
this.$router.push(url) if (url !== currentPath) {
const nextPath = encodeURIComponent(currentPath)
this.$router.push({ path: url, query: { next: nextPath } })
}
}
},
maybeRedirectToNextPage() {
if (this.$route.query.next) {
const decodedNext = decodeURIComponent(this.$route.query.next)
this.$router.push(decodedNext)
}
},
async checkProviderAuthentication() {
// Iterate over all auth providers to check if one can get a refresh token
let refreshTokenFromProvider = null
for (const userSource of this.builder.user_sources) {
for (const authProvider of userSource.auth_providers) {
refreshTokenFromProvider = this.$registry
.get('appAuthProvider', authProvider.type)
.getAuthToken(userSource, authProvider, this.$route)
if (refreshTokenFromProvider) {
break
}
}
if (refreshTokenFromProvider) {
break
}
}
if (refreshTokenFromProvider) {
setToken(this, refreshTokenFromProvider, userSourceCookieTokenName, {
sameSite: 'Lax',
})
try {
await this.$store.dispatch('userSourceUser/refreshAuth', {
application: this.builder,
token: refreshTokenFromProvider,
})
} catch (error) {
if (error.response?.status === 401) {
// We logoff as the token has probably expired or became invalid
logOffAndReturnToLogin({
builder: this.builder,
store: this.$store,
redirect: (...args) => this.$router.push(...args),
})
} else {
throw error
}
} }
} }
}, },

View file

@ -1,18 +1,30 @@
import { Registerable } from '@baserow/modules/core/registry' import { BaseAuthProviderType } from '@baserow/modules/core/authProviderTypes'
export class AppAuthProviderType extends Registerable { export class AppAuthProviderType extends BaseAuthProviderType {
get name() { get name() {
throw new Error('Must be set on the type.') return this.getName()
}
getLoginOptions(authProvider) {
return null
}
get component() {
return null
} }
/** /**
* The form to edit this user source. * The form to edit this user source.
*/ */
get formComponent() { get formComponent() {
return this.getAdminSettingsFormComponent()
}
getAuthToken(userSource, authProvider, route) {
return null return null
} }
getOrder() { handleError(userSource, authProvider, route) {
return 0 return null
} }
} }

View file

@ -2,8 +2,6 @@
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 20px; gap: 20px;
padding: 12px;
padding-bottom: 0;
.tabs__header { .tabs__header {
border-bottom: none; border-bottom: none;

View file

@ -1,4 +1,16 @@
.update-user-source-form__auth-provider {
margin-top: 16px;
}
.update-user-source-form__auth-provider-header {
display: flex;
justify-content: space-between;
}
.update-user-source-form__auth-provider-form { .update-user-source-form__auth-provider-form {
margin-left: 24px; margin-left: 24px;
margin-top: 12px;
display: flex; display: flex;
flex-direction: column;
align-items: stretch;
} }

View file

@ -2,10 +2,9 @@ import { Registerable } from '@baserow/modules/core/registry'
import PasswordAuthIcon from '@baserow/modules/core/assets/images/providers/Key.svg' import PasswordAuthIcon from '@baserow/modules/core/assets/images/providers/Key.svg'
/** /**
* The authorization provider type base class that can be extended when creating * Base class for authorization provider types
* a plugin for the frontend.
*/ */
export class AuthProviderType extends Registerable { export class BaseAuthProviderType extends Registerable {
/** /**
* The icon for the provider * The icon for the provider
*/ */
@ -14,14 +13,14 @@ export class AuthProviderType extends Registerable {
} }
/** /**
* A human readable name of the application type. * A human readable name of the authentication provider.
*/ */
getName() { getName() {
return null return null
} }
/** /**
* A human readable name of the application type. * A human readable name of the authentication provider.
*/ */
getProviderName(provider) { getProviderName(provider) {
return null return null
@ -81,6 +80,35 @@ export class AuthProviderType extends Registerable {
} }
} }
/**
* Whether we can create new providers on this type. Sometimes providers can't be
* created because of permissions reasons or because of unicity constraints.
* @returns a boolean saying if you can create new providers of this type.
*/
canCreateNew(authProviders) {
return true
}
/**
*
* @param {Object} err to handle
* @param {VueInstance} vueComponentInstance the vue component instance
* @returns true if the error is handled else false.
*/
handleServerError(vueComponentInstance, error) {
return false
}
getOrder() {
throw new Error('The order of the authentication provider must be set.')
}
}
/**
* The authorization provider type base class that can be extended when creating
* a plugin for the frontend.
*/
export class AuthProviderType extends BaseAuthProviderType {
constructor(...args) { constructor(...args) {
super(...args) super(...args)
this.type = this.getType() this.type = this.getType()
@ -108,10 +136,6 @@ export class AuthProviderType extends Registerable {
routeName: this.routeName, routeName: this.routeName,
} }
} }
getOrder() {
throw new Error('The order of an application type must be set.')
}
} }
export class PasswordAuthProviderType extends AuthProviderType { export class PasswordAuthProviderType extends AuthProviderType {
@ -139,6 +163,16 @@ export class PasswordAuthProviderType extends AuthProviderType {
return null return null
} }
/**
* We can create only one password provider.
*/
canCreateNew(authProviders) {
return (
!authProviders[this.getType()] ||
authProviders[this.getType()].length === 0
)
}
getOrder() { getOrder() {
return 1 return 1
} }

View file

@ -1,6 +1,7 @@
<template> <template>
<div <div
v-if="open" v-if="open || keepContent"
v-show="(keepContent && open) || !keepContent"
ref="modalWrapper" ref="modalWrapper"
class="modal__wrapper" class="modal__wrapper"
@click="outside($event)" @click="outside($event)"
@ -172,6 +173,14 @@ export default {
default: false, default: false,
required: false, required: false,
}, },
// This flag allow to keep the modal content in case you want it to be available.
// Useful if you have a form inside the modal that is a sub part of the current
// form.
keepContent: {
type: Boolean,
default: false,
required: false,
},
}, },
data() { data() {
return { return {

View file

@ -0,0 +1,40 @@
import form from '@baserow/modules/core/mixins/form'
export default {
mixins: [form],
props: {
authProviders: {
type: Object,
required: true,
},
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
authProviderType: {
type: Object,
required: true,
},
},
data() {
return { serverErrors: {} }
},
computed: {
providerName() {
return this.authProviderType.getProviderName(this.authProvider)
},
},
methods: {
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
},
handleServerError(error) {
return this.authProviderType.handleServerError(this, error)
},
},
}

View file

@ -85,28 +85,54 @@ export default {
firstError.scrollIntoView({ behavior: 'smooth' }) firstError.scrollIntoView({ behavior: 'smooth' })
} }
}, },
touch() { /**
* Select all children that match the given predicate.
* @param {Function} predicate a function that receive the current child as parameter and
* should return true if the child should be accepted.
* @param {Boolean} deep true if you want to deeply search for child
* @returns
*/
getChildForms(predicate = (child) => 'isFormValid' in child, deep = false) {
const children = []
const getDeep = (child) => {
if (predicate(child)) {
children.push(child)
}
if (deep) {
// Search into children of children
child.$children.forEach(getDeep)
}
}
for (const child of this.$children) {
getDeep(child)
}
return children
},
touch(deep = false) {
if ('$v' in this) { if ('$v' in this) {
this.$v.$touch() this.$v.$touch()
} }
// Also touch all the child forms so that all the error messages are going to // Also touch all the child forms so that all the error messages are going to
// be displayed. // be displayed.
for (const child of this.$children) { for (const child of this.getChildForms(
if ('isFormValid' in child && '$v' in child) { (child) => 'touch' in child,
child.touch() deep
} )) {
child.touch(deep)
} }
}, },
submit() { submit(deep = false) {
if (this.selectedFieldIsDeactivated) { if (this.selectedFieldIsDeactivated) {
return return
} }
this.touch() this.touch(deep)
if (this.isFormValid()) { if (this.isFormValid(deep)) {
this.$emit('submitted', this.getFormValues()) this.$emit('submitted', this.getFormValues(deep))
} else { } else {
this.$nextTick(() => this.focusOnFirstError()) this.$nextTick(() => this.focusOnFirstError())
} }
@ -120,21 +146,6 @@ export default {
? this.$v.values[fieldName].$error ? this.$v.values[fieldName].$error
: false : false
}, },
getChildForms(deep = false) {
const children = []
const getDeep = (child) => {
if ('isFormValid' in child) {
children.push(child)
} else if (deep) {
child.$children.forEach(getDeep)
}
}
for (const child of this.$children) {
getDeep(child)
}
return children
},
/** /**
* Returns true is everything is valid. * Returns true is everything is valid.
* *
@ -151,8 +162,11 @@ export default {
* Returns true if all the child form components are valid. * Returns true if all the child form components are valid.
*/ */
areChildFormsValid(deep = false) { areChildFormsValid(deep = false) {
for (const child of this.getChildForms(deep)) { for (const child of this.getChildForms(
if ('isFormValid' in child && !child.isFormValid()) { (child) => 'isFormValid' in child,
deep
)) {
if (!child.isFormValid(deep)) {
return false return false
} }
} }
@ -162,17 +176,21 @@ export default {
* A method that can be overridden to do some mutations on the values before * A method that can be overridden to do some mutations on the values before
* calling the submitted event. * calling the submitted event.
*/ */
getFormValues() { getFormValues(deep = false) {
return Object.assign({}, this.values, this.getChildFormsValues()) return Object.assign({}, this.values, this.getChildFormsValues(deep))
}, },
/** /**
* Returns an object containing the values of the child forms. * Returns an object containing the values of the child forms.
*/ */
getChildFormsValues() { getChildFormsValues(deep = false) {
const children = this.getChildForms(
(child) => 'getChildFormsValues' in child,
deep
)
return Object.assign( return Object.assign(
{}, {},
...this.$children.map((child) => { ...children.map((child) => {
return 'getChildFormsValues' in child ? child.getFormValues() : {} return child.getFormValues(deep)
}) })
) )
}, },
@ -196,16 +214,22 @@ export default {
await this.$nextTick() await this.$nextTick()
// Also reset the child forms after a tick to allow default values to be updated. // Also reset the child forms after a tick to allow default values to be updated.
this.getChildForms(deep).forEach((child) => child.reset()) this.getChildForms((child) => 'reset' in child, deep).forEach((child) =>
child.reset()
)
}, },
/** /**
* Returns if a child form has indicated it handled the error, false otherwise. * Returns if a child form has indicated it handled the error, false otherwise.
*/ */
handleErrorByForm(error) { handleErrorByForm(error, deep = false) {
let childHandledIt = false let childHandledIt = false
for (const child of this.$children) { const children = this.getChildForms(
if ('handleErrorByForm' in child && child.handleErrorByForm(error)) { (child) => 'handleErrorByForm' in child,
deep
)
for (const child of children) {
if (child.handleErrorByForm(error)) {
childHandledIt = true childHandledIt = true
} }
} }

View file

@ -1,5 +1,6 @@
const FF_ENABLE_ALL = '*' const FF_ENABLE_ALL = '*'
export const FF_DASHBOARDS = 'dashboards' export const FF_DASHBOARDS = 'dashboards'
export const FF_AB_SSO = 'ab_sso'
/** /**
* A comma separated list of feature flags used to enable in-progress or not ready * A comma separated list of feature flags used to enable in-progress or not ready