1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-03 04:35:31 +00:00

Add Saml auth provider

This commit is contained in:
Jérémie Pardou 2024-12-10 14:27:16 +00:00
parent 993e1c1233
commit 076c1ebf53
99 changed files with 3007 additions and 972 deletions
backend
enterprise
web-frontend/modules

View file

@ -7,6 +7,7 @@ from rest_framework import serializers
from baserow.api.polymorphic import PolymorphicSerializer
from baserow.core.app_auth_providers.models import AppAuthProvider
from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry
from baserow.core.auth_provider.validators import validate_domain
class AppAuthProviderSerializer(serializers.ModelSerializer):
@ -46,6 +47,13 @@ class BaseAppAuthProviderSerializer(serializers.ModelSerializer):
help_text="The type of the app_auth_provider.",
)
domain = serializers.CharField(
validators=[validate_domain],
required=False,
allow_null=True,
help_text=AppAuthProvider._meta.get_field("domain").help_text,
)
class Meta:
model = AppAuthProvider
fields = ("type", "user_source_id", "domain")

View file

@ -3,12 +3,19 @@ from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from baserow.core.auth_provider.models import AuthProviderModel
from baserow.core.auth_provider.validators import validate_domain
from baserow.core.registries import auth_provider_type_registry
class AuthProviderSerializer(serializers.ModelSerializer):
type = serializers.SerializerMethodField(help_text="The type of the related field.")
domain = serializers.CharField(
validators=[validate_domain],
required=True,
help_text=AuthProviderModel._meta.get_field("domain").help_text,
)
class Meta:
model = AuthProviderModel
fields = ("id", "type", "domain", "enabled")

View file

@ -104,6 +104,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class,
request=self.request,
context=self.context,
extra_params=self.extra_params,
)
ret = serializer.to_representation(instance)
@ -122,6 +123,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class,
request=self.request,
context=self.context,
extra_params=self.extra_params,
)
return serializer.to_internal_value(data)
@ -134,6 +136,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class,
request=self.request,
context=self.context,
extra_params=self.extra_params,
)
return serializer.create(validated_data)
@ -150,6 +153,7 @@ class PolymorphicSerializer(serializers.Serializer):
base_class=self.base_class,
request=self.request,
context=self.context,
extra_params=self.extra_params,
)
return serializer.update(instance, validated_data)
@ -170,6 +174,7 @@ class PolymorphicSerializer(serializers.Serializer):
context=self.context,
data=self.data,
partial=self.partial,
extra_params=self.extra_params,
)
except serializers.ValidationError:
child_valid = False
@ -194,6 +199,7 @@ class PolymorphicSerializer(serializers.Serializer):
request=self.request,
context=self.context,
partial=self.partial,
extra_params=self.extra_params,
)
validated_data = serializer.run_validation(data)

View file

@ -2,7 +2,11 @@ from django.urls import include, path
from drf_spectacular.views import SpectacularRedocView
from baserow.core.registries import application_type_registry, plugin_registry
from baserow.core.registries import (
application_type_registry,
auth_provider_type_registry,
plugin_registry,
)
from .applications import urls as application_urls
from .auth_provider import urls as auth_provider_urls
@ -53,5 +57,6 @@ urlpatterns = (
),
]
+ application_type_registry.api_urls
+ auth_provider_type_registry.api_urls
+ plugin_registry.api_urls
)

View file

@ -3,6 +3,7 @@ from typing import Optional, Tuple, TypeVar
from django.conf import settings
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.request import Request
from rest_framework_simplejwt.authentication import JWTAuthentication
@ -157,3 +158,19 @@ class UserSourceJSONWebTokenAuthentication(JWTAuthentication):
user,
validated_token,
)
class UserSourceJSONWebTokenAuthenticationExtension(OpenApiAuthenticationExtension):
target_class = (
"baserow.api.user_sources.authentication.UserSourceJSONWebTokenAuthentication"
)
name = "UserSource JWT"
match_subclasses = True
priority = -1
def get_security_definition(self, auto_schema):
return {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT your_token",
}

View file

@ -6,9 +6,7 @@ from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from baserow.api.app_auth_providers.serializers import (
ReadPolymorphicAppAuthProviderSerializer,
)
from baserow.api.app_auth_providers.serializers import AppAuthProviderSerializer
from baserow.api.polymorphic import PolymorphicSerializer
from baserow.api.services.serializers import PublicServiceSerializer
from baserow.api.user_files.serializers import UserFileField, UserFileSerializer
@ -26,6 +24,7 @@ from baserow.contrib.builder.elements.registries import element_type_registry
from baserow.contrib.builder.models import Builder
from baserow.contrib.builder.pages.handler import PageHandler
from baserow.contrib.builder.pages.models import Page
from baserow.core.app_auth_providers.registries import app_auth_provider_type_registry
from baserow.core.services.registries import service_type_registry
from baserow.core.user_sources.models import UserSource
from baserow.core.user_sources.registries import user_source_type_registry
@ -175,6 +174,16 @@ class PublicPageSerializer(serializers.ModelSerializer):
}
class PublicPolymorphicAppAuthProviderSerializer(PolymorphicSerializer):
"""
Polymorphic serializer for App Auth providers.
"""
base_class = AppAuthProviderSerializer
registry = app_auth_provider_type_registry
extra_params = {"public": True}
class BasePublicUserSourceSerializer(serializers.ModelSerializer):
"""
Basic user source serializer mostly for returned values.
@ -186,7 +195,7 @@ class BasePublicUserSourceSerializer(serializers.ModelSerializer):
def get_type(self, instance):
return user_source_type_registry.get_by_model(instance.specific_class).type
auth_providers = ReadPolymorphicAppAuthProviderSerializer(
auth_providers = PublicPolymorphicAppAuthProviderSerializer(
required=False,
many=True,
help_text="Auth providers related to this user source.",

View file

@ -1,6 +1,8 @@
from typing import Any, Dict, List, Optional
from urllib.parse import urljoin
from zipfile import ZipFile
from django.conf import settings
from django.contrib.auth.models import AbstractUser
from django.core.files.storage import Storage
from django.db import transaction
@ -419,6 +421,30 @@ class BuilderApplicationType(ApplicationType):
return builder
def get_default_application_urls(self, application: Builder) -> list[str]:
"""
Returns the default frontend urls of a builder application.
"""
from baserow.contrib.builder.domains.handler import DomainHandler
domain = DomainHandler().get_domain_for_builder(application)
if domain is not None:
# Let's also return the preview url so that it's easier to test
preview_url = urljoin(
settings.PUBLIC_WEB_FRONTEND_URL,
f"/builder/{domain.builder_id}/preview/",
)
return [domain.get_public_url(), preview_url]
preview_url = urljoin(
settings.PUBLIC_WEB_FRONTEND_URL,
f"/builder/{application.id}/preview/",
)
# It's an unpublished version let's return to the home preview page
return [preview_url]
def enhance_queryset(self, queryset):
queryset = queryset.prefetch_related("page_set")
queryset = queryset.prefetch_related("user_sources")

View file

@ -85,6 +85,17 @@ class DomainHandler:
return domain.published_to
def get_domain_for_builder(self, builder: Builder) -> Domain | None:
"""
Returns the domain the builder is published for or None if it's not a published
builder.
"""
try:
return Domain.objects.get(published_to=builder)
except Domain.DoesNotExist:
return None
def create_domain(
self, domain_type: DomainType, builder: Builder, **kwargs
) -> Domain:

View file

@ -1,3 +1,6 @@
from urllib.parse import urlparse
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import CASCADE, SET_NULL
@ -80,6 +83,16 @@ class Domain(
class Meta:
ordering = ("order",)
def get_public_url(self):
"""
Returns the URL for this domain.
"""
# Parse the PUBLIC_WEB_FRONTEND_URL to extract the scheme and port
parsed_url = urlparse(settings.PUBLIC_WEB_FRONTEND_URL)
port_string = f":{parsed_url.port}" if parsed_url.port else ""
return f"{parsed_url.scheme}://{self.domain_name}{port_string}"
@classmethod
def get_last_order(cls, builder):
queryset = Domain.objects.filter(builder=builder)

View file

@ -4,6 +4,7 @@ from baserow.contrib.builder.data_sources.operations import (
DispatchDataSourceOperationType,
ListDataSourcesPageOperationType,
)
from baserow.contrib.builder.domains.handler import DomainHandler
from baserow.contrib.builder.elements.operations import ListElementsPageOperationType
from baserow.contrib.builder.models import Builder
from baserow.contrib.builder.workflow_actions.operations import (
@ -20,8 +21,6 @@ from baserow.core.user_sources.operations import (
)
from baserow.core.user_sources.subjects import UserSourceUserSubjectType
from .models import Domain
User = get_user_model()
@ -101,10 +100,7 @@ class AllowPublicBuilderManagerType(PermissionManagerType):
# give access to specific data.
continue
if (
builder.workspace is None
and Domain.objects.filter(published_to=builder).exists()
):
if DomainHandler().get_domain_for_builder(builder) is not None:
# it's a public builder, we allow it.
result[check] = True

View file

@ -1,72 +0,0 @@
from typing import Dict, List, Optional
from baserow.core.registry import CustomFieldsInstanceMixin
class PublicCustomFieldsInstanceMixin(CustomFieldsInstanceMixin):
public_serializer_field_names = []
"""The field names that must be added to the serializer if it's public."""
public_request_serializer_field_names = []
"""
The field names that must be added to the public request serializer if different
from the `public_serializer_field_names`.
"""
request_serializer_field_overrides = None
"""
The fields that must be added to the request serializer if different from the
`serializer_field_overrides` property.
"""
public_serializer_field_overrides = None
"""The fields that must be added to the public serializer."""
public_request_serializer_field_overrides = None
"""
The fields that must be added to the public request serializer if different from the
`public_serializer_field_overrides` property.
"""
def get_field_overrides(
self, request_serializer: bool, extra_params=None, **kwargs
) -> Dict:
public = extra_params.get("public", False)
if public:
if request_serializer and self.public_request_serializer_field_overrides:
return self.public_request_serializer_field_overrides
if self.public_serializer_field_overrides:
return self.public_serializer_field_overrides
return super().get_field_overrides(request_serializer, extra_params, **kwargs)
def get_field_names(
self, request_serializer: bool, extra_params=None, **kwargs
) -> List[str]:
public = extra_params.get("public", False)
if public:
if request_serializer and self.public_request_serializer_field_names:
return self.public_request_serializer_field_names
if self.public_serializer_field_names:
return self.public_serializer_field_names
return super().get_field_names(request_serializer, extra_params, **kwargs)
def get_meta_ref_name(
self,
request_serializer: bool,
extra_params=None,
**kwargs,
) -> Optional[str]:
meta_ref_name = super().get_meta_ref_name(
request_serializer, extra_params, **kwargs
)
public = extra_params.get("public", False)
if public:
meta_ref_name = f"Public{meta_ref_name}"
return meta_ref_name

View file

@ -4,11 +4,11 @@ from django.contrib.auth.models import AbstractUser
from baserow.contrib.builder.formula_importer import import_formula
from baserow.contrib.builder.mixins import BuilderInstanceWithFormulaMixin
from baserow.contrib.builder.registries import PublicCustomFieldsInstanceMixin
from baserow.contrib.builder.workflow_actions.models import BuilderWorkflowAction
from baserow.core.registry import (
CustomFieldsRegistryMixin,
ModelRegistryMixin,
PublicCustomFieldsInstanceMixin,
Registry,
)
from baserow.core.workflow_actions.registries import WorkflowActionType
@ -88,6 +88,7 @@ class BuilderWorkflowActionType(
cache = {}
element_id = serialized_values["element_id"]
import_context = {}
if element_id:
imported_element_id = id_mapping["builder_page_elements"][element_id]
import_context = ElementHandler().get_import_context_addition(

View file

@ -1,18 +1,20 @@
from typing import TYPE_CHECKING, Callable, List, Type, Union
from typing import TYPE_CHECKING, Callable, List, Tuple, Type, Union
from django.contrib.auth.models import AbstractUser
from baserow.core.app_auth_providers.exceptions import IncompatibleUserSourceType
from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict
from baserow.core.auth_provider.registries import BaseAuthProviderType
from baserow.core.auth_provider.types import AuthProviderModelSubClass
from baserow.core.registry import EasyImportExportMixin
from baserow.core.auth_provider.types import AuthProviderModelSubClass, UserInfo
from baserow.core.registry import EasyImportExportMixin, PublicCustomFieldsInstanceMixin
if TYPE_CHECKING:
from baserow.core.user_sources.types import UserSourceSubClass
class AppAuthProviderType(EasyImportExportMixin, BaseAuthProviderType):
class AppAuthProviderType(
EasyImportExportMixin, PublicCustomFieldsInstanceMixin, BaseAuthProviderType
):
"""
Authentication provider for application user sources.
"""
@ -70,3 +72,16 @@ class AppAuthProviderType(EasyImportExportMixin, BaseAuthProviderType):
:param instance: The auth provider instance related to the user source.
:param user_source: The user source being updated.
"""
def get_or_create_user_and_sign_in(
self, auth_provider: AuthProviderModelSubClass, user_info: UserInfo
) -> Tuple[AbstractUser, bool]:
"""
Get or create a user for the given UserInfo. Calls the related userSource
get_or_create_user.
"""
user_source = auth_provider.user_source.specific
return user_source.get_type().get_or_create_user(
user_source, email=user_info.email, name=user_info.name
)

View file

@ -37,4 +37,4 @@ class AppAuthProvider(BaseAuthProviderModel, HierarchicalModelMixin):
return app_auth_provider_type_registry
class Meta:
ordering = ["domain", "id"]
ordering = ["id"]

View file

@ -17,8 +17,14 @@ class BaseAuthProviderModel(
Base abstract model for app_providers.
"""
domain = models.CharField(max_length=255, null=True)
enabled = models.BooleanField(default=True)
domain = models.CharField(
max_length=255,
null=True,
help_text="The email domain registered with this provider.",
)
enabled = models.BooleanField(
help_text="Whether the provider is enabled or not.", default=True
)
class Meta:
abstract = True

View file

@ -47,13 +47,6 @@ class BaseAuthProviderType(
default_create_allowed_fields = ["domain", "enabled"]
default_update_allowed_fields = ["domain", "enabled"]
@abstractmethod
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
"""
Returns a dictionary containing the login options
to populate the login component accordingly.
"""
def can_create_new_providers(self, **kwargs) -> bool:
"""
Returns True if it's possible to create an authentication provider of this type.
@ -249,6 +242,13 @@ class AuthenticationProviderTypeRegistry(
super().__init__(*args, **kwargs)
self._default = None
@abstractmethod
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
"""
Returns a dictionary containing the login options
to populate the login component accordingly.
"""
def get_all_available_login_options(self):
login_options = {}
for provider_type in self.get_all():

View file

@ -0,0 +1,33 @@
from django.core.management.base import BaseCommand
from django.urls import URLPattern, URLResolver, get_resolver
class Command(BaseCommand):
help = "List all registered full URLs with their namespaces"
def handle(self, *args, **kwargs):
resolver = get_resolver()
self.list_urls(resolver.url_patterns)
def list_urls(self, urlpatterns, prefix="", namespace=None):
for pattern in urlpatterns:
if isinstance(pattern, URLPattern):
# Construct the full URL path
full_url = f"{prefix}{pattern.pattern}"
full_namespace = (
f"{namespace}:{pattern.name}"
if namespace and pattern.name
else pattern.name or "None"
)
self.stdout.write(f"URL: {full_url}, Namespace: {full_namespace}")
elif isinstance(pattern, URLResolver):
# Construct the full namespace and recurse
new_prefix = f"{prefix}{pattern.pattern}"
new_namespace = (
f"{namespace}:{pattern.namespace}"
if namespace and pattern.namespace
else pattern.namespace or namespace
)
self.list_urls(
pattern.url_patterns, prefix=new_prefix, namespace=new_namespace
)

View file

@ -0,0 +1,48 @@
# Generated by Django 5.0.9 on 2024-11-19 09:43
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("core", "0092_alter_userprofile_language"),
]
operations = [
migrations.AlterModelOptions(
name="appauthprovider",
options={"ordering": ["id"]},
),
migrations.AlterField(
model_name="appauthprovider",
name="domain",
field=models.CharField(
help_text="The email domain registered with this provider.",
max_length=255,
null=True,
),
),
migrations.AlterField(
model_name="appauthprovider",
name="enabled",
field=models.BooleanField(
default=True, help_text="Whether the provider is enabled or not."
),
),
migrations.AlterField(
model_name="authprovidermodel",
name="domain",
field=models.CharField(
help_text="The email domain registered with this provider.",
max_length=255,
null=True,
),
),
migrations.AlterField(
model_name="authprovidermodel",
name="enabled",
field=models.BooleanField(
default=True, help_text="Whether the provider is enabled or not."
),
),
]

View file

@ -507,6 +507,13 @@ class ApplicationType(
def enhance_queryset(self, queryset):
return queryset
def get_default_application_urls(self, application: "Application") -> list[str]:
"""
Returns the default frontend urls of the application if any.
"""
return []
ApplicationSubClassInstance = TypeVar(
"ApplicationSubClassInstance", bound="Application"

View file

@ -276,13 +276,95 @@ class CustomFieldsInstanceMixin:
return None
class PublicCustomFieldsInstanceMixin(CustomFieldsInstanceMixin):
"""
A mixin for instance with custom fields but some field should remains private
when used in some APIs.
"""
public_serializer_field_names = None
"""The field names that must be added to the serializer if it's public."""
public_request_serializer_field_names = None
"""
The field names that must be added to the public request serializer if different
from the `public_serializer_field_names`.
"""
request_serializer_field_overrides = None
"""
The fields that must be added to the request serializer if different from the
`serializer_field_overrides` property.
"""
public_serializer_field_overrides = None
"""The fields that must be added to the public serializer."""
public_request_serializer_field_overrides = None
"""
The fields that must be added to the public request serializer if different from the
`public_serializer_field_overrides` property.
"""
def get_field_overrides(
self, request_serializer: bool, extra_params=None, **kwargs
) -> Dict:
public = extra_params.get("public", False)
if public:
if (
request_serializer is not None
and self.public_request_serializer_field_overrides is not None
):
return self.public_request_serializer_field_overrides
if self.public_serializer_field_overrides is not None:
return self.public_serializer_field_overrides
return super().get_field_overrides(request_serializer, extra_params, **kwargs)
def get_field_names(
self, request_serializer: bool, extra_params=None, **kwargs
) -> List[str]:
public = extra_params.get("public", False)
if public:
if (
request_serializer is not None
and self.public_request_serializer_field_names is not None
):
return self.public_request_serializer_field_names
if self.public_serializer_field_names is not None:
return self.public_serializer_field_names
return super().get_field_names(request_serializer, extra_params, **kwargs)
def get_meta_ref_name(
self,
request_serializer: bool,
extra_params=None,
**kwargs,
) -> Optional[str]:
meta_ref_name = super().get_meta_ref_name(
request_serializer, extra_params, **kwargs
)
public = extra_params.get("public", False)
if public:
meta_ref_name = f"Public{meta_ref_name}"
return meta_ref_name
class APIUrlsInstanceMixin:
def get_api_urls(self):
def get_api_urls(self) -> List:
"""
If needed custom api related urls to the instance can be added here.
Example:
from django.urls import include, path
def get_api_urls(self):
from . import api_urls
@ -298,7 +380,6 @@ class APIUrlsInstanceMixin:
]
:return: A list containing the urls.
:rtype: list
"""
return []

View file

@ -1,6 +1,6 @@
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
from django.contrib.auth.models import AbstractUser
@ -16,6 +16,7 @@ from baserow.core.registry import (
ModelRegistryMixin,
Registry,
)
from baserow.core.user.exceptions import UserNotFound
from baserow.core.user_sources.constants import DEFAULT_USER_ROLE_PREFIX
from baserow.core.user_sources.user_source_user import UserSourceUser
@ -233,15 +234,41 @@ class UserSourceType(
"""
@abstractmethod
def get_user(self, user_source: UserSource, **kwargs) -> Optional[UserSourceUser]:
def create_user(
self, user_source: UserSource, email: str, name: str
) -> UserSourceUser:
"""
Create a user for the given user source instance from it's email and name.
:param user_source: The user source we want to create the user for.
:param email: Email of the user to create.
:param name: Name of the user to create.
:return: A user instance.
"""
@abstractmethod
def get_user(self, user_source: UserSource, **kwargs) -> UserSourceUser:
"""
Returns a user given some args.
:param user_source: The user source used to get the user.
:param kwargs: Keyword arguments to get the user.
:raises UserNotFound: When the user can't be found.
:return: A user instance if any found with the given parameters.
"""
def get_or_create_user(
self, user_source: UserSource, email: str, name: str
) -> Tuple[UserSourceUser, bool]:
"""
Shorthand to create a user if he doesn't exist.
"""
try:
return self.get_user(user_source, email=email), False
except UserNotFound:
return self.create_user(user_source, email, name), True
@abstractmethod
def authenticate(self, user_source: UserSource, **kwargs) -> UserSourceUser:
"""

View file

@ -265,6 +265,9 @@ def stub_user_source_registry(data_fixture, mutable_user_source_registry, fake):
return get_user_return
return data_fixture.create_user_source_user(user_source=user_source)
def create_user(self, user_source, email, name):
return data_fixture.create_user_source_user(user_source=user_source)
def authenticate(self, user_source, **kwargs):
if authenticate_return:
if callable(authenticate_return):

View file

@ -193,7 +193,10 @@ def test_create_user_source_w_auth_providers(api_client, data_fixture):
"name": "test",
"integration_id": integration.id,
"auth_providers": [
{"type": "local_baserow_password", "enabled": False, "domain": "test1"},
{
"type": "local_baserow_password",
"enabled": False,
},
],
},
format="json",
@ -208,14 +211,87 @@ def test_create_user_source_w_auth_providers(api_client, data_fixture):
assert response_json["auth_providers"] == [
{
"domain": "test1",
"id": first.id,
"password_field_id": None,
"type": "local_baserow_password",
"domain": None,
},
]
@pytest.mark.django_db
def test_create_user_source_w_auth_providers_w_domain(api_client, data_fixture):
user, token = data_fixture.create_user_and_token()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(application=application)
url = reverse("api:user_sources:list", kwargs={"application_id": application.id})
response = api_client.post(
url,
{
"type": "local_baserow",
"name": "test",
"integration_id": integration.id,
"auth_providers": [
{
"domain": "domain.com",
"type": "local_baserow_password",
"enabled": False,
},
],
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_200_OK
assert AppAuthProvider.objects.count() == 1
first = AppAuthProvider.objects.first()
assert response_json["auth_providers"] == [
{
"id": first.id,
"password_field_id": None,
"type": "local_baserow_password",
"domain": "domain.com",
},
]
@pytest.mark.django_db
def test_create_user_source_w_auth_providers_w_wrong_domain(api_client, data_fixture):
user, token = data_fixture.create_user_and_token()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(application=application)
url = reverse("api:user_sources:list", kwargs={"application_id": application.id})
response = api_client.post(
url,
{
"type": "local_baserow",
"name": "test",
"integration_id": integration.id,
"auth_providers": [
{
"domain": "baddomain",
"type": "local_baserow_password",
"enabled": False,
},
],
},
format="json",
HTTP_AUTHORIZATION=f"JWT {token}",
)
response_json = response.json()
assert response.status_code == HTTP_400_BAD_REQUEST
assert response_json["error"] == "ERROR_REQUEST_BODY_VALIDATION"
@pytest.mark.django_db
def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture):
user, token = data_fixture.create_user_and_token()
@ -237,7 +313,6 @@ def test_create_user_source_w_auth_provider_wrong_type(api_client, data_fixture)
"integration_id": integration.id,
"auth_providers": [
{
"domain": "test_domain",
"enabled": True,
"type": app_auth_provider_type.type,
},
@ -270,7 +345,6 @@ def test_create_user_source_w_auth_provider_missing_type(api_client, data_fixtur
"integration_id": integration.id,
"auth_providers": [
{
"domain": "test_domain",
"enabled": True,
"type": "bad_type",
},
@ -374,7 +448,11 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
url,
{
"auth_providers": [
{"type": "local_baserow_password", "enabled": False, "domain": "test1"},
{
"type": "local_baserow_password",
"enabled": False,
"domain": "test2.com",
},
],
},
format="json",
@ -388,7 +466,7 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
assert response.json()["auth_providers"] == [
{
"domain": "test1",
"domain": "test2.com",
"id": first.id,
"password_field_id": None,
"type": "local_baserow_password",
@ -399,7 +477,11 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
url,
{
"auth_providers": [
{"type": "local_baserow_password", "enabled": False, "domain": "test3"},
{
"type": "local_baserow_password",
"enabled": False,
"domain": "test3.com",
},
],
},
format="json",
@ -411,7 +493,7 @@ def test_update_user_source_w_auth_providers(api_client, data_fixture):
assert response.json()["auth_providers"] == [
{
"domain": "test3",
"domain": "test3.com",
"id": first.id,
"password_field_id": None,
"type": "local_baserow_password",

View file

@ -196,3 +196,34 @@ def test_domain_publishing(data_fixture):
DomainHandler().publish(domain1, progress)
assert Builder.objects.count() == 2
@pytest.mark.django_db
def test_get_domain_for_builder(data_fixture):
user = data_fixture.create_user()
builder = data_fixture.create_builder_application(user=user)
builder_to = data_fixture.create_builder_application(workspace=None)
domain1 = data_fixture.create_builder_custom_domain(
builder=builder, published_to=builder_to, domain_name="mytest.com"
)
domain2 = data_fixture.create_builder_custom_domain(
builder=builder, domain_name="mytest2.com"
)
assert (
DomainHandler().get_domain_for_builder(builder_to).domain_name == "mytest.com"
)
assert DomainHandler().get_domain_for_builder(builder) is None
@pytest.mark.django_db
def test_get_domain_public_url(data_fixture):
user = data_fixture.create_user()
builder = data_fixture.create_builder_application(user=user)
builder_to = data_fixture.create_builder_application(workspace=None)
domain1 = data_fixture.create_builder_custom_domain(
builder=builder, published_to=builder_to, domain_name="mytest.com"
)
assert domain1.get_public_url() == "http://mytest.com:3000"

View file

@ -1636,3 +1636,21 @@ def test_builder_application_exports_file_with_zip_file(
serialized_image_element = visible_pages[0]["elements"][0]
assert serialized_image_element["image_source_type"] == "upload"
assert serialized_image_element["image_file_id"] == serialized_file
@pytest.mark.django_db
def test_get_default_application_urls(data_fixture):
user = data_fixture.create_user()
builder = data_fixture.create_builder_application(user=user)
builder_to = data_fixture.create_builder_application(workspace=None)
domain1 = data_fixture.create_builder_custom_domain(
builder=builder, published_to=builder_to, domain_name="mytest.com"
)
assert builder.get_type().get_default_application_urls(builder) == [
f"http://localhost:3000/builder/{builder.id}/preview/"
]
assert builder_to.get_type().get_default_application_urls(builder_to) == [
"http://mytest.com:3000",
f"http://localhost:3000/builder/{builder.id}/preview/",
]

View file

@ -14,7 +14,7 @@ class CreateAuthProviderSerializer(serializers.ModelSerializer):
class Meta:
model = AuthProviderModel
fields = ("domain", "type")
fields = ("domain", "type", "enabled")
class UpdateAuthProviderSerializer(serializers.ModelSerializer):
@ -29,7 +29,7 @@ class UpdateAuthProviderSerializer(serializers.ModelSerializer):
class Meta:
model = AuthProviderModel
fields = ("domain", "type")
fields = ("domain", "type", "enabled")
extra_kwargs = {
"domain": {"required": False},
}

View file

@ -0,0 +1,15 @@
from rest_framework import serializers
from baserow_enterprise.api.sso.saml.serializers import SAMLResponseSerializer
from baserow_enterprise.api.sso.serializers import BaseSsoLoginRequestSerializer
class CommonSsoLoginRequestSerializer(BaseSsoLoginRequestSerializer):
next = serializers.CharField(
required=False,
help_text="If provided, the user will be redirected to that path after login",
)
class CommonSAMLResponseSerializer(SAMLResponseSerializer):
query_param_serializer = CommonSsoLoginRequestSerializer

View file

@ -0,0 +1,19 @@
from django.urls import re_path
from .views import (
SamlAppAuthProviderAssertionConsumerServiceView,
SamlAppAuthProviderBaserowInitiatedSingleSignOn,
)
app_name = "baserow_enterprise.api.integrations.common.sso.saml"
urlpatterns = [
re_path(
r"acs/$", SamlAppAuthProviderAssertionConsumerServiceView.as_view(), name="acs"
),
re_path(
r"login/$",
SamlAppAuthProviderBaserowInitiatedSingleSignOn.as_view(),
name="login",
),
]

View file

@ -0,0 +1,245 @@
from django.db import transaction
from django.http import HttpResponseRedirect
from django.shortcuts import redirect
from drf_spectacular.openapi import OpenApiParameter, OpenApiTypes
from drf_spectacular.utils import extend_schema
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.views import APIView
from baserow.api.decorators import map_exceptions
from baserow.api.exceptions import (
QueryParameterValidationException,
RequestBodyValidationException,
)
from baserow.api.user_sources.errors import ERROR_USER_SOURCE_DOES_NOT_EXIST
from baserow.api.utils import validate_data
from baserow.core.user.exceptions import DeactivatedUserException
from baserow.core.user_sources.exceptions import (
UserSourceDoesNotExist,
UserSourceImproperlyConfigured,
)
from baserow.core.user_sources.handler import UserSourceHandler
from baserow_enterprise.api.integrations.common.sso.saml.serializers import (
CommonSAMLResponseSerializer,
CommonSsoLoginRequestSerializer,
)
from baserow_enterprise.api.sso.serializers import BaseSsoLoginRequestSerializer
from baserow_enterprise.api.sso.utils import (
SsoErrorCode,
get_valid_frontend_url,
map_sso_exceptions,
urlencode_query_params,
)
from baserow_enterprise.integrations.common.sso.saml.handler import (
SamlAppAuthProviderHandler,
)
from baserow_enterprise.integrations.common.sso.saml.models import (
SamlAppAuthProviderModel,
)
from baserow_enterprise.sso.saml.exceptions import (
InvalidSamlConfiguration,
InvalidSamlRequest,
InvalidSamlResponse,
)
class SamlAppAuthProviderAssertionConsumerServiceView(APIView):
permission_classes = (AllowAny,)
@extend_schema(
tags=["Auth"],
request=CommonSAMLResponseSerializer,
operation_id="auth_provider_saml_acs_url",
description=(
"Complete the SAML authentication flow by validating the SAML response. "
"Sign in the user if already exists in user_source or create a new one "
"otherwise."
"Once authenticated, the user will be redirected to the original "
"URL they were trying to access. If the response is invalid, the user "
"will be redirected to an error page with a specific error message."
"It accepts the language code and the workspace invitation token as query "
"parameters if provided."
),
responses={302: None},
auth=[],
)
@transaction.atomic
@map_exceptions(
{
UserSourceDoesNotExist: ERROR_USER_SOURCE_DOES_NOT_EXIST,
}
)
def post(
self,
request: Request,
user_source_uid,
) -> HttpResponseRedirect:
user_source = UserSourceHandler().get_user_source_by_uid(user_source_uid)
default_frontend_urls = (
user_source.application.get_type().get_default_application_urls(
user_source.application.specific
)
)
error_raised = {"code": None}
def on_error(error_code):
error_raised["code"] = error_code
with map_sso_exceptions(
{
InvalidSamlConfiguration: SsoErrorCode.INVALID_SAML_RESPONSE,
InvalidSamlResponse: SsoErrorCode.INVALID_SAML_RESPONSE,
DeactivatedUserException: SsoErrorCode.USER_DEACTIVATED,
RequestBodyValidationException: SsoErrorCode.INVALID_SAML_RESPONSE,
UserSourceDoesNotExist: SsoErrorCode.INVALID_SAML_REQUEST,
UserSourceImproperlyConfigured: SsoErrorCode.INVALID_SAML_REQUEST,
},
on_error=on_error,
):
# We can't use the decorator here because the redirect_url is related
# to the user source and we don't have it before.
data = validate_data(
CommonSAMLResponseSerializer,
request.data,
return_validated=True,
)
next_path = data["saml_request_data"].pop("next", None)
user = SamlAppAuthProviderHandler.sign_in_user_from_saml_response(
data["SAMLResponse"],
data["saml_request_data"],
base_queryset=SamlAppAuthProviderModel.objects.filter(
user_source=user_source
),
)
if error_raised["code"]:
# We redirect to the default frontend url with an error code
error_url = urlencode_query_params(
default_frontend_urls[0],
{f"saml_error__{user_source.id}": error_raised["code"].value},
)
return redirect(error_url)
query_params = {
f"user_source_saml_token__{user_source.id}": user.get_refresh_token()
}
if next_path:
query_params["next"] = next_path
# Otherwise it a success, we redirect to the login page
redirect_url = get_valid_frontend_url(
data["RelayState"],
default_frontend_urls=default_frontend_urls,
# Add the refresh token as query parameter
query_params=query_params,
allow_any_path=False,
)
return redirect(redirect_url)
class SamlAppAuthProviderBaserowInitiatedSingleSignOn(APIView):
permission_classes = (AllowAny,)
@extend_schema(
parameters=[
OpenApiParameter(
name="email",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="The email address of the user that want to sign in using SAML.",
),
OpenApiParameter(
name="original",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description=(
"The url to which the user should be redirected after a successful "
"login or sign up."
),
),
],
tags=["User sources"],
request=BaseSsoLoginRequestSerializer,
operation_id="app_auth_provider_saml_sp_login",
description=(
"This is the endpoint that is called when the user wants to initiate a "
"SSO SAML login from Baserow (the service provider). The user will be "
"redirected to the SAML identity provider (IdP) where the user "
"can authenticate. "
"Once logged in in the IdP, the user will be redirected back "
"to the assertion consumer service endpoint (ACS) where the SAML response "
"will be validated and a new JWT session token will be provided to work "
"with Baserow APIs."
),
responses={302: None},
auth=[],
)
@map_exceptions(
{
UserSourceDoesNotExist: ERROR_USER_SOURCE_DOES_NOT_EXIST,
}
)
def get(self, request: Request, user_source_uid: str) -> HttpResponseRedirect:
user_source = UserSourceHandler().get_user_source_by_uid(user_source_uid)
default_frontend_urls = (
user_source.application.get_type().get_default_application_urls(
user_source.application.specific
)
)
error_raised = {"code": None}
def on_error(error_code):
error_raised["code"] = error_code
with map_sso_exceptions(
{
InvalidSamlConfiguration: SsoErrorCode.INVALID_SAML_REQUEST,
InvalidSamlRequest: SsoErrorCode.INVALID_SAML_REQUEST,
RequestBodyValidationException: SsoErrorCode.INVALID_SAML_REQUEST,
},
on_error=on_error,
):
# Validate query parameters
query_params = validate_data(
CommonSsoLoginRequestSerializer,
request.GET.dict(),
partial=False,
exception_to_raise=QueryParameterValidationException,
return_validated=True,
)
original_url = query_params.pop("original", "")
valid_relay_state_url = get_valid_frontend_url(
original_url,
query_params,
default_frontend_urls=default_frontend_urls,
allow_any_path=False,
)
idp_sign_in_url = SamlAppAuthProviderHandler.get_sign_in_url(
query_params,
SamlAppAuthProviderHandler.model_class.objects.filter(
user_source__uid=user_source.uid
),
redirect_to=valid_relay_state_url,
)
if error_raised["code"]:
# We redirect to the default frontend url with an error code
error_url = urlencode_query_params(
default_frontend_urls[0],
{f"saml_error__{user_source.id}": error_raised["code"].value},
)
return redirect(error_url)
return redirect(idp_sign_in_url)

View file

@ -7,6 +7,8 @@ from baserow_enterprise.sso.saml.exceptions import InvalidSamlResponse
class SAMLResponseSerializer(serializers.Serializer):
query_param_serializer = SsoLoginRequestSerializer
SAMLResponse = serializers.CharField(
required=True, help_text="The encoded SAML response from the IdP."
)
@ -25,11 +27,12 @@ class SAMLResponseSerializer(serializers.Serializer):
parsed_relay_state = urlparse(relay_state)
query_params = dict(parse_qsl(parsed_relay_state.query))
if query_params:
request_data_serializer = SsoLoginRequestSerializer(data=query_params)
request_data_serializer = self.query_param_serializer(data=query_params)
if request_data_serializer.is_valid():
data["saml_request_data"] = request_data_serializer.validated_data
else:
raise InvalidSamlResponse("Invalid RelayState query parameters.")
data["RelayState"] = parsed_relay_state._replace(query="").geturl()
return data

View file

@ -1,5 +1,7 @@
import io
from django.db.models import QuerySet
from rest_framework import serializers
from saml2.xml.schema import XMLSchemaError
from saml2.xml.schema import validate as validate_saml_metadata_schema
@ -9,14 +11,17 @@ from baserow_enterprise.sso.saml.models import SamlAuthProviderModel
def validate_unique_saml_domain(
domain, instance=None, model_class=SamlAuthProviderModel
domain, instance=None, base_queryset: QuerySet | None = None
):
queryset = model_class.objects.filter(domain=domain)
if base_queryset is None:
base_queryset = SamlAuthProviderModel.objects
queryset = base_queryset.filter(domain=domain)
if instance:
queryset = queryset.exclude(id=instance.id)
if queryset.exists():
raise SamlProviderForDomainAlreadyExists(
f"There is already a {model_class.__name__} for this domain."
"There is already a provider for this domain."
)
return domain

View file

@ -230,7 +230,7 @@ class AdminAuthProvidersLoginUrlView(APIView):
)
saml_login_url = urljoin(
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:login")
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise_sso_saml:login")
)
saml_login_url = urlencode_query_params(saml_login_url, query_params)
return Response({"redirect_url": saml_login_url})

View file

@ -7,13 +7,13 @@ from baserow.api.user.serializers import NormalizedEmailField
from baserow.api.user.validators import language_validation
class SsoLoginRequestSerializer(serializers.Serializer):
class BaseSsoLoginRequestSerializer(serializers.Serializer):
email = NormalizedEmailField(
required=False, help_text="The email address of the user."
)
original = serializers.CharField(
required=False,
help_text="The relative part of URL that the user wanted to access.",
help_text="The original URL that the user wanted to access.",
)
language = serializers.CharField(
required=False,
@ -23,6 +23,13 @@ class SsoLoginRequestSerializer(serializers.Serializer):
help_text="An ISO 639 language code (with optional variant) "
"selected by the user. Ex: en-GB.",
)
def to_internal_value(self, instance) -> Dict[str, str]:
data = super().to_internal_value(instance)
return {k: v for k, v in data.items() if v is not None}
class SsoLoginRequestSerializer(BaseSsoLoginRequestSerializer):
workspace_invitation_token = serializers.CharField(
required=False,
help_text="If provided and valid, the user accepts the workspace invitation and"
@ -34,8 +41,5 @@ class SsoLoginRequestSerializer(serializers.Serializer):
if urlparse(value).hostname:
return None
return value
def to_internal_value(self, instance) -> Dict[str, str]:
data = super().to_internal_value(instance)
return {k: v for k, v in data.items() if v is not None}
return value

View file

@ -1,11 +0,0 @@
from django.urls import include, path
from .oauth2 import urls as oauth2_urls
from .saml import urls as saml_urls
app_name = "baserow_enterprise.api.sso"
urlpatterns = [
path("saml/", include(saml_urls, namespace="saml")),
path("oauth2/", include(oauth2_urls, namespace="oauth2")),
]

View file

@ -1,5 +1,7 @@
from contextlib import ContextDecorator
from enum import Enum
from typing import Dict, Optional
from functools import wraps
from typing import Callable, Dict, Optional, Type
from urllib.parse import urljoin, urlparse
from django.conf import settings
@ -25,31 +27,58 @@ class SsoErrorCode(Enum):
SIGNUP_DISABLED = "errorSignupDisabled"
def map_sso_exceptions(mapping: Dict[Exception, SsoErrorCode]):
class map_sso_exceptions(ContextDecorator):
"""
This decorator can be used to map exceptions to SSO error codes. If the
decorated function raises an exception that is in the mapping, the
redirect_to_sign_in_error_page() function will be called with the mapped
error code. If the exception is not in the mapping, it will be raised
normally.
A context manager and decorator to map exceptions to SSO error codes. If the
enclosed code block or decorated function raises an exception that is in the
mapping, the provided redirect function will be called with the mapped error code.
If the exception is not in the mapping, it will be raised normally.
:param mapping: A dictionary that maps exceptions to SSO error codes.
:return: The decorator.
:param on_error: A callable that takes an error code and handles the action.
"""
def decorator(func):
def wrapper(*args, **kwargs):
def __init__(
self,
mapping: Dict[Type[Exception], str],
on_error: Callable[[str], None] | None = None,
):
self.mapping = mapping
if on_error is None:
self.on_error = redirect_to_sign_in_error_page
else:
self.on_error = on_error
def __enter__(self):
pass
# Context manager
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
# No exception occurred
return False
for exception, error_code in self.mapping.items():
if isinstance(exc_value, exception):
self.on_error(error_code)
return True # Swallow the exception after handling it
# If exception not handled, propagate it
return False
# Decorator version
def __call__(self, func):
@wraps(func)
def wrapped_function(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
for exception, error_code in mapping.items():
for exception, error_code in self.mapping.items():
if isinstance(e, exception):
return redirect_to_sign_in_error_page(error_code)
return self.on_error(error_code)
raise e
return wrapper
return decorator
return wrapped_function
def urlencode_query_params(url: str, query_params: Dict[str, str]) -> str:
@ -79,6 +108,7 @@ def redirect_to_sign_in_error_page(
"""
frontend_error_page_url = get_frontend_login_error_url()
if error_code:
frontend_error_page_url = urlencode_query_params(
frontend_error_page_url, {"error": error_code.value}
@ -89,40 +119,64 @@ def redirect_to_sign_in_error_page(
def get_valid_frontend_url(
requested_original_url: Optional[str] = None,
query_params: Optional[Dict[str, str]] = None,
) -> str:
default_frontend_urls: list[str] | None = None,
allow_any_path: bool = True,
):
"""
Returns a valid absolute frontend url based on the original url requested
before the redirection to the login (can be relative or absolute). If the
original url is relative, it will be prefixed with the frontend hostname to
make the IdP redirection work. If the original url is external to Baserow,
the default frontend dashboard url will be returned instead.
original url is relative, it will be prefixed with the default hostname to
make the IdP redirection work. If the original url doesn't match any of the given
default_front_urls, the first default frontend url will be used instead.
:param requested_original_url: The url to which the user should be
redirected after a successful login.
:param query_params: The query parameters to add to the URL.
:param default_frontend_urls: The first one is the default fallback frontend URL.
Baserow one is used if None. Others are also allowed as valid URLs.
:return: The url with the token as a query parameter.
"""
requested_url_parsed = urlparse(requested_original_url or "")
default_frontend_url_parsed = urlparse(get_frontend_default_redirect_url())
if requested_url_parsed.path in ["", "/"]:
# use the default frontend path if the requested one is empty
requested_url_parsed = requested_url_parsed._replace(
path=default_frontend_url_parsed.path
)
if default_frontend_urls is None:
default_frontend_urls = [get_frontend_default_redirect_url()]
default_frontend_urls_parsed = [urlparse(u) for u in default_frontend_urls]
default_frontend_url_parsed = default_frontend_urls_parsed[0]
matching_url = default_frontend_url_parsed
if requested_url_parsed.hostname is None:
# provide a correct absolute url if the requested one is relative
requested_url_parsed = default_frontend_url_parsed._replace(
path=requested_url_parsed.path
)
elif requested_url_parsed.hostname != default_frontend_url_parsed.hostname:
# return the default url if the requested url is external to Baserow
requested_url_parsed = default_frontend_url_parsed
else:
found = False
for allowed_url in default_frontend_urls_parsed:
if requested_url_parsed.hostname == allowed_url.hostname:
matching_url = allowed_url
found = True
if not found:
# None are matching -> redirecting to main homepage
requested_url_parsed = default_frontend_url_parsed
matching_url = default_frontend_url_parsed
if allow_any_path:
if requested_url_parsed.path in ["", "/"]:
# use the default frontend path if the requested one is empty
requested_url_parsed = requested_url_parsed._replace(path=matching_url.path)
elif not requested_url_parsed.geturl().startswith(matching_url.geturl()):
# if using a path that doesn't match the allowed urls, we reset to default url
requested_url_parsed = matching_url
if query_params:
return urlencode_query_params(requested_url_parsed.geturl(), query_params)
return str(requested_url_parsed.geturl())
return requested_url_parsed.geturl()
def urlencode_user_token(frontend_url: str, user: AbstractUser) -> str:

View file

@ -4,7 +4,6 @@ from .admin import urls as admin_urls
from .audit_log import urls as audit_log_urls
from .role import urls as role_urls
from .secure_file_serve import urls as secure_file_serve_urls
from .sso import urls as sso_urls
from .teams import urls as teams_urls
app_name = "baserow_enterprise.api"
@ -13,7 +12,6 @@ urlpatterns = [
path("teams/", include(teams_urls, namespace="teams")),
path("role/", include(role_urls, namespace="role")),
path("admin/", include(admin_urls, namespace="admin")),
path("sso/", include(sso_urls, namespace="sso")),
path("audit-log/", include(audit_log_urls, namespace="audit_log")),
path("files/", include(secure_file_serve_urls, namespace="files")),
]

View file

@ -174,6 +174,12 @@ class BaserowEnterpriseConfig(AppConfig):
LocalBaserowPasswordAppAuthProviderType()
)
from baserow_enterprise.integrations.common.sso.saml.app_auth_provider_types import (
SamlAppAuthProviderType,
)
app_auth_provider_type_registry.register(SamlAppAuthProviderType())
from baserow.contrib.builder.elements.registries import element_type_registry
from baserow_enterprise.builder.elements.element_types import (
AuthFormElementType,

View file

@ -0,0 +1,113 @@
from typing import List
from urllib.parse import urljoin
from django.conf import settings
from django.urls import include, path, reverse
from baserow.core.app_auth_providers.auth_provider_types import AppAuthProviderType
from baserow.core.app_auth_providers.types import AppAuthProviderTypeDict
from baserow_enterprise.api.sso.saml.validators import validate_unique_saml_domain
from baserow_enterprise.integrations.local_baserow.user_source_types import (
LocalBaserowUserSourceType,
)
from baserow_enterprise.sso.saml.auth_provider_types import SamlAuthProviderTypeMixin
from .models import SamlAppAuthProviderModel
class SamlAppAuthProviderType(SamlAuthProviderTypeMixin, AppAuthProviderType):
"""
The SAML authentication provider type allows users to login using SAML.
"""
model_class = SamlAppAuthProviderModel
compatible_user_source_types = [LocalBaserowUserSourceType.type]
class SerializedDict(
AppAuthProviderTypeDict, SamlAuthProviderTypeMixin.SamlSerializedDict
):
...
@property
def allowed_fields(self) -> List[str]:
return SamlAuthProviderTypeMixin.saml_allowed_fields
@property
def serializer_field_names(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_names
public_serializer_field_names = []
@property
def serializer_field_overrides(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_overrides
public_serializer_field_overrides = {}
def get_api_urls(self):
from baserow_enterprise.api.integrations.common.sso.saml import urls
return [
path(
"user-source/<str:user_source_uid>/sso/saml/",
include(urls, namespace="sso_saml"),
)
]
def before_create(self, user, **values):
user_source = values["user_source"]
if "domain" in values:
validate_unique_saml_domain(
values["domain"],
base_queryset=SamlAppAuthProviderModel.objects.filter(
user_source=user_source
),
)
return super().before_create(user, **values)
def before_update(self, user, provider, **values):
if "domain" in values:
user_source = values.get("user_source", provider.user_source)
validate_unique_saml_domain(
values["domain"],
provider,
base_queryset=SamlAppAuthProviderModel.objects.filter(
user_source=user_source
),
)
return super().before_update(user, provider, **values)
@classmethod
def get_acs_absolute_url(
cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None
):
"""
Returns the ACS url for SAML authentication purpose. The user is redirected
to this URL after a successful login.
"""
return urljoin(
settings.PUBLIC_BACKEND_URL,
reverse(
"api:user_sources:sso_saml:acs",
kwargs={"user_source_uid": auth_provider.user_source.uid},
),
)
@classmethod
def get_login_absolute_url(
cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None
):
"""
Returns the login URL for this auth_provider. The login URL is used to initiate
the Saml login process.
"""
return urljoin(
settings.PUBLIC_BACKEND_URL,
reverse(
"api:user_sources:sso_saml:login",
kwargs={"user_source_uid": auth_provider.user_source.uid},
),
)

View file

@ -0,0 +1,8 @@
from baserow_enterprise.integrations.common.sso.saml.models import (
SamlAppAuthProviderModel,
)
from baserow_enterprise.sso.saml.handler import SamlAuthProviderHandler
class SamlAppAuthProviderHandler(SamlAuthProviderHandler):
model_class = SamlAppAuthProviderModel

View file

@ -0,0 +1,8 @@
from baserow.core.app_auth_providers.models import AppAuthProvider
from baserow_enterprise.sso.saml.models import SamlAuthProviderModelMixin
class SamlAppAuthProviderModel(SamlAuthProviderModelMixin, AppAuthProvider):
# Restore ordering
class Meta(AppAuthProvider.Meta):
ordering = ["id"]

View file

@ -31,6 +31,8 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType):
]
serializer_field_names = ["password_field_id"]
public_serializer_field_names = []
allowed_fields = ["password_field"]
serializer_field_overrides = {
@ -40,6 +42,7 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType):
help_text="The id of the field to use as password for the user account.",
),
}
public_serializer_field_overrides = {}
class SerializedDict(AppAuthProviderTypeDict):
password_field_id: int
@ -151,13 +154,6 @@ class LocalBaserowPasswordAppAuthProviderType(AppAuthProviderType):
instance.password_field = None
instance.save()
def get_login_options(self, **kwargs) -> Dict[str, Any]:
"""
Not implemented yet.
"""
return {}
def get_or_create_user_and_sign_in(
self, auth_provider: AuthProviderModelSubClass, user_info: Dict[str, Any]
) -> Tuple[AbstractUser, bool]:

View file

@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional
from django.contrib.auth.models import AbstractUser
from loguru import logger
from rest_framework import serializers
from baserow.api.exceptions import RequestBodyValidationException
@ -17,6 +18,7 @@ from baserow.contrib.database.fields.field_types import (
)
from baserow.contrib.database.fields.handler import FieldHandler
from baserow.contrib.database.fields.registries import FieldType
from baserow.contrib.database.rows.actions import CreateRowsActionType
from baserow.contrib.database.rows.operations import ReadDatabaseRowOperationType
from baserow.contrib.database.search.handler import SearchHandler
from baserow.contrib.database.table.exceptions import TableDoesNotExist
@ -450,9 +452,9 @@ class LocalBaserowUserSourceType(UserSourceType):
return (
f"{user_source.id}"
f"_{user_source.table_id if user_source.table_id else '?'}"
f"_{user_source.email_field_id if user_source.email_field_id else '?'}"
f"_{user_source.role_field_id if user_source.role_field_id else '?'}"
f"_{user_source.table_id if user_source.table_id else '0'}"
f"_{user_source.email_field_id if user_source.email_field_id else '0'}"
f"_{user_source.role_field_id if user_source.role_field_id else '0'}"
)
def role_type_is_allowed(self, role_field: Optional[FieldType]) -> bool:
@ -596,6 +598,54 @@ class LocalBaserowUserSourceType(UserSourceType):
raise UserNotFound()
def create_user(self, user_source: UserSource, email, name, role=None):
"""
Creates the user in the configured table.
"""
if not self.is_configured(user_source):
raise UserSourceImproperlyConfigured()
try:
# Use table handler to exclude trashed table
table = TableHandler().get_table(user_source.table_id)
except TableDoesNotExist as exc:
# As we CASCADE when a table is deleted, the table shouldn't
# exist only if it's trashed and not yet deleted.
raise UserSourceImproperlyConfigured("The table doesn't exist.") from exc
integration = user_source.integration.specific
model = table.get_model()
values = {
user_source.name_field.db_column: name,
user_source.email_field.db_column: email,
}
if role and user_source.role_field_id:
values[user_source.role_field.db_column] = role
try:
# Use the action to keep track on what's going on
(user,) = CreateRowsActionType.do(
user=integration.authorized_user,
table=table,
rows_values=[values],
model=model,
)
except Exception as e:
logger.exception(e)
raise ("Error while creating the user") from e
return UserSourceUser(
user_source,
user,
user.id,
getattr(user, user_source.name_field.db_column),
getattr(user, user_source.email_field.db_column),
self.get_user_role(user, user_source),
)
def authenticate(self, user_source: UserSource, **kwargs):
"""
Authenticates using the given credentials. It uses the password auth provider.

View file

@ -0,0 +1,81 @@
# Generated by Django 5.0.9 on 2024-12-05 11:23
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("baserow_enterprise", "0033_samlauthprovidermodel_email_attr_key_and_more"),
("core", "0093_alter_appauthprovider_options_and_more"),
]
operations = [
migrations.CreateModel(
name="SamlAppAuthProviderModel",
fields=[
(
"appauthprovider_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="core.appauthprovider",
),
),
(
"metadata",
models.TextField(
blank=True,
help_text="The XML metadata downloaded from the metadata_url.",
),
),
(
"is_verified",
models.BooleanField(
default=False,
help_text="This will be set to True only after a user successfully login with this IdP. This must be True to disable normal username/password login and make SAML the only authentication provider. ",
),
),
(
"email_attr_key",
models.CharField(
db_default="user.email",
default="user.email",
help_text="The key in the SAML response that contains the email address of the user. If this is not set, the email will be taken from the user's profile.",
max_length=32,
),
),
(
"first_name_attr_key",
models.CharField(
db_default="user.first_name",
default="user.first_name",
help_text="The key in the SAML response that contains the first name of the user. If this is not set, the first name will be taken from the user's profile.",
max_length=32,
),
),
(
"last_name_attr_key",
models.CharField(
blank=True,
db_default="user.last_name",
default="user.last_name",
help_text="The key in the SAML response that contains the last name of the user. If this is not set, the last name will be taken from the user's profile.",
max_length=32,
),
),
],
options={
"ordering": ["id"],
"abstract": False,
},
bases=("core.appauthprovider", models.Model),
),
migrations.AlterModelOptions(
name="samlauthprovidermodel",
options={"ordering": ["domain"]},
),
]

View file

@ -1,6 +1,12 @@
from baserow_enterprise.builder.elements.models import AuthFormElement
from baserow_enterprise.data_sync.models import LocalBaserowTableDataSync
from baserow_enterprise.integrations.models import LocalBaserowUserSource
from baserow_enterprise.integrations.common.sso.saml.models import (
SamlAppAuthProviderModel,
)
from baserow_enterprise.integrations.models import (
LocalBaserowPasswordAppAuthProvider,
LocalBaserowUserSource,
)
from baserow_enterprise.role.models import Role, RoleAssignment
from baserow_enterprise.teams.models import Team, TeamSubject
@ -12,4 +18,6 @@ __all__ = [
"LocalBaserowUserSource",
"AuthFormElement",
"LocalBaserowTableDataSync",
"LocalBaserowPasswordAppAuthProvider",
"SamlAppAuthProviderModel",
]

View file

@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Tuple
from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase
from django.urls import reverse
from django.urls import include, path, reverse
import requests
from loguru import logger
@ -30,6 +30,8 @@ from .models import (
OAUTH_BACKEND_URL = settings.PUBLIC_BACKEND_URL
_is_url_already_loaded = False
@dataclass
class WellKnownUrls:
@ -50,6 +52,20 @@ class OAuth2AuthProviderMixin:
- self.SCOPE
"""
def get_api_urls(self):
global _is_url_already_loaded
from baserow_enterprise.api.sso.oauth2 import urls
if not _is_url_already_loaded:
_is_url_already_loaded = True
# We need to register this only once
return [
path("sso/oauth2/", include(urls, namespace="enterprise_sso_oauth2"))
]
else:
return []
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
if not is_sso_feature_active():
return None
@ -64,7 +80,7 @@ class OAuth2AuthProviderMixin:
{
"redirect_url": urllib.parse.urljoin(
OAUTH_BACKEND_URL,
reverse("api:enterprise:sso:oauth2:login", args=(instance.id,)),
reverse("api:enterprise_sso_oauth2:login", args=(instance.id,)),
),
"name": instance.name,
"type": self.type,
@ -144,7 +160,7 @@ class OAuth2AuthProviderMixin:
redirect_uri = urllib.parse.urljoin(
OAUTH_BACKEND_URL,
reverse("api:enterprise:sso:oauth2:callback", args=(instance.id,)),
reverse("api:enterprise_sso_oauth2:callback", args=(instance.id,)),
)
if "oauth_state" in session:
return OAuth2Session(

View file

@ -1,8 +1,9 @@
from typing import Any, Dict, List, Optional
from abc import abstractmethod
from typing import Any, Dict, List, Optional, TypedDict
from urllib.parse import urljoin
from django.conf import settings
from django.urls import reverse
from django.urls import include, path, reverse
from rest_framework import serializers
@ -26,43 +27,33 @@ from baserow_enterprise.sso.utils import is_sso_feature_active
from .models import SamlAuthProviderModel
class SamlAuthProviderType(AuthProviderType):
class SamlAuthProviderTypeMixin:
"""
The SAML authentication provider type allows users to login using SAML.
"""
type = "saml"
model_class = SamlAuthProviderModel
allowed_fields: List[str] = [
"id",
"domain",
"type",
"enabled",
class SamlSerializedDict(TypedDict):
metadata: Dict
is_verified: bool
saml_allowed_fields: List[str] = [
"metadata",
"is_verified",
"email_attr_key",
"first_name_attr_key",
"last_name_attr_key",
]
serializer_field_names = [
"domain",
saml_serializer_field_names = [
"metadata",
"enabled",
"is_verified",
"email_attr_key",
"first_name_attr_key",
"last_name_attr_key",
]
serializer_field_overrides = {
"domain": serializers.CharField(
validators=[validate_domain],
required=True,
help_text="The email domain registered with this provider.",
),
"enabled": serializers.BooleanField(
help_text="Whether the provider is enabled or not.",
required=False,
),
saml_serializer_field_overrides = {
"metadata": serializers.CharField(
validators=[validate_saml_metadata],
required=True,
@ -98,21 +89,61 @@ class SamlAuthProviderType(AuthProviderType):
SamlProviderForDomainAlreadyExists: ERROR_SAML_PROVIDER_FOR_DOMAIN_ALREADY_EXISTS
}
def before_create(self, user, **values):
validate_unique_saml_domain(values["domain"])
return super().before_create(user, **values)
@classmethod
@abstractmethod
def get_acs_absolute_url(
cls, auth_provider: "SamlAuthProviderTypeMixin | None" = None
):
"""
Returns the ACS url for SAML authentication purpose. The user is redirected
to this URL after a successful login.
"""
def before_update(self, user, provider, **values):
if "domain" in values:
validate_unique_saml_domain(values["domain"], provider)
return super().before_update(user, provider, **values)
@classmethod
@abstractmethod
def get_login_absolute_url(cls):
"""
Returns the login URL for this auth_provider. The login URL is used to initiate
the Saml login process.
"""
class SamlAuthProviderType(SamlAuthProviderTypeMixin, AuthProviderType):
"""
The SAML authentication provider type allows users to login using SAML.
"""
model_class = SamlAuthProviderModel
@property
def allowed_fields(self) -> List[str]:
return SamlAuthProviderTypeMixin.saml_allowed_fields
@property
def serializer_field_names(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_names
@property
def serializer_field_overrides(self):
return SamlAuthProviderTypeMixin.saml_serializer_field_overrides | {
"domain": serializers.CharField(
validators=[validate_domain],
required=True,
help_text="The email domain registered with this provider.",
),
}
def get_api_urls(self):
from baserow_enterprise.api.sso.saml import urls
return [path("sso/saml/", include(urls, namespace="enterprise_sso_saml"))]
def get_login_options(self, **kwargs) -> Optional[Dict[str, Any]]:
single_sign_on_feature_active = is_sso_feature_active()
if not single_sign_on_feature_active:
return None
configured_domains = SamlAuthProviderModel.objects.filter(enabled=True).count()
configured_domains = self.model_class.objects.filter(enabled=True).count()
if not configured_domains:
return None
@ -130,16 +161,28 @@ class SamlAuthProviderType(AuthProviderType):
"default_redirect_url": default_redirect_url,
}
def before_create(self, user, **values):
validate_unique_saml_domain(values["domain"])
return super().before_create(user, **values)
def before_update(self, user, provider, **values):
if "domain" in values:
validate_unique_saml_domain(values["domain"], provider)
return super().before_update(user, provider, **values)
@classmethod
def get_acs_absolute_url(cls):
def get_acs_absolute_url(
cls, auth_provider: SamlAuthProviderTypeMixin | None = None
):
return urljoin(
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:acs")
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise_sso_saml:acs")
)
@classmethod
def get_login_absolute_url(cls):
return urljoin(
settings.PUBLIC_BACKEND_URL, reverse("api:enterprise:sso:saml:login")
settings.PUBLIC_BACKEND_URL,
reverse("api:enterprise_sso_saml:login"),
)
def export_serialized(self) -> Dict[str, Any]:

View file

@ -4,6 +4,7 @@ from typing import Any, Dict, Optional
from django.conf import settings
from django.contrib.auth.models import AbstractUser
from django.db.models import Model, QuerySet
from defusedxml import ElementTree
from loguru import logger
@ -13,9 +14,11 @@ from saml2.config import Config as Saml2Config
from saml2.response import AuthnResponse
from baserow.core.auth_provider.types import UserInfo
from baserow.core.registries import auth_provider_type_registry
from baserow_enterprise.api.sso.utils import get_valid_frontend_url
from baserow_enterprise.sso.saml.models import SamlAuthProviderModel
from baserow_enterprise.sso.saml.models import (
SamlAuthProviderModel,
SamlAuthProviderModelMixin,
)
from .exceptions import (
InvalidSamlConfiguration,
@ -25,10 +28,12 @@ from .exceptions import (
class SamlAuthProviderHandler:
model_class: Model = SamlAuthProviderModel
@classmethod
def prepare_saml_client(
cls,
saml_auth_provider: SamlAuthProviderModel,
saml_auth_provider: SamlAuthProviderModelMixin,
) -> Saml2Client:
"""
Returns a SAML client with the correct configuration for the given
@ -39,10 +44,9 @@ class SamlAuthProviderHandler:
:return: The SAML client that can be used to authenticate the user.
"""
saml_provider_type = auth_provider_type_registry.get_by_model(
saml_auth_provider
acs_url = saml_auth_provider.get_type().get_acs_absolute_url(
saml_auth_provider.specific
)
acs_url = saml_provider_type.get_acs_absolute_url()
saml_settings: Dict[str, Any] = {
"entityid": acs_url,
@ -95,9 +99,8 @@ class SamlAuthProviderHandler:
@classmethod
def get_saml_auth_provider_from_saml_response(
cls,
raw_saml_response: str,
) -> SamlAuthProviderModel:
cls, raw_saml_response: str, base_queryset: QuerySet | None = None
) -> SamlAuthProviderModelMixin:
"""
Parses the saml response and returns the authentication provider that needs to
be used to authenticate the user.
@ -120,7 +123,10 @@ class SamlAuthProviderHandler:
except (ElementTree.ParseError, AttributeError):
raise InvalidSamlResponse("Impossible decode SAML response.")
saml_auth_provider = SamlAuthProviderModel.objects.filter(
if base_queryset is None:
base_queryset = cls.model_class.objects
saml_auth_provider = base_queryset.filter(
enabled=True, metadata__contains=issuer
).first()
if not saml_auth_provider:
@ -130,7 +136,7 @@ class SamlAuthProviderHandler:
@classmethod
def get_user_info_from_authn_user_identity(
cls,
saml_auth_provider: SamlAuthProviderModel,
saml_auth_provider: SamlAuthProviderModelMixin,
authn_identity: Dict[str, str],
saml_request_data: Optional[Dict[str, str]] = None,
) -> UserInfo:
@ -172,8 +178,7 @@ class SamlAuthProviderHandler:
@classmethod
def get_saml_auth_provider_from_email(
cls,
email: Optional[str] = None,
cls, email: Optional[str] = None, base_queryset: QuerySet | None = None
) -> SamlAuthProviderModel:
"""
It returns the Saml Identity Provider for the the given email address.
@ -187,20 +192,24 @@ class SamlAuthProviderHandler:
address provided.
"""
base_queryset = SamlAuthProviderModel.objects.filter(enabled=True)
if base_queryset is None:
base_queryset = cls.model_class.objects
queryset = base_queryset.filter(enabled=True)
if email is not None:
try:
domain = email.rsplit("@", 1)[1]
except IndexError:
raise InvalidSamlRequest("Invalid mail address provided.")
base_queryset = base_queryset.filter(domain=domain)
queryset = queryset.filter(domain=domain)
try:
return base_queryset.get()
return queryset.get()
except (
SamlAuthProviderModel.DoesNotExist,
SamlAuthProviderModel.MultipleObjectsReturned,
cls.model_class.DoesNotExist,
cls.model_class.MultipleObjectsReturned,
):
raise InvalidSamlRequest("No valid SAML identity provider found.")
@ -213,7 +222,10 @@ class SamlAuthProviderHandler:
@classmethod
def sign_in_user_from_saml_response(
cls, saml_response: str, saml_request_data: Optional[Dict[str, str]] = None
cls,
saml_response: str,
saml_request_data: Optional[Dict[str, str]] = None,
base_queryset: QuerySet | None = None,
) -> AbstractUser:
"""
Signs in the user using the SAML response received from the identity
@ -230,7 +242,7 @@ class SamlAuthProviderHandler:
try:
saml_auth_provider = cls.get_saml_auth_provider_from_saml_response(
saml_response
saml_response, base_queryset=base_queryset
)
saml_client = cls.prepare_saml_client(saml_auth_provider)
@ -249,11 +261,10 @@ class SamlAuthProviderHandler:
logger.exception(exc)
raise InvalidSamlResponse(str(exc))
saml_provider_type = saml_auth_provider.get_type()
(
user,
_,
) = saml_provider_type.get_or_create_user_and_sign_in(
) = saml_auth_provider.get_type().get_or_create_user_and_sign_in(
saml_auth_provider, idp_provided_user_info
)
@ -268,7 +279,7 @@ class SamlAuthProviderHandler:
@classmethod
def get_sign_in_url_for_auth_provider(
cls,
saml_auth_provider: SamlAuthProviderModel,
saml_auth_provider: SamlAuthProviderModelMixin,
original_url: str = "",
) -> str:
"""
@ -284,6 +295,7 @@ class SamlAuthProviderHandler:
"""
saml_client = cls.prepare_saml_client(saml_auth_provider)
_, info = saml_client.prepare_for_authenticate(relay_state=original_url)
for key, value in info["headers"]:
@ -294,24 +306,36 @@ class SamlAuthProviderHandler:
raise InvalidSamlConfiguration("No Location header found in SAML response.")
@classmethod
def get_sign_in_url(cls, query_params: Dict[str, str]) -> str:
def get_sign_in_url(
cls,
query_params: Dict[str, str],
base_queryset: QuerySet | None = None,
redirect_to: str | None = None,
) -> str:
"""
Returns the sign in url for the correct identity provider. This url is
used to initiate the SAML authentication flow from the service provider.
:param query_params: A dict containing the query parameters from the
sign in request.
:param redirect_to: if set, used as relay state url.
:raises InvalidSamlRequest: If the email address is invalid.
:raises InvalidSamlConfiguration: If the SAML configuration is invalid.
:return: The redirect url to the identity provider.
"""
user_email = query_params.pop("email", None)
original_url = query_params.pop("original", "")
valid_relay_state_url = get_valid_frontend_url(original_url, query_params)
if redirect_to:
valid_relay_state_url = redirect_to
else:
original_url = query_params.pop("original", "")
valid_relay_state_url = get_valid_frontend_url(original_url, query_params)
try:
saml_auth_provider = cls.get_saml_auth_provider_from_email(user_email)
saml_auth_provider = cls.get_saml_auth_provider_from_email(
user_email, base_queryset=base_queryset
)
return cls.get_sign_in_url_for_auth_provider(
saml_auth_provider, valid_relay_state_url
)

View file

@ -5,7 +5,7 @@ from django.dispatch import receiver
from baserow.core.auth_provider.models import AuthProviderModel
class SamlAuthProviderModel(AuthProviderModel):
class SamlAuthProviderModelMixin(models.Model):
metadata = models.TextField(
blank=True, help_text="The XML metadata downloaded from the metadata_url."
)
@ -46,6 +46,15 @@ class SamlAuthProviderModel(AuthProviderModel):
),
)
class Meta:
abstract = True
class SamlAuthProviderModel(SamlAuthProviderModelMixin, AuthProviderModel):
# Restore ordering
class Meta(AuthProviderModel.Meta):
ordering = ["domain"]
@receiver(pre_save, sender=SamlAuthProviderModel)
def reset_is_verified_if_metadata_changed(sender, instance, **kwargs):

View file

@ -52,7 +52,7 @@ def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture):
)
response = api_client.get(
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": provider.id}),
reverse("api:enterprise_sso_oauth2:login", kwargs={"provider_id": provider.id}),
format="json",
)
@ -68,7 +68,7 @@ def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture):
def test_oauth2_login_provider_doesnt_exist(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
response = api_client.get(
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": 50}),
reverse("api:enterprise_sso_oauth2:login", kwargs={"provider_id": 50}),
format="json",
)
@ -88,7 +88,7 @@ def test_oauth2_login_with_url_param(api_client, enterprise_data_fixture):
)
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:login",
"api:enterprise_sso_oauth2:login",
kwargs={"provider_id": provider.id},
)
+ "?original=templates&workspace_invitation_token=t&language=en",
@ -118,7 +118,7 @@ def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture)
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback", kwargs={"provider_id": provider.id}
"api:enterprise_sso_oauth2:callback", kwargs={"provider_id": provider.id}
),
format="json",
)
@ -135,7 +135,7 @@ def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture)
def test_oauth2_callback_provider_doesnt_exist(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
response = api_client.get(
reverse("api:enterprise:sso:oauth2:callback", kwargs={"provider_id": 50}),
reverse("api:enterprise_sso_oauth2:callback", kwargs={"provider_id": 50}),
format="json",
)
@ -167,7 +167,7 @@ def test_oauth2_callback_signup_success(api_client, enterprise_data_fixture):
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -211,7 +211,7 @@ def test_oauth2_callback_signup_set_language(api_client, enterprise_data_fixture
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -263,7 +263,7 @@ def test_oauth2_callback_signup_workspace_invitation(
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -318,7 +318,7 @@ def test_oauth2_callback_signup_workspace_invitation_email_mismatch(
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -360,7 +360,7 @@ def test_oauth2_callback_signup_disabled(api_client, enterprise_data_fixture):
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -408,7 +408,7 @@ def test_oauth2_callback_login_success(
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -460,7 +460,7 @@ def test_oauth2_callback_login_deactivated_user(
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -505,7 +505,7 @@ def test_oauth2_callback_login_different_provider(
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
@ -549,7 +549,7 @@ def test_oauth2_callback_login_auth_flow_error(
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
"api:enterprise_sso_oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",

View file

@ -33,8 +33,8 @@ def test_saml_provider_get_login_url(api_client, data_fixture, enterprise_data_f
auth_provider_1 = enterprise_data_fixture.create_saml_auth_provider(
domain="test1.com"
)
auth_provider_login = reverse("api:enterprise:sso:saml:login")
auth_provider_login_url = reverse("api:enterprise:sso:saml:login_url")
auth_provider_login = reverse("api:enterprise_sso_saml:login")
auth_provider_login_url = reverse("api:enterprise_sso_saml:login_url")
_, unauthorized_token = data_fixture.create_user_and_token()
@ -147,7 +147,7 @@ def test_user_cannot_initiate_saml_sso_without_enterprise_license(
api_client, enterprise_data_fixture
):
enterprise_data_fixture.create_saml_auth_provider(domain="test1.com")
response = api_client.get(reverse("api:enterprise:sso:saml:login"))
response = api_client.get(reverse("api:enterprise_sso_saml:login"))
assert response.status_code == HTTP_302_FOUND
assert (
response.headers["Location"]
@ -173,7 +173,7 @@ def test_user_can_initiate_saml_sso_with_enterprise_license(
enterprise_data_fixture.create_saml_auth_provider(domain="test1.com")
enterprise_data_fixture.enable_enterprise()
sp_sso_saml_login_url = reverse("api:enterprise:sso:saml:login")
sp_sso_saml_login_url = reverse("api:enterprise_sso_saml:login")
original_relative_url = "database/1/table/1/"
request_query_string = urlencode({"original": original_relative_url})
@ -204,7 +204,7 @@ def test_user_can_initiate_saml_sso_with_enterprise_license(
@override_settings(DEBUG=True)
def test_saml_assertion_consumer_service(api_client, enterprise_data_fixture):
user, _ = enterprise_data_fixture.create_enterprise_admin_user_and_token()
sp_sso_saml_acs_url = reverse("api:enterprise:sso:saml:acs")
sp_sso_saml_acs_url = reverse("api:enterprise_sso_saml:acs")
(
metadata,

View file

@ -654,6 +654,95 @@ def test_create_user_source_field_from_other_table(api_client, data_fixture):
assert response.json()["detail"]["name_field_id"][0]["code"] == "missing_table"
@pytest.mark.django_db
def test_user_source_create_user(data_fixture):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
database = data_fixture.create_database_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(
application=application, user=user
)
table_from_same_workspace1, fields, rows = data_fixture.build_table(
user=user,
database=database,
columns=[
("Email", "text"),
("Name", "text"),
],
rows=[
["test@baserow.io", "Test"],
],
)
email_field, name_field = fields
user_source = data_fixture.create_user_source_with_first_type(
integration=integration,
name="Test name",
table=table_from_same_workspace1,
email_field=email_field,
name_field=name_field,
uid="uid",
)
created_user = user_source.get_type().create_user(
user_source, "test2@baserow.io", "Test2"
)
model = table_from_same_workspace1.get_model()
assert created_user.email == "test2@baserow.io"
assert model.objects.count() == 2
assert getattr(model.objects.last(), email_field.db_column) == "test2@baserow.io"
@pytest.mark.django_db
def test_user_source_create_user_w_role(data_fixture):
user = data_fixture.create_user()
workspace = data_fixture.create_workspace(user=user)
application = data_fixture.create_builder_application(workspace=workspace)
database = data_fixture.create_database_application(workspace=workspace)
integration = data_fixture.create_local_baserow_integration(
application=application, user=user
)
table_from_same_workspace1, fields, rows = data_fixture.build_table(
user=user,
database=database,
columns=[
("Email", "text"),
("Name", "text"),
("Role", "text"),
],
rows=[
["test@baserow.io", "Test", "role1"],
],
)
email_field, name_field, role_field = fields
user_source = data_fixture.create_user_source_with_first_type(
integration=integration,
name="Test name",
table=table_from_same_workspace1,
email_field=email_field,
name_field=name_field,
role_field=role_field,
uid="uid",
)
created_user = user_source.get_type().create_user(
user_source, "test2@baserow.io", "Test2", "role2"
)
model = table_from_same_workspace1.get_model()
assert created_user.role == "role2"
assert getattr(model.objects.last(), role_field.db_column) == "role2"
@pytest.mark.django_db
def test_export_user_source(data_fixture):
user = data_fixture.create_user()
@ -803,7 +892,7 @@ def test_create_local_baserow_user_source_w_auth_providers(api_client, data_fixt
{
"type": "local_baserow_password",
"enabled": True,
"domain": "test1",
"domain": "test1.com",
"password_field_id": password_field.id,
}
],
@ -820,7 +909,7 @@ def test_create_local_baserow_user_source_w_auth_providers(api_client, data_fixt
assert response_json["auth_providers"] == [
{
"domain": "test1",
"domain": "test1.com",
"id": first.id,
"password_field_id": password_field.id,
"type": "local_baserow_password",

View file

@ -0,0 +1,100 @@
from baserow_enterprise.api.sso.utils import get_valid_frontend_url
def test_get_valid_front_url():
assert get_valid_frontend_url() == "http://localhost:3000/dashboard"
assert (
get_valid_frontend_url("http://localhost:3000/dashboard")
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://localhost:3000/dashboard/after")
== "http://localhost:3000/dashboard/after"
)
assert (
get_valid_frontend_url("http://localhost:3000/other")
== "http://localhost:3000/other"
)
assert (
get_valid_frontend_url("http://localhost:3000/other", allow_any_path=False)
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://localhost:3000/")
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://something.com/")
== "http://localhost:3000/dashboard"
)
assert (
get_valid_frontend_url("http://something.com/dashboard/test")
== "http://localhost:3000/dashboard"
)
def test_get_valid_front_url_with_defaults():
defaults = ["https://test.com/toto", "http://random.net/"]
assert (
get_valid_frontend_url(default_frontend_urls=defaults)
== "https://test.com/toto"
)
assert (
get_valid_frontend_url("https://test.com/toto", default_frontend_urls=defaults)
== "https://test.com/toto"
)
assert (
get_valid_frontend_url("http://random.net/", default_frontend_urls=defaults)
== "http://random.net/"
)
assert (
get_valid_frontend_url(
"https://test.com/toto/subpath/", default_frontend_urls=defaults
)
== "https://test.com/toto/subpath/"
)
assert (
get_valid_frontend_url("https://test.com/titi/", default_frontend_urls=defaults)
== "https://test.com/titi/"
)
assert (
get_valid_frontend_url(
"https://test.com/titi/",
default_frontend_urls=defaults,
)
== "https://test.com/titi/"
)
assert (
get_valid_frontend_url(
"https://test.com/titi/",
default_frontend_urls=defaults,
allow_any_path=False,
)
== "https://test.com/toto"
)
assert (
get_valid_frontend_url("http://random.net/", default_frontend_urls=defaults)
== "http://random.net/"
)
assert (
get_valid_frontend_url("http://random.net/path", default_frontend_urls=defaults)
== "http://random.net/path"
)
assert (
get_valid_frontend_url("http://other.net/path", default_frontend_urls=defaults)
== "https://test.com/toto"
)
def test_get_valid_front_url_w_params():
assert (
get_valid_frontend_url(query_params={"test": "value"})
== "http://localhost:3000/dashboard?test=value"
)
assert (
get_valid_frontend_url(
"http://localhost:3000/dashboard", query_params={"test": "value"}
)
== "http://localhost:3000/dashboard?test=value"
)

View file

@ -14,3 +14,6 @@
@import 'long_text_field';
@import 'highest_role_field';
@import 'auth_form_element';
@import 'common_saml_setting_form';
@import 'common_saml_setting_modal';
@import 'saml_auth_link';

View file

@ -15,3 +15,11 @@
font-size: var(--label-font-size, 13px);
padding-top: 0.5em;
}
.auth-form-element__provider {
padding-top: 16px;
&:first-child {
padding-top: 0;
}
}

View file

@ -0,0 +1,13 @@
.common-saml-setting-form {
@include elevation($elevation-low);
@include rounded($rounded);
padding: 20px;
border: 1px solid $palette-neutral-200;
display: flex;
cursor: pointer;
&--error {
border-color: $palette-red-600;
}
}

View file

@ -0,0 +1,24 @@
.common-saml-setting-modal__url-block {
@include rounded($rounded);
position: relative;
font-family: monospace;
padding: 12px 16px;
background-color: $palette-neutral-50;
border: 1px solid $palette-neutral-400;
}
.common-saml-setting-modal__url {
padding-bottom: 8px;
cursor: pointer;
display: flex;
gap: 5px;
}
.common-saml-setting-modal__url-domain {
color: $palette-neutral-800;
}
.common-saml-setting-modal__url-dest {
@extend %ellipsis;
}

View file

@ -0,0 +1,6 @@
.saml-auth-link,
.saml-auth-link__modal-footer {
display: flex;
flex-direction: column;
font-size: var(--label-font-size, 13px);
}

View file

@ -42,68 +42,150 @@ export class PasswordAuthProviderType extends AuthProviderType {
return null
}
/**
* We can create only one password provider.
*/
canCreateNew(authProviders) {
return (
!authProviders[this.getType()] ||
authProviders[this.getType()].length === 0
)
}
getOrder() {
return 1
}
}
export class SamlAuthProviderType extends AuthProviderType {
static getType() {
return 'saml'
}
getIcon() {
return SAMLIcon
}
getVerifiedIcon() {
return VerifiedProviderIcon
}
getName() {
return 'SSO SAML provider'
}
getProviderName(provider) {
return `SSO SAML: ${provider.domain}`
}
getLoginActionComponent() {
return SamlLoginAction
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return SamlSettingsForm
export const SamlAuthProviderTypeMixin = (Base) =>
class extends Base {
static getType() {
return 'saml'
}
getIcon() {
return SAMLIcon
}
getVerifiedIcon() {
return VerifiedProviderIcon
}
getName() {
return this.app.i18n.t('authProviderTypes.saml')
}
getProviderName(provider) {
if (provider.domain) {
return this.app.i18n.t('authProviderTypes.ssoSamlProviderName', {
domain: provider.domain,
})
} else {
return this.app.i18n.t(
'authProviderTypes.ssoSamlProviderNameUnconfigured'
)
}
}
getLoginActionComponent() {
return SamlLoginAction
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return SamlSettingsForm
}
getRelayStateUrl() {
return this.app.store.getters['authProviderAdmin/getType'](this.getType())
.relayStateUrl
}
getAcsUrl() {
return this.app.store.getters['authProviderAdmin/getType'](this.getType())
.acsUrl
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
redirectUrl: authProviderOption.redirect_url,
domainRequired: authProviderOption.domain_required,
...loginOptions,
}
}
handleServerError(vueComponentInstance, error) {
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
for (const [key, value] of Object.entries(error.handler.detail || {})) {
vueComponentInstance.serverErrors[key] = value
}
return true
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
acsUrl: authProviderType.acs_url,
relayStateUrl: authProviderType.relay_state_url,
...populated,
}
}
}
export class SamlAuthProviderType extends SamlAuthProviderTypeMixin(
AuthProviderType
) {
getOrder() {
return 50
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
redirectUrl: authProviderOption.redirect_url,
domainRequired: authProviderOption.domain_required,
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
acsUrl: authProviderType.acs_url,
relayStateUrl: authProviderType.relay_state_url,
...populated,
}
}
}
export class GoogleAuthProviderType extends AuthProviderType {
export const OAuth2AuthProviderTypeMixin = (Base) =>
class extends Base {
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getCallbackUrl(authProvider) {
if (!authProvider.id) {
const nextProviderId =
this.app.store.getters['authProviderAdmin/getNextProviderId']
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/`
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}
export class GoogleAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() {
return 'google'
}
@ -120,38 +202,14 @@ export class GoogleAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : `Google`
}
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getOrder() {
return 50
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}
export class FacebookAuthProviderType extends AuthProviderType {
export class FacebookAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() {
return 'facebook'
}
@ -168,38 +226,14 @@ export class FacebookAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : this.getName()
}
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getOrder() {
return 50
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}
export class GitHubAuthProviderType extends AuthProviderType {
export class GitHubAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() {
return 'github'
}
@ -216,35 +250,9 @@ export class GitHubAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : this.getName()
}
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OAuth2SettingsForm
}
getOrder() {
return 50
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}
export class GitLabAuthProviderType extends AuthProviderType {
@ -276,26 +284,23 @@ export class GitLabAuthProviderType extends AuthProviderType {
return GitLabSettingsForm
}
getCallbackUrl(authProvider) {
if (!authProvider.id) {
const nextProviderId =
this.app.store.getters['authProviderAdmin/getNextProviderId']
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/`
}
getOrder() {
return 50
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}
export class OpenIdConnectAuthProviderType extends AuthProviderType {
export class OpenIdConnectAuthProviderType extends OAuth2AuthProviderTypeMixin(
AuthProviderType
) {
static getType() {
return 'openid_connect'
}
@ -312,33 +317,38 @@ export class OpenIdConnectAuthProviderType extends AuthProviderType {
return provider.name ? provider.name : this.getName()
}
getLoginButtonComponent() {
return LoginButton
}
getAdminListComponent() {
return AuthProviderItem
}
getAdminSettingsFormComponent() {
return OpenIdConnectSettingsForm
}
getCallbackUrl(authProvider) {
if (!authProvider.id) {
const nextProviderId =
this.app.store.getters['authProviderAdmin/getNextProviderId']
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.app.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${authProvider.id}/`
}
handleServerError(vueComponentInstance, error) {
if (error.handler.code === 'ERROR_INVALID_PROVIDER_URL') {
vueComponentInstance.serverErrors = {
...vueComponentInstance.serverErrors,
baseUrl: error.handler.detail,
}
return true
}
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
vueComponentInstance.serverErrors = structuredClone(
error.handler.detail || {}
)
return true
}
getOrder() {
return 50
}
populateLoginOptions(authProviderOption) {
const loginOptions = super.populateLoginOptions(authProviderOption)
return {
...loginOptions,
}
}
populate(authProviderType) {
const populated = super.populate(authProviderType)
return {
...populated,
}
}
}

View file

@ -1,64 +1,30 @@
<template>
<form
v-if="hasAtLeastOneLoginOption"
class="auth-form-element"
:style="getStyleOverride('input')"
@submit.prevent="onLogin"
>
<Error :error="error"></Error>
<ABFormGroup
:label="$t('authFormElement.email')"
:error-message="
$v.values.email.$dirty
? !$v.values.email.required
? $t('error.requiredField')
: !$v.values.email.email
? $t('error.invalidEmail')
: ''
: ''
"
:autocomplete="isEditMode ? 'off' : ''"
required
>
<ABInput
v-model="values.email"
:placeholder="$t('authFormElement.emailPlaceholder')"
@blur="$v.values.email.$touch()"
/>
</ABFormGroup>
<ABFormGroup
:label="$t('authFormElement.password')"
:error-message="
$v.values.password.$dirty
? !$v.values.password.required
? $t('error.requiredField')
: ''
: ''
"
required
>
<ABInput
ref="passwordRef"
v-model="values.password"
type="password"
:placeholder="$t('authFormElement.passwordPlaceholder')"
@blur="$v.values.password.$touch()"
/>
</ABFormGroup>
<div :style="getStyleOverride('login_button')" class="auth-form__footer">
<ABButton :disabled="$v.$error" :loading="loading" size="large">
{{ resolvedLoginButtonLabel }}
</ABButton>
</div>
</form>
<p v-else>{{ $t('authFormElement.selectOrConfigureUserSourceFirst') }}</p>
<div v-if="hasAtLeastOneLoginOption" :style="fullStyle">
<template v-for="appAuthType in appAuthProviderTypes">
<div
v-if="hasAtLeastOneProvider(appAuthType)"
:key="appAuthType.type"
class="auth-form-element__provider"
>
<component
:is="appAuthType.component"
:user-source="selectedUserSource"
:auth-providers="appAuthProviderPerTypes[appAuthType.type]"
:login-button-label="resolvedLoginButtonLabel"
@after-login="afterLogin"
/>
</div>
</template>
</div>
<p v-else>
{{ $t('authFormElement.selectOrConfigureUserSourceFirst') }}
</p>
</template>
<script>
import form from '@baserow/modules/core/mixins/form'
import error from '@baserow/modules/core/mixins/error'
import element from '@baserow/modules/builder/mixins/element'
import { required, email } from 'vuelidate/lib/validators'
import { ensureString } from '@baserow/modules/core/utils/validator'
import { mapActions } from 'vuex'
@ -79,12 +45,15 @@ export default {
},
},
data() {
return {
loading: false,
values: { email: '', password: '' },
}
return {}
},
computed: {
fullStyle() {
return {
...this.getStyleOverride('input'),
...this.getStyleOverride('login_button'),
}
},
selectedUserSource() {
return this.$store.getters['userSource/getUserSourceById'](
this.builder,
@ -97,8 +66,21 @@ export default {
}
return this.$registry.get('userSource', this.selectedUserSource.type)
},
isAuthenticated() {
return this.$store.getters['userSourceUser/isAuthenticated'](this.builder)
authProviders() {
return this.selectedUserSource?.auth_providers || []
},
appAuthProviderTypes() {
return this.$registry.getOrderedList('appAuthProvider')
},
appAuthProviderPerTypes() {
return Object.fromEntries(
this.appAuthProviderTypes.map((authType) => {
return [
authType.type,
this.authProviders.filter(({ type }) => type === authType.type),
]
})
)
},
loginOptions() {
if (!this.selectedUserSourceType) {
@ -141,68 +123,15 @@ export default {
...mapActions({
actionForceUpdateElement: 'element/forceUpdate',
}),
async onLogin(event) {
if (this.isAuthenticated) {
await this.$store.dispatch('userSourceUser/logoff', {
application: this.builder,
})
}
this.$v.$touch()
if (this.$v.$invalid) {
this.focusOnFirstError()
return
}
this.loading = true
this.hideError()
try {
await this.$store.dispatch('userSourceUser/authenticate', {
application: this.builder,
userSource: this.selectedUserSource,
credentials: {
email: this.values.email,
password: this.values.password,
},
setCookie: this.mode === 'public',
})
this.values.password = ''
this.values.email = ''
this.$v.$reset()
this.fireEvent(
this.elementType.getEventByName(this.element, 'after_login')
)
} catch (error) {
if (error.handler) {
const response = error.handler.response
if (response && response.status === 401) {
this.values.password = ''
this.$v.$reset()
this.$v.$touch()
this.$refs.passwordRef.focus()
if (response.data?.error === 'ERROR_INVALID_CREDENTIALS') {
this.showError(
this.$t('error.incorrectCredentialTitle'),
this.$t('error.incorrectCredentialMessage')
)
}
} else {
const message = error.handler.getMessage('login')
this.showError(message)
}
error.handler.handled()
} else {
throw error
}
}
this.loading = false
hasAtLeastOneProvider(authProviderType) {
return (
this.appAuthProviderPerTypes[authProviderType.getType()]?.length > 0
)
},
},
validations: {
values: {
email: { required, email },
password: { required },
afterLogin() {
this.fireEvent(
this.elementType.getEventByName(this.element, 'after_login')
)
},
},
}

View file

@ -10,8 +10,8 @@
class="context__menu-item-link"
@click="$emit('create', authProviderType)"
>
<AuthProviderIcon :icon="getIcon(authProviderType)" />
{{ getName(authProviderType) }}
<AuthProviderIcon :icon="authProviderType.getIcon()" />
{{ authProviderType.getName() }}
</a>
</li>
</ul>
@ -32,13 +32,5 @@ export default {
required: true,
},
},
methods: {
getIcon(providerType) {
return this.$registry.get('authProvider', providerType.type).getIcon()
},
getName(providerType) {
return this.$registry.get('authProvider', providerType.type).getName()
},
},
}
</script>

View file

@ -107,59 +107,25 @@
<script>
import { required, url } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default {
name: 'GitLabSettingsForm',
mixins: [form],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
},
mixins: [authProviderForm],
data() {
return {
allowedValues: ['name', 'base_url', 'client_id', 'secret'],
values: {
name: '',
base_url: '',
base_url: 'https://gitlab.com',
client_id: '',
secret: '',
},
}
},
computed: {
providerName() {
return this.$registry
.get('authProvider', 'gitlab')
.getProviderName(this.authProvider)
},
callbackUrl() {
if (!this.authProvider.id) {
const nextProviderId =
this.$store.getters['authProviderAdmin/getNextProviderId']
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/`
},
},
methods: {
getDefaultValues() {
return {
name: this.providerName,
base_url: this.authProvider.base_url || 'https://gitlab.com',
client_id: this.authProvider.client_id || '',
secret: this.authProvider.secret || '',
}
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
return this.authProviderType.getCallbackUrl(this.authProvider)
},
},
validations() {

View file

@ -81,23 +81,11 @@
<script>
import { required } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default {
name: 'OAuth2SettingsForm',
mixins: [form],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
authProviderType: {
type: String,
required: false,
default: null,
},
},
mixins: [authProviderForm],
data() {
return {
allowedValues: ['name', 'client_id', 'secret'],
@ -109,37 +97,8 @@ export default {
}
},
computed: {
providerName() {
const type = this.authProviderType
? this.authProviderType
: this.authProvider.type
return this.$registry
.get('authProvider', type)
.getProviderName(this.authProvider)
},
callbackUrl() {
if (!this.authProvider.id) {
const nextProviderId =
this.$store.getters['authProviderAdmin/getNextProviderId']
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/`
},
},
methods: {
getDefaultValues() {
return {
name: this.providerName,
client_id: this.authProvider.client_id || '',
secret: this.authProvider.secret || '',
}
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
return this.authProviderType.getCallbackUrl(this.authProvider)
},
},
validations() {

View file

@ -109,23 +109,11 @@
<script>
import { required, url } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default {
name: 'OpenIdConnectSettingsForm',
mixins: [form],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
serverErrors: {
type: Object,
required: false,
default: () => ({}),
},
},
mixins: [authProviderForm],
data() {
return {
allowedValues: ['name', 'base_url', 'client_id', 'secret'],
@ -139,42 +127,7 @@ export default {
},
computed: {
callbackUrl() {
if (!this.authProvider.id) {
const nextProviderId =
this.$store.getters['authProviderAdmin/getNextProviderId']
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${nextProviderId}/`
}
return `${this.$config.PUBLIC_BACKEND_URL}/api/sso/oauth2/callback/${this.authProvider.id}/`
},
},
methods: {
getDefaultValues() {
return {
name: this.authProvider.name || '',
base_url: this.authProvider.base_url || '',
client_id: this.authProvider.client_id || '',
secret: this.authProvider.secret || '',
}
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
},
handleServerError(error) {
if (error.handler.code === 'ERROR_INVALID_PROVIDER_URL') {
this.serverErrors.baseUrl = error.handler.detail
return true
}
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
for (const [key, value] of Object.entries(error.handler.detail || {})) {
this.serverErrors[key] = value
}
return true
return this.authProviderType.getCallbackUrl(this.authProvider)
},
},
validations() {

View file

@ -21,9 +21,9 @@
ref="domain"
v-model="values.domain"
size="large"
:error="fieldHasErrors('domain') || serverErrors.domain"
:error="fieldHasErrors('domain') || !!serverErrors.domain"
:placeholder="$t('samlSettingsForm.domainPlaceholder')"
@input="serverErrors.domain = null"
@input="onDomainInput()"
@blur="$v.values.domain.$touch()"
></FormInput>
<template #error>
@ -50,16 +50,16 @@
small-label
required
:label="$t('samlSettingsForm.metadata')"
:error="fieldHasErrors('metadata')"
:error="fieldHasErrors('metadata') || !!serverErrors.metadata"
class="margin-bottom-2"
>
<FormTextarea
ref="metadata"
v-model="values.metadata"
:rows="12"
:error="fieldHasErrors('metadata') || serverErrors.metadata"
:rows="8"
:error="fieldHasErrors('metadata') || !!serverErrors.metadata"
:placeholder="$t('samlSettingsForm.metadataPlaceholder')"
@input="serverErrors.metadata = null"
@input="onMetadataInput()"
@blur="$v.values.metadata.$touch()"
></FormTextarea>
@ -73,23 +73,25 @@
</template>
</FormGroup>
<FormGroup
small-label
required
:label="$t('samlSettingsForm.relayStateUrl')"
class="margin-bottom-2"
>
<code>{{ getRelayStateUrl() }}</code>
</FormGroup>
<slot name="config">
<FormGroup
small-label
required
:label="$t('samlSettingsForm.relayStateUrl')"
class="margin-bottom-2"
>
<code>{{ getRelayStateUrl() }}</code>
</FormGroup>
<FormGroup
small-label
required
:label="$t('samlSettingsForm.acsUrl')"
class="margin-bottom-2"
>
<code>{{ getAcsUrl() }}</code>
</FormGroup>
<FormGroup
small-label
required
:label="$t('samlSettingsForm.acsUrl')"
class="margin-bottom-2"
>
<code>{{ getAcsUrl() }}</code>
</FormGroup>
</slot>
<Expandable card class="margin-bottom-2">
<template #header="{ toggle, expanded }">
@ -110,7 +112,7 @@
</div>
<div>
{{
usingDefaultAttrs()
usingDefaultAttrs
? $t('samlSettingsForm.defaultAttrs')
: $t('samlSettingsForm.customAttrs')
}}
@ -190,7 +192,7 @@
<script>
import { maxLength, required, helpers } from 'vuelidate/lib/validators'
import form from '@baserow/modules/core/mixins/form'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
const alphanumericDotDashUnderscore = helpers.regex(
'alphanumericDotDashUnderscore',
@ -199,41 +201,32 @@ const alphanumericDotDashUnderscore = helpers.regex(
export default {
name: 'SamlSettingsForm',
mixins: [form],
props: {
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
authProviderType: {
type: String,
required: false,
default: null,
},
},
mixins: [authProviderForm],
data() {
return {
allowedValues: ['domain', 'metadata'],
serverErrors: {},
allowedValues: [
'domain',
'metadata',
'email_attr_key',
'first_name_attr_key',
'last_name_attr_key',
],
values: {
domain: '',
metadata: '',
email_attr_key: '',
first_name_attr_key: '',
last_name_attr_key: '',
email_attr_key: 'user.email',
first_name_attr_key: 'user.first_name',
last_name_attr_key: 'user.last_name',
},
}
},
computed: {
allSamlProviders() {
return this.authProviders.saml || []
},
samlDomains() {
const samlAuthProviders =
this.$store.getters['authProviderAdmin/getAll'].saml?.authProviders ||
[]
return samlAuthProviders
.filter(
(authProvider) => authProvider.domain !== this.authProvider.domain
)
return this.allSamlProviders
.filter((authProvider) => authProvider.id !== this.authProvider.id)
.map((authProvider) => authProvider.domain)
},
defaultAttrs() {
@ -244,10 +237,8 @@ export default {
}
},
type() {
return this.authProviderType || this.authProvider.type
return this.authProviderType.getType()
},
},
methods: {
usingDefaultAttrs() {
return (
this.values.email_attr_key === this.defaultAttrs.email_attr_key &&
@ -256,20 +247,13 @@ export default {
this.values.last_name_attr_key === this.defaultAttrs.last_name_attr_key
)
},
getDefaultValues() {
const authProviderAttrs = {
email_attr_key: this.authProvider.email_attr_key,
first_name_attr_key: this.authProvider.first_name_attr_key,
last_name_attr_key: this.authProvider.last_name_attr_key,
}
const samlAttrs = this.authProvider.id
? authProviderAttrs
: this.defaultAttrs
return {
domain: this.authProvider.domain || '',
metadata: this.authProvider.metadata || '',
...samlAttrs,
}
},
methods: {
onDomainInput() {
this.serverErrors.domain = null
},
onMetadataInput() {
this.serverErrors.metadata = null
},
getFieldErrorMsg(fieldName) {
if (!this.$v.values[fieldName].$dirty) {
@ -285,33 +269,17 @@ export default {
}
},
getRelayStateUrl() {
return this.$store.getters['authProviderAdmin/getType'](this.type)
.relayStateUrl
return this.authProviderType.getRelayStateUrl()
},
getAcsUrl() {
return this.$store.getters['authProviderAdmin/getType'](this.type).acsUrl
return this.authProviderType.getAcsUrl()
},
getVerifiedIcon() {
return this.$registry.get('authProvider', this.type).getVerifiedIcon()
},
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
return this.authProviderType.getVerifiedIcon()
},
mustHaveUniqueDomain(domain) {
return !this.samlDomains.includes(domain.trim())
},
handleServerError(error) {
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
for (const [key, value] of Object.entries(error.handler.detail || {})) {
this.serverErrors[key] = value
}
return true
},
},
validations() {
return {

View file

@ -3,14 +3,15 @@
<h2 class="box__title">
{{
$t('createSettingsAuthProviderModal.title', {
type: getProviderTypeName(),
type: authProviderType.getName(),
})
}}
</h2>
<div v-if="authProviderType">
<div>
<component
:is="getProviderAdminSettingsFormComponent()"
ref="providerSettingsForm"
:auth-providers="appAuthProviderPerTypes"
:auth-provider-type="authProviderType"
@submit="create($event)"
>
@ -21,8 +22,8 @@
</li>
</ul>
<Button type="primary" :disabled="loading" :loading="loading">
{{ $t('action.create') }}</Button
>
{{ $t('action.create') }}
</Button>
</div>
</component>
</div>
@ -38,9 +39,8 @@ export default {
mixins: [modal],
props: {
authProviderType: {
type: String,
required: false,
default: null,
type: Object,
required: true,
},
},
data() {
@ -48,23 +48,30 @@ export default {
loading: false,
}
},
computed: {
authProviders() {
return this.$store.getters['authProviderAdmin/getAll']
},
appAuthProviderPerTypes() {
return Object.fromEntries(
this.$registry
.getOrderedList('authProvider')
.map((authProviderType) => [
authProviderType.getType(),
this.authProviders[authProviderType.getType()].authProviders,
])
)
},
},
methods: {
getProviderAdminSettingsFormComponent() {
return this.$registry
.get('authProvider', this.authProviderType)
.getAdminSettingsFormComponent()
},
getProviderTypeName() {
if (!this.authProviderType) return ''
return this.$registry.get('authProvider', this.authProviderType).getName()
return this.authProviderType.getAdminSettingsFormComponent()
},
async create(values) {
this.loading = true
this.serverErrors = {}
try {
await this.$store.dispatch('authProviderAdmin/create', {
type: this.authProviderType,
type: this.authProviderType.getType(),
values,
})
this.$emit('created')

View file

@ -11,7 +11,10 @@
<component
:is="getProviderAdminSettingsFormComponent()"
ref="providerSettingsForm"
:auth-providers="appAuthProviderPerTypes"
:auth-provider="authProvider"
:default-values="authProvider"
:auth-provider-type="authProviderType"
@submit="onSettingsUpdated"
>
<div class="actions">
@ -22,8 +25,8 @@
</ul>
<Button type="primary" :disabled="loading" :loading="loading">
{{ $t('action.save') }}</Button
>
{{ $t('action.save') }}
</Button>
</div>
</component>
</div>
@ -48,16 +51,30 @@ export default {
loading: false,
}
},
computed: {
authProviderType() {
return this.$registry.get('authProvider', this.authProvider.type)
},
authProviders() {
return this.$store.getters['authProviderAdmin/getAll']
},
appAuthProviderPerTypes() {
return Object.fromEntries(
this.$registry
.getOrderedList('authProvider')
.map((authProviderType) => [
authProviderType.getType(),
this.authProviders[authProviderType.getType()].authProviders,
])
)
},
},
methods: {
getProviderAdminSettingsFormComponent() {
return this.$registry
.get('authProvider', this.authProvider.type)
.getAdminSettingsFormComponent()
return this.authProviderType.getAdminSettingsFormComponent()
},
getProviderName() {
return this.$registry
.get('authProvider', this.authProvider.type)
.getProviderName(this.authProvider)
return this.authProviderType.getProviderName(this.authProvider)
},
async onSettingsUpdated(values) {
this.loading = true

View file

@ -1,5 +1,10 @@
import { AppAuthProviderType } from '@baserow/modules/core/appAuthProviderTypes'
import { SamlAuthProviderTypeMixin } from '@baserow_enterprise/authProviderTypes'
import LocalBaserowUserSourceForm from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowPasswordAppAuthProviderForm'
import LocalBaserowAuthPassword from '@baserow_enterprise/integrations/localBaserow/components/appAuthProviders/LocalBaserowAuthPassword'
import CommonSamlSettingForm from '@baserow_enterprise/integrations/common/components/CommonSamlSettingForm'
import SamlAuthLink from '@baserow_enterprise/integrations/common/components/SamlAuthLink'
import { PasswordFieldType } from '@baserow/modules/database/fieldTypes'
export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType {
@ -11,6 +16,10 @@ export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType
return this.app.i18n.t('appAuthProviderType.localBaserowPassword')
}
get component() {
return LocalBaserowAuthPassword
}
get formComponent() {
return LocalBaserowUserSourceForm
}
@ -23,7 +32,91 @@ export class LocalBaserowPasswordAppAuthProviderType extends AppAuthProviderType
return [PasswordFieldType.getType()]
}
getLoginOptions(authProvider) {
if (authProvider.password_field_id) {
return {}
}
return null
}
/**
* We can create only one password provider.
*/
canCreateNew(appAuthProviders) {
return (
!appAuthProviders[this.getType()] ||
appAuthProviders[this.getType()].length === 0
)
}
getOrder() {
return 10
}
}
export class SamlAppAuthProviderType extends SamlAuthProviderTypeMixin(
AppAuthProviderType
) {
get name() {
return this.app.i18n.t('appAuthProviderType.commonSaml')
}
get component() {
return SamlAuthLink
}
get formComponent() {
return CommonSamlSettingForm
}
handleServerError(vueComponentInstance, error) {
if (error.handler.code !== 'ERROR_REQUEST_BODY_VALIDATION') return false
if (error.handler.detail?.auth_providers?.length > 0) {
const flatProviders = Object.entries(vueComponentInstance.authProviders)
.map(([, providers]) => providers)
.flat()
// Sort per ID to make sure we have the same order
// as the backend
.sort((a, b) => a.id - b.id)
for (const [
index,
authError,
] of error.handler.detail.auth_providers.entries()) {
if (
Object.keys(authError).length > 0 &&
flatProviders[index].id === vueComponentInstance.authProvider.id
) {
vueComponentInstance.serverErrors = {
...vueComponentInstance.serverErrors,
...authError,
}
return true
}
}
}
return false
}
getAuthToken(userSource, authProvider, route) {
// token can be in the query string (SSO) or in the cookies (previous session)
// We use the user source id in order to prevent conflicts when using multiple
// auth forms on the same page.
const queryParamName = `user_source_saml_token__${userSource.id}`
return route.query[queryParamName]
}
handleError(userSource, authProvider, route) {
const queryParamName = `saml_error__${userSource.id}`
const errorCode = route.query[queryParamName]
if (errorCode) {
return { message: this.app.i18n.t(`loginError.${errorCode}`), code: 500 }
}
}
getOrder() {
return 20
}
}

View file

@ -0,0 +1,83 @@
<template>
<div>
<div
class="common-saml-setting-form"
:class="{ 'common-saml-setting-form--error': inError }"
>
<Presentation
class="flex-grow-1"
:title="authProviderType.getProviderName(authProvider)"
size="medium"
avatar-color="neutral"
:image="authProviderType.getIcon()"
@click="onEdit"
/>
<div class="common-saml-setting-form__actions">
<ButtonIcon
type="secondary"
icon="iconoir-edit"
@click.prevent="onEdit"
/>
<ButtonIcon
type="secondary"
icon="iconoir-bin"
@click.prevent="onDelete"
/>
</div>
</div>
<div v-if="inError" class="error">
{{ $t('commonSamlSettingForm.authProviderInError') }}
</div>
<CommonSamlSettingModal
ref="samlModal"
v-bind="$props"
:integration="integration"
:user-source="userSource"
@form-valid="onFormValid($event)"
v-on="$listeners"
></CommonSamlSettingModal>
</div>
</template>
<script>
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
import CommonSamlSettingModal from '@baserow_enterprise/integrations/common/components/CommonSamlSettingModal'
export default {
name: 'CommonSamlSettingForm',
components: { CommonSamlSettingModal },
mixins: [authProviderForm],
inject: ['builder'],
props: {
integration: {
type: Object,
required: true,
},
userSource: {
type: Object,
required: true,
},
},
data() {
return { inError: false }
},
methods: {
onFormValid(value) {
this.inError = !value
},
onEdit() {
this.$refs.samlModal.show()
},
onDelete() {
this.$emit('delete')
},
handleServerError(error) {
if (this.$refs.samlModal.handleServerError(error)) {
this.inError = true
return true
}
return false
},
},
}
</script>

View file

@ -0,0 +1,172 @@
<template>
<Modal ref="modal" keep-content @hidden="onHide">
<h2 class="box__title">{{ $t('commonSamlSettingModal.title') }}</h2>
<div>
<SamlSettingsForm
v-bind="$props"
ref="samlForm"
@values-changed="checkValidity"
v-on="$listeners"
>
<template #config>
<FormGroup
small-label
required
:label="$t('commonSamlSettingModal.relayStateTitle')"
class="margin-bottom-2"
>
<div class="common-saml-setting-modal__url-block">
<div
v-for="conf in config"
:key="conf.name"
class="common-saml-setting-modal__url"
@click.prevent="
;[copyToClipboard(conf.relay), $refs.copiedRelay.show()]
"
>
<span class="common-saml-setting-modal__url-domain">
{{ conf.name }}
</span>
<span
class="common-saml-setting-modal__url-dest"
:title="conf.relay"
>
{{ conf.relay }}
</span>
</div>
<Copied ref="copiedRelay"></Copied>
</div>
</FormGroup>
<FormGroup
small-label
required
:label="$t('commonSamlSettingModal.acsTitle')"
class="margin-bottom-2"
>
<div class="common-saml-setting-modal__url-block">
<div
v-for="conf in config"
:key="conf.name"
class="common-saml-setting-modal__url"
@click.prevent="
;[copyToClipboard(conf.acs), $refs.copiedACS.show()]
"
>
<span class="common-saml-setting-modal__url-domain">
{{ conf.name }}
</span>
<span
class="common-saml-setting-modal__url-dest"
:title="conf.acs"
>
{{ conf.acs }}
</span>
</div>
<Copied ref="copiedACS"></Copied>
</div>
</FormGroup>
</template>
</SamlSettingsForm>
<div class="actions actions--right">
<Button size="large" @click.prevent="$refs.modal.hide()">
{{ $t('action.close') }}
</Button>
</div>
</div>
</Modal>
</template>
<script>
import SamlSettingsForm from '@baserow_enterprise/components/admin/forms/SamlSettingsForm'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
import error from '@baserow/modules/core/mixins/error'
import { copyToClipboard } from '@baserow/modules/database/utils/clipboard'
import { mapActions, mapGetters } from 'vuex'
import modal from '@baserow/modules/core/mixins/modal'
export default {
name: 'CommonSamlSettingsModal',
components: { SamlSettingsForm },
mixins: [error, authProviderForm, modal],
inject: ['builder'],
props: {
integration: {
type: Object,
required: true,
},
userSource: {
type: Object,
required: true,
},
},
async fetch() {
try {
await this.actionFetchDomains({ builderId: this.builder.id })
} catch (error) {
this.handleError(error)
}
},
watch: {
'$v.$anyDirty'() {
// Force validity refresh on child touch
this.checkValidity()
},
},
computed: {
...mapGetters({ domains: 'domain/getDomains' }),
config() {
const previewRelay = `${this.$config.PUBLIC_WEB_FRONTEND_URL}/builder/${this.builder.id}/preview/`
const previewACS = `${this.$config.PUBLIC_BACKEND_URL}/api/user-source/${this.userSource.uid}/sso/saml/acs/`
const preview = [
{
name: this.$t('commonSamlSettingModal.preview'),
acs: previewACS,
relay: previewRelay,
},
]
const others = this.domains.map((domain) => ({
name: domain.domain_name,
acs: `${this.$config.PUBLIC_BACKEND_URL}/api/user-source/domain_${domain.id}__${this.userSource.uid}/sso/saml/acs/`,
relay: this.getDomainUrl(domain),
}))
return [...preview, ...others]
},
},
methods: {
...mapActions({
actionFetchDomains: 'domain/fetch',
}),
copyToClipboard(value) {
copyToClipboard(value)
},
onHide() {
this.checkValidity()
},
checkValidity() {
if (
!this.$refs.samlForm.isFormValid() &&
this.$refs.samlForm.$v.$anyDirty
) {
this.$emit('form-valid', false)
} else {
this.$emit('form-valid', true)
}
},
getDomainUrl(domain) {
const url = new URL(this.$config.PUBLIC_WEB_FRONTEND_URL)
return `${url.protocol}//${domain.domain_name}${
url.port ? `:${url.port}` : ''
}`
},
handleServerError(error) {
return this.$refs.samlForm.handleServerError(error)
},
},
validations() {
// Keep this to get the `$v` property
return {}
},
}
</script>

View file

@ -0,0 +1,135 @@
<template>
<div class="saml-auth-link">
<ABButton @click.prevent="onClick()">
{{ buttonLabel }}
</ABButton>
<Modal ref="modal">
<ThemeProvider>
<form @submit.prevent="">
<ABFormGroup
v-if="hasMultipleSamlProvider"
:label="$t('samlAuthLink.provideEmail')"
:error-message="
$v.values.email.$dirty
? !$v.values.email.required
? $t('error.requiredField')
: !$v.values.email.email
? $t('error.invalidEmail')
: ''
: ''
"
:autocomplete="isEditMode ? 'off' : ''"
required
>
<ABInput
v-model="values.email"
:placeholder="$t('samlAuthLink.emailPlaceholder')"
@blur="$v.values.email.$touch()"
/>
</ABFormGroup>
<div class="saml-auth-link__modal-footer">
<ABButton class="margin-top-2" @click.prevent.stop="login()">
{{ buttonLabel }}
</ABButton>
</div>
</form>
</ThemeProvider>
</Modal>
</div>
</template>
<script>
import form from '@baserow/modules/core/mixins/form'
import error from '@baserow/modules/core/mixins/error'
import { required, email } from 'vuelidate/lib/validators'
import ThemeProvider from '@baserow/modules/builder/components/theme/ThemeProvider'
export default {
components: { ThemeProvider },
mixins: [form, error],
inject: ['builder', 'mode'],
props: {
userSource: { type: Object, required: true },
authProviders: {
type: Array,
required: true,
},
loginButtonLabel: {
type: String,
required: true,
},
},
data() {
return {
loading: false,
values: { email: '' },
}
},
computed: {
isAuthenticated() {
return this.$store.getters['userSourceUser/isAuthenticated'](this.builder)
},
hasMultipleSamlProvider() {
return this.authProviders.length > 1
},
isEditMode() {
return this.mode === 'editing'
},
buttonLabel() {
return this.$t('samlAuthLink.placeholderWithSaml', {
login: this.loginButtonLabel,
})
},
},
methods: {
async onClick() {
if (this.hasMultipleSamlProvider) {
this.$refs.modal.show()
} else {
await this.login()
}
},
async login() {
if (this.isAuthenticated) {
await this.$store.dispatch('userSourceUser/logoff', {
application: this.builder,
})
}
if (this.hasMultipleSamlProvider) {
this.$v.$touch()
if (this.$v.$invalid) {
this.focusOnFirstError()
return
}
}
this.loading = true
this.hideError()
const dest = `${
this.$config.PUBLIC_BACKEND_URL
}/api/user-source/${encodeURIComponent(
this.userSource.uid
)}/sso/saml/login/`
const urlWithParams = new URL(dest)
if (this.hasMultipleSamlProvider) {
urlWithParams.searchParams.append('email', this.values.email)
}
// Add the current url as get parameter to be redirected here after the login.
urlWithParams.searchParams.append('original', window.location)
window.location = urlWithParams.toString()
},
},
validations: {
values: {
email: { required, email },
},
},
}
</script>

View file

@ -0,0 +1,148 @@
<template>
<form class="auth-form-element" @submit.prevent="onLogin">
<Error :error="error"></Error>
<ABFormGroup
:label="$t('authFormElement.email')"
:error-message="
$v.values.email.$dirty
? !$v.values.email.required
? $t('error.requiredField')
: !$v.values.email.email
? $t('error.invalidEmail')
: ''
: ''
"
:autocomplete="isEditMode ? 'off' : ''"
required
>
<ABInput
v-model="values.email"
:placeholder="$t('authFormElement.emailPlaceholder')"
@blur="$v.values.email.$touch()"
/>
</ABFormGroup>
<ABFormGroup
:label="$t('authFormElement.password')"
:error-message="
$v.values.password.$dirty
? !$v.values.password.required
? $t('error.requiredField')
: ''
: ''
"
required
>
<ABInput
ref="passwordRef"
v-model="values.password"
type="password"
:placeholder="$t('authFormElement.passwordPlaceholder')"
@blur="$v.values.password.$touch()"
/>
</ABFormGroup>
<div class="auth-form__footer">
<ABButton :disabled="$v.$error" :loading="loading" size="large">
{{ loginButtonLabel }}
</ABButton>
</div>
</form>
</template>
<script>
import form from '@baserow/modules/core/mixins/form'
import error from '@baserow/modules/core/mixins/error'
import { required, email } from 'vuelidate/lib/validators'
export default {
mixins: [form, error],
inject: ['builder', 'mode'],
props: {
userSource: { type: Object, required: true },
authProviders: {
type: Array,
required: true,
},
loginButtonLabel: {
type: String,
required: true,
},
},
data() {
return {
loading: false,
values: { email: '', password: '' },
}
},
computed: {
isAuthenticated() {
return this.$store.getters['userSourceUser/isAuthenticated'](this.builder)
},
isEditMode() {
return this.mode === 'editing'
},
},
methods: {
async onLogin(event) {
if (this.isAuthenticated) {
await this.$store.dispatch('userSourceUser/logoff', {
application: this.builder,
})
}
this.$v.$touch()
if (this.$v.$invalid) {
this.focusOnFirstError()
return
}
this.loading = true
this.hideError()
try {
await this.$store.dispatch('userSourceUser/authenticate', {
application: this.builder,
userSource: this.userSource,
credentials: {
email: this.values.email,
password: this.values.password,
},
setCookie: this.mode === 'public',
})
this.values.password = ''
this.values.email = ''
this.$v.$reset()
this.$emit('after-login')
} catch (error) {
if (error.handler) {
const response = error.handler.response
if (response && response.status === 401) {
this.values.password = ''
this.$v.$reset()
this.$v.$touch()
this.$refs.passwordRef.focus()
if (response.data?.error === 'ERROR_INVALID_CREDENTIALS') {
this.showError(
this.$t('error.incorrectCredentialTitle'),
this.$t('error.incorrectCredentialMessage')
)
}
} else {
const message = error.handler.getMessage('login')
this.showError(message)
}
error.handler.handled()
} else {
throw error
}
}
this.loading = false
},
},
validations: {
values: {
email: { required, email },
password: { required },
},
},
}
</script>

View file

@ -5,7 +5,6 @@
small-label
horizontal
horizontal-variable
class="margin-top-2"
required
>
<Dropdown
@ -32,16 +31,16 @@
</template>
<script>
import form from '@baserow/modules/core/mixins/form'
import authProviderForm from '@baserow/modules/core/mixins/authProviderForm'
export default {
mixins: [form],
mixins: [authProviderForm],
props: {
integration: {
type: Object,
required: true,
},
currentUserSource: {
userSource: {
type: Object,
required: true,
},
@ -55,9 +54,6 @@ export default {
}
},
computed: {
authProviderType() {
return this.$registry.get('appAuthProvider', 'local_baserow_password')
},
databases() {
return this.integration.context_data.databases
},
@ -65,12 +61,12 @@ export default {
return this.$registry.getAll('field')
},
selectedTable() {
if (!this.currentUserSource.table_id) {
if (!this.userSource.table_id) {
return null
}
for (const database of this.databases) {
for (const table of database.tables) {
if (table.id === this.currentUserSource.table_id) {
if (table.id === this.userSource.table_id) {
return table
}
}
@ -91,7 +87,7 @@ export default {
},
},
watch: {
'currentUserSource.table_id'() {
'userSource.table_id'() {
this.values.password_field_id = null
},
},

View file

@ -117,17 +117,22 @@ export class LocalBaserowUserSourceType extends UserSourceType {
if (!userSource.email_field_id || !userSource.name_field_id) {
return {}
}
if (userSource.auth_providers.length !== 1) {
return {}
}
const authProvider = userSource.auth_providers[0]
if (
authProvider.type !== 'local_baserow_password' ||
!authProvider.password_field_id
) {
return {}
}
return { password: {} }
return userSource.auth_providers.reduce((acc, authProvider) => {
if (!acc[authProvider.type]) {
acc[authProvider.type] = []
}
const loginOptions = this.app.$registry
.get('appAuthProvider', authProvider.type)
.getLoginOptions(authProvider)
if (loginOptions) {
acc[authProvider.type].push(loginOptions)
}
return acc
}, {})
}
getOrder() {

View file

@ -95,14 +95,17 @@
"addProvider": "Add provider"
},
"authProviderTypes": {
"password": "Email and password authentication"
"password": "Email and password authentication",
"saml": "SSO SAML provider",
"ssoSamlProviderName": "SSO SAML: {domain}",
"ssoSamlProviderNameUnconfigured": "Unconfigured SSO SAML"
},
"editAuthProviderMenuContext": {
"edit": "Edit",
"delete": "Delete"
},
"samlSettingsForm": {
"domain": "Domain",
"domain": "SAML Domain",
"domainPlaceholder": "Insert the company domain name...",
"invalidDomain": "Invalid domain name",
"domainAlreadyExists": "A SAML provider for this domain already exists",
@ -311,12 +314,22 @@
"roleFieldPlaceholder": "Select a field..."
},
"appAuthProviderType": {
"localBaserowPassword": "Email/Password"
"localBaserowPassword": "Email/Password",
"commonSaml": "Saml SSO"
},
"localBaserowPasswordAppAuthProviderForm": {
"passwordFieldLabel": "Select password field",
"noFields": "No compatible fields"
},
"commonSamlSettingModal": {
"title": "Edit provider",
"relayStateTitle": "Default Relay State URL per domain (Click to copy)",
"acsTitle": "Single Sign On URL per domain (Click to copy)",
"preview": "Preview"
},
"commonSamlSettingForm": {
"authProviderInError": "Please edit this provider to fix the error."
},
"enterpriseSettings": {
"branding": "Branding",
"showBaserowHelpMessage": "Show help message",
@ -383,5 +396,11 @@
"projectIdHelper": "The ID of the project where you want to sync the issues from. Can be found by going to your project page (e.g. https://gitlab.com/baserow/baserow), click on the three dots in the top right corner, and then on 'Copy project ID: 12345678'.",
"accessToken": "Access token",
"accessTokenHelper": "Can be generated here https://gitlab.com/-/user_settings/personal_access_tokens by clicking on `Add new token`, select `read_api`."
},
"samlAuthLink": {
"loginWithSaml": "Login with SAML",
"placeholderWithSaml": "{login} with SAML",
"provideEmail": "Provide your SAML account email",
"emailPlaceholder": "Enter your email..."
}
}

View file

@ -13,6 +13,7 @@
@create="showCreateModal($event)"
/>
<CreateAuthProviderModal
v-if="authProviderTypeToCreate"
ref="createModal"
:auth-provider-type="authProviderTypeToCreate"
@created="$refs.createModal.hide()"
@ -57,9 +58,15 @@ export default {
},
computed: {
...mapGetters({
authProviderMap: 'authProviderAdmin/getAll',
authProviders: 'authProviderAdmin/getAllOrdered',
authProviderTypesCanBeCreated: 'authProviderAdmin/getCreatableTypes',
}),
authProviderTypesCanBeCreated() {
return Object.values(this.$registry.getAll('authProvider')).filter(
(authProviderType) =>
authProviderType.canCreateNew(this.authProviderMap)
)
},
},
methods: {
getAdminListComponent(authProvider) {
@ -75,8 +82,10 @@ export default {
4
)
},
showCreateModal(authProviderType) {
this.authProviderTypeToCreate = authProviderType.type
async showCreateModal(authProviderType) {
this.authProviderTypeToCreate = authProviderType
// Wait for the modal to appear in DOM
await this.$nextTick()
this.$refs.createModal.show()
this.$refs.createContext.hide()
},

View file

@ -27,7 +27,10 @@ import {
} from '@baserow_enterprise/licenseTypes'
import { EnterprisePlugin } from '@baserow_enterprise/plugins'
import { LocalBaserowUserSourceType } from '@baserow_enterprise/integrations/userSourceTypes'
import { LocalBaserowPasswordAppAuthProviderType } from '@baserow_enterprise/integrations/appAuthProviderTypes'
import {
LocalBaserowPasswordAppAuthProviderType,
SamlAppAuthProviderType,
} from '@baserow_enterprise/integrations/appAuthProviderTypes'
import { AuthFormElementType } from '@baserow_enterprise/builder/elementTypes'
import {
EnterpriseAdminRoleType,
@ -46,6 +49,8 @@ import {
GitLabIssuesDataSyncType,
} from '@baserow_enterprise/dataSyncTypes'
import { FF_AB_SSO } from '@baserow/modules/core/plugins/featureFlags'
export default (context) => {
const { app, isDev, store } = context
@ -115,6 +120,13 @@ export default (context) => {
new LocalBaserowPasswordAppAuthProviderType(context)
)
if (app.$featureFlagIsEnabled(FF_AB_SSO)) {
app.$registry.register(
'appAuthProvider',
new SamlAppAuthProviderType(context)
)
}
app.$registry.register('roles', new EnterpriseAdminRoleType(context))
app.$registry.register('roles', new EnterpriseMemberRoleType(context))
app.$registry.register('roles', new EnterpriseBuilderRoleType(context))

View file

@ -118,15 +118,6 @@ export const getters = {
}
return authProviders
},
getCreatableTypes: (state) => {
const items = []
for (const authProviderType of Object.values(state.items)) {
if (authProviderType.canCreateNewProviders) {
items.push(authProviderType)
}
}
return items
},
getNextProviderId: (state) => {
return state.nextProviderId
},

View file

@ -16,17 +16,7 @@
</p>
<p v-else class="placeholder__content">{{ content }}</p>
<div class="placeholder__action">
<Button
type="primary"
icon="iconoir-home"
size="large"
@click="
$router.go({
name: 'application-builder-page',
params: { pathMatch: '/' },
})
"
>
<Button type="primary" icon="iconoir-home" size="large" @click="onHome()">
{{ $t('action.backToHome') }}
</Button>
</div>
@ -62,5 +52,15 @@ export default {
return this.error.content || this.$t('errorLayout.error')
},
},
methods: {
onHome() {
this.$router.push({
name: 'application-builder-page',
params: { pathMatch: '/' },
// We remove the query parameters. Important if we have some with error
query: {},
})
},
},
}
</script>

View file

@ -11,10 +11,7 @@
<slot name="title">{{
$t('pageVisibilitySettingsTypes.logInPageWarningTitle')
}}</slot>
<!-- eslint-disable-next-line vue/no-v-html vue/no-v-text-v-html-on-component -->
<p
v-html="$t('pageVisibilitySettingsTypes.logInPagewarningMessage')"
></p>
<p>{{ $t('pageVisibilitySettingsTypes.logInPagewarningMessage') }}</p>
</Alert>
<Alert
v-else-if="showLoginPageAlert && !showLogInPageWarning"
@ -23,14 +20,13 @@
<slot name="title">{{
$t('pageVisibilitySettingsTypes.logInPageInfoTitle')
}}</slot>
<!-- eslint-disable-next-line vue/no-v-html vue/no-v-text-v-html-on-component -->
<p
v-html="
<p>
{{
$t('pageVisibilitySettingsTypes.logInPageInfoMessage', {
logInPageName: loginPageName,
})
"
></p>
}}
</p>
</Alert>
</div>
<div class="margin-top-1 visibility-form__visibility-all">

View file

@ -53,7 +53,6 @@
<UpdateUserSourceForm
ref="userSourceForm"
:builder="builder"
:integrations="integrations"
:user-source-type="getUserSourceType(editedUserSource)"
:default-values="editedUserSource"
@submitted="updateUserSource"
@ -71,7 +70,7 @@
:disabled="actionInProgress || invalidForm"
:loading="actionInProgress"
size="large"
@click="$refs.userSourceForm.submit()"
@click="$refs.userSourceForm.submit(true)"
>
{{ $t('action.save') }}
</Button>
@ -86,7 +85,6 @@
<CreateUserSourceForm
ref="userSourceForm"
:builder="builder"
:integrations="integrations"
@submitted="createUserSource"
@values-changed="onValueChange"
/>
@ -122,6 +120,9 @@ export default {
name: 'UserSourceSettings',
components: { CreateUserSourceForm, UpdateUserSourceForm },
mixins: [error],
provide() {
return { builder: this.builder }
},
props: {
builder: {
type: Object,
@ -169,7 +170,7 @@ export default {
return this.$registry.get('userSource', userSource.type)
},
onValueChange() {
this.invalidForm = !this.$refs.userSourceForm.isFormValid()
this.invalidForm = !this.$refs.userSourceForm.isFormValid(true)
},
async showForm(userSourceToEdit) {
if (userSourceToEdit) {
@ -205,6 +206,10 @@ export default {
this.actionInProgress = false
},
async updateUserSource(newValues) {
if (!this.$refs.userSourceForm.isFormValid(true)) {
return
}
this.actionInProgress = true
try {
await this.actionUpdateUserSource({
@ -215,8 +220,10 @@ export default {
this.hideForm()
} catch (error) {
// Restore the previously saved values from the store
this.$refs.userSourceForm.reset()
this.handleError(error)
if (!this.$refs.userSourceForm.handleServerError(error)) {
this.$refs.userSourceForm.reset()
this.handleError(error)
}
}
this.actionInProgress = false
},

View file

@ -67,10 +67,6 @@ export default {
type: Object,
required: true,
},
integrations: {
type: Array,
required: true,
},
},
data() {
return {
@ -78,6 +74,9 @@ export default {
}
},
computed: {
integrations() {
return this.$store.getters['integration/getIntegrations'](this.builder)
},
userSources() {
return this.$store.getters['userSource/getUserSources'](this.builder)
},
@ -117,6 +116,10 @@ export default {
}
return ''
},
handleServerError() {
return false
},
},
validations: {
values: {

View file

@ -47,13 +47,28 @@
<div
v-for="appAuthType in appAuthProviderTypes"
:key="appAuthType.type"
class="update-user-source-form__auth-provider"
>
<Checkbox
:checked="hasAtLeastOneOfThisType(appAuthType)"
@input="onSelect(appAuthType)"
>
{{ appAuthType.name }}
</Checkbox>
<div class="update-user-source-form__auth-provider-header">
<Checkbox
:checked="hasAtLeastOneOfThisType(appAuthType)"
@input="onSelect(appAuthType)"
>
{{ appAuthType.name }}
</Checkbox>
<ButtonText
v-if="
hasAtLeastOneOfThisType(appAuthType) &&
appAuthType.canCreateNew(appAuthProviderPerTypes)
"
icon="iconoir-plus"
type="secondary"
@click.prevent="addNew(appAuthType)"
>
{{ $t('updateUserSourceForm.addProvider') }}
</ButtonText>
</div>
<div
v-for="appAuthProvider in appAuthProviderPerTypes[appAuthType.type]"
@ -63,11 +78,16 @@
<component
:is="appAuthType.formComponent"
v-if="hasAtLeastOneOfThisType(appAuthType)"
:integration="integration"
:current-user-source="fullValues"
:default-values="appAuthProvider"
:ref="`authProviderForm`"
excluded-form
:integration="integration"
:user-source="fullValues"
:auth-providers="appAuthProviderPerTypes"
:auth-provider="appAuthProvider"
:default-values="appAuthProvider"
:auth-provider-type="appAuthType"
@values-changed="updateAuthProvider(appAuthProvider, $event)"
@delete="remove(appAuthProvider)"
/>
</div>
</div>
@ -82,7 +102,6 @@
import form from '@baserow/modules/core/mixins/form'
import IntegrationDropdown from '@baserow/modules/core/components/integrations/IntegrationDropdown'
import { required, maxLength } from 'vuelidate/lib/validators'
import { uuid } from '@baserow/modules/core/utils/string'
export default {
components: { IntegrationDropdown },
@ -97,10 +116,6 @@ export default {
required: false,
default: null,
},
integrations: {
type: Array,
required: true,
},
},
data() {
return {
@ -113,6 +128,9 @@ export default {
}
},
computed: {
integrations() {
return this.$store.getters['integration/getIntegrations'](this.builder)
},
integration() {
if (!this.values.integration_id) {
return null
@ -140,6 +158,8 @@ export default {
methods: {
// Override the default getChildFormValues to exclude the provider forms from
// final values as they are handled directly by this component
// The problem is that the child provider forms are not handled as a sub array
// so they override the userSource configuration
getChildFormsValues() {
return Object.assign(
{},
@ -157,6 +177,14 @@ export default {
hasAtLeastOneOfThisType(appAuthProviderType) {
return this.appAuthProviderPerTypes[appAuthProviderType.type]?.length > 0
},
/** Return an integer bigger than any of the current auth_provider id to
* keep the right order when we want to map the error coming back from the server.
*/
nextID() {
return (
Math.max(1, ...this.values.auth_providers.map(({ id }) => id)) + 100
)
},
onSelect(appAuthProviderType) {
if (this.hasAtLeastOneOfThisType(appAuthProviderType)) {
this.values.auth_providers = this.values.auth_providers.filter(
@ -165,10 +193,21 @@ export default {
} else {
this.values.auth_providers.push({
type: appAuthProviderType.type,
id: uuid(),
id: this.nextID(),
})
}
},
addNew(appAuthProviderType) {
this.values.auth_providers.push({
type: appAuthProviderType.type,
id: this.nextID(),
})
},
remove(appAuthProvider) {
this.values.auth_providers = this.values.auth_providers.filter(
({ id }) => id !== appAuthProvider.id
)
},
updateAuthProvider(authProviderToChange, values) {
this.values.auth_providers = this.values.auth_providers.map(
(authProvider) => {
@ -182,6 +221,16 @@ export default {
emitChange() {
this.fullValues = this.getFormValues()
},
handleServerError(error) {
if (
this.$refs.authProviderForm
.map((form) => form.handleServerError(error))
.some((result) => result)
) {
return true
}
return false
},
getError(fieldName) {
if (!this.$v.values[fieldName].$dirty) {
return ''

View file

@ -732,7 +732,8 @@
"nameFieldLabel": "Name",
"nameFieldPlaceholder": "Enter a name...",
"authTitle": "Authentication",
"integrationFieldLabel": "Integration"
"integrationFieldLabel": "Integration",
"addProvider": "Add provider"
},
"builderLoginPageForm": {
"pageDropdownLabel": "Login Page",

View file

@ -24,8 +24,8 @@ import { userCanViewPage } from '@baserow/modules/builder/utils/visibility'
import {
getTokenIfEnoughTimeLeft,
setToken,
userSourceCookieTokenName,
setToken,
} from '@baserow/modules/core/utils/auth'
const logOffAndReturnToLogin = async ({ builder, store, redirect }) => {
@ -59,8 +59,8 @@ export default {
$registry,
app,
req,
route,
redirect,
route,
}) {
let mode = 'public'
const builderId = params.builderId ? parseInt(params.builderId, 10) : null
@ -71,6 +71,7 @@ export default {
}
let builder = store.getters['application/getSelected']
let needPostBuilderLoading = false
if (!builder || (builderId && builderId !== builder.id)) {
try {
@ -105,7 +106,40 @@ export default {
})
}
needPostBuilderLoading = true
}
store.dispatch('userSourceUser/setCurrentApplication', {
application: builder,
})
if (
(!process.server || req) &&
!store.getters['userSourceUser/isAuthenticated'](builder)
) {
const refreshToken = getTokenIfEnoughTimeLeft(
app,
userSourceCookieTokenName
)
if (refreshToken) {
try {
await store.dispatch('userSourceUser/refreshAuth', {
application: builder,
token: refreshToken,
})
} catch (error) {
if (error.response?.status === 401) {
// We logoff as the token has probably expired or became invalid
logOffAndReturnToLogin({ builder, store, redirect })
} else {
throw error
}
}
}
}
if (needPostBuilderLoading) {
// Post builder loading task executed once per application
// It's executed here to make sure we are authenticated at that point
const sharedPage = await store.getters['page/getSharedPage'](builder)
await Promise.all([
store.dispatch('dataSource/fetchPublished', {
@ -127,37 +161,18 @@ export default {
}
)
}
store.dispatch('userSourceUser/setCurrentApplication', {
application: builder,
})
if (
(!process.server || req) &&
!store.getters['userSourceUser/isAuthenticated'](builder)
) {
// token can be in the query string (SSO) or in the cookies (previous session)
let refreshToken = route.query.token
if (refreshToken) {
setToken(app, refreshToken, userSourceCookieTokenName, {
sameSite: 'Lax',
})
} else {
refreshToken = getTokenIfEnoughTimeLeft(app, userSourceCookieTokenName)
}
if (refreshToken) {
try {
await store.dispatch('userSourceUser/refreshAuth', {
application: builder,
token: refreshToken,
// Auth providers can get error code from the URL parameters
for (const userSource of builder.user_sources) {
for (const authProvider of userSource.auth_providers) {
const authError = $registry
.get('appAuthProvider', authProvider.type)
.handleError(userSource, authProvider, route)
if (authError) {
return error({
statusCode: authError.code,
message: authError.message,
})
} catch (error) {
if (error.response?.status === 401) {
// We logoff as the token has probably expired or became invalid
logOffAndReturnToLogin({ builder, store, redirect })
} else {
throw error
}
}
}
}
@ -348,7 +363,7 @@ export default {
}
},
},
async isAuthenticated() {
async isAuthenticated(newIsAuthenticated) {
// When the user login or logout, we need to refetch the elements and actions
// as they might have changed
await this.$store.dispatch('element/fetchPublished', {
@ -364,12 +379,18 @@ export default {
page: this.sharedPage,
})
// If the user is on a hidden page, redirect them to the Login page if possible.
await this.maybeRedirectUserToLoginPage()
if (newIsAuthenticated) {
// If the user has just logged in, we redirect him to the next page.
await this.maybeRedirectToNextPage()
} else {
// If the user is on a hidden page, redirect them to the Login page if possible.
await this.maybeRedirectUserToLoginPage()
}
},
},
async mounted() {
await this.maybeRedirectUserToLoginPage()
await this.checkProviderAuthentication()
},
methods: {
/**
@ -389,8 +410,57 @@ export default {
this.mode
)
if (url !== this.$router.history.current?.fullPath) {
this.$router.push(url)
const currentPath = this.$route.fullPath
if (url !== currentPath) {
const nextPath = encodeURIComponent(currentPath)
this.$router.push({ path: url, query: { next: nextPath } })
}
}
},
maybeRedirectToNextPage() {
if (this.$route.query.next) {
const decodedNext = decodeURIComponent(this.$route.query.next)
this.$router.push(decodedNext)
}
},
async checkProviderAuthentication() {
// Iterate over all auth providers to check if one can get a refresh token
let refreshTokenFromProvider = null
for (const userSource of this.builder.user_sources) {
for (const authProvider of userSource.auth_providers) {
refreshTokenFromProvider = this.$registry
.get('appAuthProvider', authProvider.type)
.getAuthToken(userSource, authProvider, this.$route)
if (refreshTokenFromProvider) {
break
}
}
if (refreshTokenFromProvider) {
break
}
}
if (refreshTokenFromProvider) {
setToken(this, refreshTokenFromProvider, userSourceCookieTokenName, {
sameSite: 'Lax',
})
try {
await this.$store.dispatch('userSourceUser/refreshAuth', {
application: this.builder,
token: refreshTokenFromProvider,
})
} catch (error) {
if (error.response?.status === 401) {
// We logoff as the token has probably expired or became invalid
logOffAndReturnToLogin({
builder: this.builder,
store: this.$store,
redirect: (...args) => this.$router.push(...args),
})
} else {
throw error
}
}
}
},

View file

@ -1,18 +1,30 @@
import { Registerable } from '@baserow/modules/core/registry'
import { BaseAuthProviderType } from '@baserow/modules/core/authProviderTypes'
export class AppAuthProviderType extends Registerable {
export class AppAuthProviderType extends BaseAuthProviderType {
get name() {
throw new Error('Must be set on the type.')
return this.getName()
}
getLoginOptions(authProvider) {
return null
}
get component() {
return null
}
/**
* The form to edit this user source.
*/
get formComponent() {
return this.getAdminSettingsFormComponent()
}
getAuthToken(userSource, authProvider, route) {
return null
}
getOrder() {
return 0
handleError(userSource, authProvider, route) {
return null
}
}

View file

@ -2,8 +2,6 @@
display: flex;
flex-direction: column;
gap: 20px;
padding: 12px;
padding-bottom: 0;
.tabs__header {
border-bottom: none;

View file

@ -1,4 +1,16 @@
.update-user-source-form__auth-provider {
margin-top: 16px;
}
.update-user-source-form__auth-provider-header {
display: flex;
justify-content: space-between;
}
.update-user-source-form__auth-provider-form {
margin-left: 24px;
margin-top: 12px;
display: flex;
flex-direction: column;
align-items: stretch;
}

View file

@ -2,10 +2,9 @@ import { Registerable } from '@baserow/modules/core/registry'
import PasswordAuthIcon from '@baserow/modules/core/assets/images/providers/Key.svg'
/**
* The authorization provider type base class that can be extended when creating
* a plugin for the frontend.
* Base class for authorization provider types
*/
export class AuthProviderType extends Registerable {
export class BaseAuthProviderType extends Registerable {
/**
* The icon for the provider
*/
@ -14,14 +13,14 @@ export class AuthProviderType extends Registerable {
}
/**
* A human readable name of the application type.
* A human readable name of the authentication provider.
*/
getName() {
return null
}
/**
* A human readable name of the application type.
* A human readable name of the authentication provider.
*/
getProviderName(provider) {
return null
@ -81,6 +80,35 @@ export class AuthProviderType extends Registerable {
}
}
/**
* Whether we can create new providers on this type. Sometimes providers can't be
* created because of permissions reasons or because of unicity constraints.
* @returns a boolean saying if you can create new providers of this type.
*/
canCreateNew(authProviders) {
return true
}
/**
*
* @param {Object} err to handle
* @param {VueInstance} vueComponentInstance the vue component instance
* @returns true if the error is handled else false.
*/
handleServerError(vueComponentInstance, error) {
return false
}
getOrder() {
throw new Error('The order of the authentication provider must be set.')
}
}
/**
* The authorization provider type base class that can be extended when creating
* a plugin for the frontend.
*/
export class AuthProviderType extends BaseAuthProviderType {
constructor(...args) {
super(...args)
this.type = this.getType()
@ -108,10 +136,6 @@ export class AuthProviderType extends Registerable {
routeName: this.routeName,
}
}
getOrder() {
throw new Error('The order of an application type must be set.')
}
}
export class PasswordAuthProviderType extends AuthProviderType {
@ -139,6 +163,16 @@ export class PasswordAuthProviderType extends AuthProviderType {
return null
}
/**
* We can create only one password provider.
*/
canCreateNew(authProviders) {
return (
!authProviders[this.getType()] ||
authProviders[this.getType()].length === 0
)
}
getOrder() {
return 1
}

View file

@ -1,6 +1,7 @@
<template>
<div
v-if="open"
v-if="open || keepContent"
v-show="(keepContent && open) || !keepContent"
ref="modalWrapper"
class="modal__wrapper"
@click="outside($event)"
@ -172,6 +173,14 @@ export default {
default: false,
required: false,
},
// This flag allow to keep the modal content in case you want it to be available.
// Useful if you have a form inside the modal that is a sub part of the current
// form.
keepContent: {
type: Boolean,
default: false,
required: false,
},
},
data() {
return {

View file

@ -0,0 +1,40 @@
import form from '@baserow/modules/core/mixins/form'
export default {
mixins: [form],
props: {
authProviders: {
type: Object,
required: true,
},
authProvider: {
type: Object,
required: false,
default: () => ({}),
},
authProviderType: {
type: Object,
required: true,
},
},
data() {
return { serverErrors: {} }
},
computed: {
providerName() {
return this.authProviderType.getProviderName(this.authProvider)
},
},
methods: {
submit() {
this.$v.$touch()
if (this.$v.$invalid) {
return
}
this.$emit('submit', this.values)
},
handleServerError(error) {
return this.authProviderType.handleServerError(this, error)
},
},
}

View file

@ -85,28 +85,54 @@ export default {
firstError.scrollIntoView({ behavior: 'smooth' })
}
},
touch() {
/**
* Select all children that match the given predicate.
* @param {Function} predicate a function that receive the current child as parameter and
* should return true if the child should be accepted.
* @param {Boolean} deep true if you want to deeply search for child
* @returns
*/
getChildForms(predicate = (child) => 'isFormValid' in child, deep = false) {
const children = []
const getDeep = (child) => {
if (predicate(child)) {
children.push(child)
}
if (deep) {
// Search into children of children
child.$children.forEach(getDeep)
}
}
for (const child of this.$children) {
getDeep(child)
}
return children
},
touch(deep = false) {
if ('$v' in this) {
this.$v.$touch()
}
// Also touch all the child forms so that all the error messages are going to
// be displayed.
for (const child of this.$children) {
if ('isFormValid' in child && '$v' in child) {
child.touch()
}
for (const child of this.getChildForms(
(child) => 'touch' in child,
deep
)) {
child.touch(deep)
}
},
submit() {
submit(deep = false) {
if (this.selectedFieldIsDeactivated) {
return
}
this.touch()
this.touch(deep)
if (this.isFormValid()) {
this.$emit('submitted', this.getFormValues())
if (this.isFormValid(deep)) {
this.$emit('submitted', this.getFormValues(deep))
} else {
this.$nextTick(() => this.focusOnFirstError())
}
@ -120,21 +146,6 @@ export default {
? this.$v.values[fieldName].$error
: false
},
getChildForms(deep = false) {
const children = []
const getDeep = (child) => {
if ('isFormValid' in child) {
children.push(child)
} else if (deep) {
child.$children.forEach(getDeep)
}
}
for (const child of this.$children) {
getDeep(child)
}
return children
},
/**
* Returns true is everything is valid.
*
@ -151,8 +162,11 @@ export default {
* Returns true if all the child form components are valid.
*/
areChildFormsValid(deep = false) {
for (const child of this.getChildForms(deep)) {
if ('isFormValid' in child && !child.isFormValid()) {
for (const child of this.getChildForms(
(child) => 'isFormValid' in child,
deep
)) {
if (!child.isFormValid(deep)) {
return false
}
}
@ -162,17 +176,21 @@ export default {
* A method that can be overridden to do some mutations on the values before
* calling the submitted event.
*/
getFormValues() {
return Object.assign({}, this.values, this.getChildFormsValues())
getFormValues(deep = false) {
return Object.assign({}, this.values, this.getChildFormsValues(deep))
},
/**
* Returns an object containing the values of the child forms.
*/
getChildFormsValues() {
getChildFormsValues(deep = false) {
const children = this.getChildForms(
(child) => 'getChildFormsValues' in child,
deep
)
return Object.assign(
{},
...this.$children.map((child) => {
return 'getChildFormsValues' in child ? child.getFormValues() : {}
...children.map((child) => {
return child.getFormValues(deep)
})
)
},
@ -196,16 +214,22 @@ export default {
await this.$nextTick()
// Also reset the child forms after a tick to allow default values to be updated.
this.getChildForms(deep).forEach((child) => child.reset())
this.getChildForms((child) => 'reset' in child, deep).forEach((child) =>
child.reset()
)
},
/**
* Returns if a child form has indicated it handled the error, false otherwise.
*/
handleErrorByForm(error) {
handleErrorByForm(error, deep = false) {
let childHandledIt = false
for (const child of this.$children) {
if ('handleErrorByForm' in child && child.handleErrorByForm(error)) {
const children = this.getChildForms(
(child) => 'handleErrorByForm' in child,
deep
)
for (const child of children) {
if (child.handleErrorByForm(error)) {
childHandledIt = true
}
}

View file

@ -1,5 +1,6 @@
const FF_ENABLE_ALL = '*'
export const FF_DASHBOARDS = 'dashboards'
export const FF_AB_SSO = 'ab_sso'
/**
* A comma separated list of feature flags used to enable in-progress or not ready