mirror of
https://gitlab.com/bramw/baserow.git
synced 2025-04-17 18:32:35 +00:00
Add tests for oauth providers and views
This commit is contained in:
parent
6d5917ba84
commit
3d517fd1a9
6 changed files with 939 additions and 31 deletions
enterprise/backend
src/baserow_enterprise
tests/baserow_enterprise_tests
|
@ -18,6 +18,7 @@ from baserow.core.user.exceptions import DeactivatedUserException, DisabledSignu
|
|||
from baserow_enterprise.api.sso.serializers import SsoLoginRequestSerializer
|
||||
from baserow_enterprise.api.sso.utils import (
|
||||
SsoErrorCode,
|
||||
map_sso_exceptions,
|
||||
redirect_to_sign_in_error_page,
|
||||
redirect_user_on_success,
|
||||
)
|
||||
|
@ -63,6 +64,11 @@ class OAuth2LoginView(APIView):
|
|||
auth=[],
|
||||
)
|
||||
@validate_query_parameters(SsoLoginRequestSerializer, return_validated=True)
|
||||
@map_sso_exceptions(
|
||||
{
|
||||
AuthProviderModelNotFound: SsoErrorCode.PROVIDER_DOES_NOT_EXIST,
|
||||
}
|
||||
)
|
||||
@transaction.atomic
|
||||
def get(
|
||||
self, request: Request, provider_id: int, query_params: Dict[str, Any]
|
||||
|
@ -75,11 +81,8 @@ class OAuth2LoginView(APIView):
|
|||
if not is_sso_feature_active():
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.FEATURE_NOT_ACTIVE)
|
||||
|
||||
try:
|
||||
provider = AuthProviderHandler.get_auth_provider(provider_id)
|
||||
provider_type = auth_provider_type_registry.get_by_model(provider)
|
||||
except AuthProviderModelNotFound:
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.PROVIDER_DOES_NOT_EXIST)
|
||||
provider = AuthProviderHandler.get_auth_provider(provider_id)
|
||||
provider_type = auth_provider_type_registry.get_by_model(provider)
|
||||
|
||||
redirect_url = provider_type.get_authorization_url(
|
||||
provider.specific_class.objects.get(id=provider_id),
|
||||
|
@ -119,6 +122,16 @@ class OAuth2CallbackView(APIView):
|
|||
},
|
||||
auth=[],
|
||||
)
|
||||
@map_sso_exceptions(
|
||||
{
|
||||
AuthProviderModelNotFound: SsoErrorCode.PROVIDER_DOES_NOT_EXIST,
|
||||
AuthFlowError: SsoErrorCode.AUTH_FLOW_ERROR,
|
||||
DeactivatedUserException: SsoErrorCode.USER_DEACTIVATED,
|
||||
DifferentAuthProvider: SsoErrorCode.DIFFERENT_PROVIDER,
|
||||
GroupInvitationEmailMismatch: SsoErrorCode.GROUP_INVITATION_EMAIL_MISMATCH,
|
||||
DisabledSignupError: SsoErrorCode.SIGNUP_DISABLED,
|
||||
}
|
||||
)
|
||||
@transaction.atomic
|
||||
def get(self, request: Request, provider_id: int) -> HttpResponseRedirect:
|
||||
"""
|
||||
|
@ -131,29 +144,14 @@ class OAuth2CallbackView(APIView):
|
|||
if not is_sso_feature_active():
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.FEATURE_NOT_ACTIVE)
|
||||
|
||||
try:
|
||||
provider = AuthProviderHandler.get_auth_provider(provider_id)
|
||||
provider_type = auth_provider_type_registry.get_by_model(provider)
|
||||
code = request.query_params.get("code", None)
|
||||
user_info, original_url = provider_type.get_user_info(
|
||||
provider, code, request.session
|
||||
)
|
||||
user = AuthProviderHandler.get_or_create_user_and_sign_in_via_auth_provider(
|
||||
user_info, provider
|
||||
)
|
||||
except AuthProviderModelNotFound:
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.PROVIDER_DOES_NOT_EXIST)
|
||||
except AuthFlowError:
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.AUTH_FLOW_ERROR)
|
||||
except DeactivatedUserException:
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.USER_DEACTIVATED)
|
||||
except DifferentAuthProvider:
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.DIFFERENT_PROVIDER)
|
||||
except GroupInvitationEmailMismatch:
|
||||
return redirect_to_sign_in_error_page(
|
||||
SsoErrorCode.GROUP_INVITATION_EMAIL_MISMATCH
|
||||
)
|
||||
except DisabledSignupError:
|
||||
return redirect_to_sign_in_error_page(SsoErrorCode.SIGNUP_DISABLED)
|
||||
provider = AuthProviderHandler.get_auth_provider(provider_id)
|
||||
provider_type = auth_provider_type_registry.get_by_model(provider)
|
||||
code = request.query_params.get("code", None)
|
||||
user_info, original_url = provider_type.get_user_info(
|
||||
provider, code, request.session
|
||||
)
|
||||
user = AuthProviderHandler.get_or_create_user_and_sign_in_via_auth_provider(
|
||||
user_info, provider
|
||||
)
|
||||
|
||||
return redirect_user_on_success(user, original_url)
|
||||
|
|
|
@ -160,6 +160,9 @@ class OAuth2AuthProviderMixin:
|
|||
def get_user_info_url(self, instance: AuthProviderModel) -> str:
|
||||
return self.USER_INFO_URL
|
||||
|
||||
def get_access_token_url(self, instance: AuthProviderModel) -> str:
|
||||
return self.ACCESS_TOKEN_URL
|
||||
|
||||
def before_fetch_token(self, oauth: OAuth2Session) -> None:
|
||||
pass
|
||||
|
||||
|
@ -170,7 +173,7 @@ class OAuth2AuthProviderMixin:
|
|||
oauth = self.get_oauth_session(instance, session)
|
||||
self.before_fetch_token(oauth)
|
||||
token = oauth.fetch_token(
|
||||
self.ACCESS_TOKEN_URL,
|
||||
self.get_access_token_url(instance),
|
||||
code=code,
|
||||
client_secret=instance.secret,
|
||||
)
|
||||
|
@ -339,6 +342,9 @@ class GitLabAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
|
|||
def get_user_info_url(self, instance: AuthProviderModel) -> str:
|
||||
return f"{instance.base_url}{self.USER_INFO_PATH}"
|
||||
|
||||
def get_access_token_url(self, instance: AuthProviderModel) -> str:
|
||||
return f"{instance.base_url}{self.ACCESS_TOKEN_PATH}"
|
||||
|
||||
|
||||
class FacebookAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
|
||||
"""
|
||||
|
@ -367,7 +373,7 @@ class FacebookAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
|
|||
oauth = facebook_compliance_fix(oauth)
|
||||
authorization_url, state = oauth.authorization_url(self.AUTHORIZATION_URL)
|
||||
session["oauth_state"] = state
|
||||
self.add_params_to_session(session, query_params)
|
||||
self.push_request_data_to_session(session, query_params)
|
||||
return authorization_url
|
||||
|
||||
def before_fetch_token(self, oauth: OAuth2Session) -> None:
|
||||
|
@ -413,6 +419,12 @@ class OpenIdConnectAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
|
|||
def get_base_url(self, instance: AuthProviderModel) -> str:
|
||||
return instance.authorization_url
|
||||
|
||||
def get_access_token_url(self, instance: AuthProviderModel) -> str:
|
||||
return instance.access_token_url
|
||||
|
||||
def get_user_info_url(self, instance: AuthProviderModel) -> str:
|
||||
return instance.user_info_url
|
||||
|
||||
def get_wellknown_urls(self, base_url: str) -> WellKnownUrls:
|
||||
"""
|
||||
Queries the provider "wellknown URL endpoint" to retrieve OpenId Connect
|
||||
|
|
|
@ -0,0 +1,558 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qsl, urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.shortcuts import reverse
|
||||
from django.test.utils import override_settings
|
||||
|
||||
import pytest
|
||||
from rest_framework.status import HTTP_302_FOUND
|
||||
from rest_framework_simplejwt.tokens import RefreshToken
|
||||
|
||||
from baserow.api.user.jwt import get_user_from_jwt_token
|
||||
from baserow.core.handler import CoreHandler
|
||||
from baserow.core.models import GroupUser, Settings
|
||||
from baserow_enterprise.auth_provider.handler import UserInfo
|
||||
from baserow_enterprise.sso.exceptions import AuthFlowError
|
||||
|
||||
GET_USER_INFO = (
|
||||
"baserow_enterprise.sso.oauth2.auth_provider_types."
|
||||
"GoogleAuthProviderType.get_user_info"
|
||||
)
|
||||
|
||||
|
||||
def create_get_user_info_stub(provider):
|
||||
def get_user_info_stub(self, instance, code, session):
|
||||
assert instance == provider
|
||||
assert code == "validcode"
|
||||
data = json.loads(session.pop("oauth_request_data"))
|
||||
return (
|
||||
UserInfo(
|
||||
email="testuser@example.com",
|
||||
name="Test User",
|
||||
language=data.get("language", "en"),
|
||||
group_invitation_token=data.get("group_invitation_token", None),
|
||||
),
|
||||
data.get("original", ""),
|
||||
)
|
||||
|
||||
return get_user_info_stub
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
response = api_client.get(
|
||||
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": provider.id}),
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorSsoFeatureNotActive"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_login_provider_doesnt_exist(api_client, enterprise_data_fixture):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
response = api_client.get(
|
||||
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": 50}),
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorProviderDoesNotExist"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_login_with_url_param(api_client, enterprise_data_fixture):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:login",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?original=templates&group_invitation_token=t&language=en",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
|
||||
location = response.headers["Location"]
|
||||
|
||||
assert location.startswith("https://accounts.google.com/o/oauth2/v2/auth")
|
||||
assert "client_id=g_client_id" in location
|
||||
|
||||
session_data = json.loads(api_client.session.pop("oauth_request_data"))
|
||||
assert session_data == {
|
||||
"original": "templates",
|
||||
"group_invitation_token": "t",
|
||||
"language": "en",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback", kwargs={"provider_id": provider.id}
|
||||
),
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorSsoFeatureNotActive"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_provider_doesnt_exist(api_client, enterprise_data_fixture):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
response = api_client.get(
|
||||
reverse("api:enterprise:sso:oauth2:callback", kwargs={"provider_id": 50}),
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorProviderDoesNotExist"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_signup_success(api_client, enterprise_data_fixture):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{"original": "templates", "oauth_state": "generated_oauth_state"}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"].startswith(
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
|
||||
)
|
||||
|
||||
parsed_url = urlparse(response.headers["Location"])
|
||||
params = dict(parse_qsl(parsed_url.query))
|
||||
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
|
||||
assert user.email == "testuser@example.com"
|
||||
assert user.first_name == "Test User"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_signup_set_language(api_client, enterprise_data_fixture):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{
|
||||
"original": "templates",
|
||||
"language": "es",
|
||||
"oauth_state": "generated_oauth_state",
|
||||
}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"].startswith(
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
|
||||
)
|
||||
|
||||
parsed_url = urlparse(response.headers["Location"])
|
||||
params = dict(parse_qsl(parsed_url.query))
|
||||
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
|
||||
assert user.email == "testuser@example.com"
|
||||
assert user.first_name == "Test User"
|
||||
assert user.profile.language == "es"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_signup_group_invitation(
|
||||
api_client, data_fixture, enterprise_data_fixture
|
||||
):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
invitation = data_fixture.create_group_invitation(email="testuser@example.com")
|
||||
core_handler = CoreHandler()
|
||||
signer = core_handler.get_group_invitation_signer()
|
||||
group_invitation_token = signer.dumps(invitation.id)
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{
|
||||
"original": "templates",
|
||||
"group_invitation_token": group_invitation_token,
|
||||
"oauth_state": "generated_oauth_state",
|
||||
}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"].startswith(
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
|
||||
)
|
||||
|
||||
parsed_url = urlparse(response.headers["Location"])
|
||||
params = dict(parse_qsl(parsed_url.query))
|
||||
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
|
||||
assert user.email == "testuser@example.com"
|
||||
assert user.first_name == "Test User"
|
||||
|
||||
assert GroupUser.objects.get(user=user, group=invitation.group)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_signup_group_invitation_email_mismatch(
|
||||
api_client, data_fixture, enterprise_data_fixture
|
||||
):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
invitation = data_fixture.create_group_invitation(
|
||||
email="differenttestuser@example.com"
|
||||
)
|
||||
core_handler = CoreHandler()
|
||||
signer = core_handler.get_group_invitation_signer()
|
||||
group_invitation_token = signer.dumps(invitation.id)
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{
|
||||
"original": "templates",
|
||||
"group_invitation_token": group_invitation_token,
|
||||
"oauth_state": "generated_oauth_state",
|
||||
}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"].startswith(
|
||||
(
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorGroupInvitationEmailMismatch"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_signup_disabled(api_client, enterprise_data_fixture):
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
# disable signups
|
||||
instance_settings = Settings.objects.all()[:1].get()
|
||||
instance_settings.allow_new_signups = False
|
||||
instance_settings.save()
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{"original": "templates", "oauth_state": "generated_oauth_state"}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorSignupDisabled"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_login_success(
|
||||
api_client, data_fixture, enterprise_data_fixture
|
||||
):
|
||||
"""
|
||||
When a user already have an account associated with the specific provider,
|
||||
he can log in.
|
||||
"""
|
||||
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
first_name="Test User", email="testuser@example.com"
|
||||
)
|
||||
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
provider.users.add(user)
|
||||
provider.save()
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{"original": "templates", "oauth_state": "generated_oauth_state"}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"].startswith(
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
|
||||
)
|
||||
|
||||
parsed_url = urlparse(response.headers["Location"])
|
||||
params = dict(parse_qsl(parsed_url.query))
|
||||
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
|
||||
assert user.email == "testuser@example.com"
|
||||
assert user.first_name == "Test User"
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_login_deactivated_user(
|
||||
api_client, data_fixture, enterprise_data_fixture
|
||||
):
|
||||
"""
|
||||
Deactivated user can't log in anymore.
|
||||
"""
|
||||
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
first_name="Test User", email="testuser@example.com", is_active=False
|
||||
)
|
||||
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
provider.users.add(user)
|
||||
provider.save()
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{"original": "templates", "oauth_state": "generated_oauth_state"}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorUserDeactivated"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_login_different_provider(
|
||||
api_client, data_fixture, enterprise_data_fixture
|
||||
):
|
||||
"""
|
||||
Existing user account can't log in through a different auth provider.
|
||||
"""
|
||||
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
first_name="Test User", email="testuser@example.com"
|
||||
)
|
||||
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
create_get_user_info_stub(provider),
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{"original": "templates", "oauth_state": "generated_oauth_state"}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorDifferentProvider"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_oauth2_callback_login_auth_flow_error(
|
||||
api_client, data_fixture, enterprise_data_fixture
|
||||
):
|
||||
user, token = data_fixture.create_user_and_token(
|
||||
first_name="Test User", email="testuser@example.com"
|
||||
)
|
||||
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type="google", client_id="g_client_id", secret="g_secret"
|
||||
)
|
||||
|
||||
def get_user_info_raise_error(self, instance, code, session):
|
||||
raise AuthFlowError()
|
||||
|
||||
with patch(
|
||||
GET_USER_INFO,
|
||||
get_user_info_raise_error,
|
||||
):
|
||||
session = api_client.session
|
||||
session["oauth_request_data"] = json.dumps(
|
||||
{"original": "templates", "oauth_state": "generated_oauth_state"}
|
||||
)
|
||||
session.save()
|
||||
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
|
||||
|
||||
response = api_client.get(
|
||||
reverse(
|
||||
"api:enterprise:sso:oauth2:callback",
|
||||
kwargs={"provider_id": provider.id},
|
||||
)
|
||||
+ "?code=validcode",
|
||||
format="json",
|
||||
)
|
||||
|
||||
assert response.status_code == HTTP_302_FOUND
|
||||
assert response.headers["Location"] == (
|
||||
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
|
||||
"error?error=errorAuthFlowError"
|
||||
)
|
|
@ -0,0 +1,340 @@
|
|||
import json
|
||||
from urllib.parse import parse_qsl, urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.sessions.backends.base import SessionBase
|
||||
from django.test.utils import override_settings
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from baserow.core.registries import auth_provider_type_registry
|
||||
from baserow_enterprise.sso.oauth2.auth_provider_types import (
|
||||
OAuth2AuthProviderMixin,
|
||||
OpenIdConnectAuthProviderType,
|
||||
WellKnownUrls,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_type,extra_params",
|
||||
[
|
||||
("google", {}),
|
||||
("facebook", {}),
|
||||
("github", {}),
|
||||
("gitlab", {"base_url": "https://gitlab.com"}),
|
||||
("openid_connect", {"base_url": "https://gitlab.com"}),
|
||||
],
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
@override_settings(DEBUG=True)
|
||||
def test_get_login_options(provider_type, extra_params, enterprise_data_fixture):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type=provider_type,
|
||||
client_id="test_client_id",
|
||||
secret="test_secret",
|
||||
name="provider1",
|
||||
**extra_params,
|
||||
)
|
||||
provider2 = enterprise_data_fixture.create_oauth_provider(
|
||||
type=provider_type,
|
||||
client_id="g_client_id",
|
||||
secret="g_secret",
|
||||
name="provider2",
|
||||
enabled=False,
|
||||
**extra_params,
|
||||
)
|
||||
enterprise_data_fixture.enable_enterprise()
|
||||
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
|
||||
|
||||
assert provider_type_instance.get_login_options() == {
|
||||
"type": provider_type,
|
||||
"items": [
|
||||
{
|
||||
"redirect_url": (
|
||||
f"{settings.PUBLIC_BACKEND_URL}"
|
||||
f"/api/sso/oauth2/login/{provider.id}/"
|
||||
),
|
||||
"name": provider.name,
|
||||
"type": provider_type,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_type,extra_params",
|
||||
[
|
||||
("google", {}),
|
||||
("facebook", {}),
|
||||
("github", {}),
|
||||
("gitlab", {"base_url": "https://gitlab.com"}),
|
||||
(
|
||||
"openid_connect",
|
||||
{
|
||||
"base_url": "https://gitlab.com",
|
||||
"authorization_url": "https://gitlab.com/oauth/authorize",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
def test_get_authorization_url(provider_type, extra_params, enterprise_data_fixture):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type=provider_type,
|
||||
client_id="test_client_id",
|
||||
secret="test_secret",
|
||||
name="provider1",
|
||||
**extra_params,
|
||||
)
|
||||
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
|
||||
session = SessionBase()
|
||||
query_params = {"query_param": "param_value"}
|
||||
|
||||
auth_url = provider_type_instance.get_authorization_url(
|
||||
provider, session, query_params
|
||||
)
|
||||
|
||||
parsed_url = urlparse(auth_url)
|
||||
params = dict(parse_qsl(parsed_url.query))
|
||||
assert params["response_type"] == "code"
|
||||
assert params["client_id"] == "test_client_id"
|
||||
assert f"/api/sso/oauth2/callback/{provider.id}/" in params["redirect_uri"]
|
||||
assert params["state"] == session["oauth_state"]
|
||||
stored_query_params = json.loads(session["oauth_request_data"])
|
||||
assert stored_query_params["query_param"] == "param_value"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_type,extra_params",
|
||||
[
|
||||
("google", {}),
|
||||
("facebook", {}),
|
||||
("github", {}),
|
||||
("gitlab", {"base_url": "https://gitlab.com"}),
|
||||
(
|
||||
"openid_connect",
|
||||
{
|
||||
"base_url": "https://gitlab.com",
|
||||
"authorization_url": "https://gitlab.com/oauth/authorize",
|
||||
"access_token_url": "https://gitlab.com/oauth/token",
|
||||
"user_info_url": "https://gitlab.com/api/v4/user",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
@responses.activate
|
||||
def test_get_oauth_token_and_response(
|
||||
provider_type, extra_params, enterprise_data_fixture
|
||||
):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type=provider_type,
|
||||
client_id="test_client_id",
|
||||
secret="test_secret",
|
||||
name="provider1",
|
||||
**extra_params,
|
||||
)
|
||||
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
|
||||
session = SessionBase()
|
||||
code = "testcode"
|
||||
|
||||
# mock access token response
|
||||
access_token_response = {
|
||||
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
responses.add(
|
||||
responses.POST,
|
||||
provider_type_instance.get_access_token_url(provider),
|
||||
json=access_token_response,
|
||||
status=200,
|
||||
)
|
||||
|
||||
# mock get user info response
|
||||
oauth_response_data = {
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
responses.add(
|
||||
responses.GET,
|
||||
provider_type_instance.get_user_info_url(provider),
|
||||
json=oauth_response_data,
|
||||
status=200,
|
||||
)
|
||||
|
||||
token, json_response = provider_type_instance.get_oauth_token_and_response(
|
||||
provider, code, session
|
||||
)
|
||||
assert token == access_token_response
|
||||
assert json_response["email"] == oauth_response_data["email"]
|
||||
assert json_response["name"] == oauth_response_data["name"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_type,extra_params",
|
||||
[
|
||||
("google", {}),
|
||||
("facebook", {}),
|
||||
("github", {}),
|
||||
("gitlab", {"base_url": "https://gitlab.com"}),
|
||||
(
|
||||
"openid_connect",
|
||||
{
|
||||
"base_url": "https://gitlab.com",
|
||||
"authorization_url": "https://gitlab.com/oauth/authorize",
|
||||
"access_token_url": "https://gitlab.com/oauth/token",
|
||||
"user_info_url": "https://gitlab.com/api/v4/user",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
@responses.activate
|
||||
def test_get_user_info(provider_type, extra_params, enterprise_data_fixture):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type=provider_type,
|
||||
client_id="test_client_id",
|
||||
secret="test_secret",
|
||||
name="provider1",
|
||||
**extra_params,
|
||||
)
|
||||
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
|
||||
session = SessionBase()
|
||||
code = "testcode"
|
||||
query_params = {
|
||||
"group_invitation_token": "testgrouptoken",
|
||||
"language": "es",
|
||||
"original": "templates",
|
||||
}
|
||||
provider_type_instance.push_request_data_to_session(session, query_params)
|
||||
|
||||
# mock access token response
|
||||
access_token_response = {
|
||||
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
responses.add(
|
||||
responses.POST,
|
||||
provider_type_instance.get_access_token_url(provider),
|
||||
json=access_token_response,
|
||||
status=200,
|
||||
)
|
||||
|
||||
# mock get user info response
|
||||
oauth_response_data = {
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
responses.add(
|
||||
responses.GET,
|
||||
provider_type_instance.get_user_info_url(provider),
|
||||
json=oauth_response_data,
|
||||
status=200,
|
||||
)
|
||||
|
||||
# mock emails endpoint for github
|
||||
if provider_type == "github":
|
||||
responses.add(
|
||||
responses.GET,
|
||||
provider_type_instance.EMAILS_URL,
|
||||
json=[{"email": "testuser@example.com"}],
|
||||
status=200,
|
||||
)
|
||||
|
||||
user_info, original = provider_type_instance.get_user_info(provider, code, session)
|
||||
assert user_info.email == oauth_response_data["email"]
|
||||
assert user_info.name == oauth_response_data["name"]
|
||||
assert user_info.group_invitation_token == query_params["group_invitation_token"]
|
||||
assert user_info.language == query_params["language"]
|
||||
assert original == query_params["original"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider_type,extra_params",
|
||||
[
|
||||
("google", {}),
|
||||
("facebook", {}),
|
||||
("github", {}),
|
||||
("gitlab", {"base_url": "https://gitlab.com"}),
|
||||
(
|
||||
"openid_connect",
|
||||
{
|
||||
"base_url": "https://gitlab.com",
|
||||
"authorization_url": "https://gitlab.com/oauth/authorize",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.django_db
|
||||
def test_get_user_info_from_oauth_json_response(
|
||||
provider_type, extra_params, enterprise_data_fixture
|
||||
):
|
||||
provider = enterprise_data_fixture.create_oauth_provider(
|
||||
type=provider_type,
|
||||
client_id="test_client_id",
|
||||
secret="test_secret",
|
||||
name="provider1",
|
||||
**extra_params,
|
||||
)
|
||||
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
|
||||
session = SessionBase()
|
||||
query_params = {
|
||||
"group_invitation_token": "testgrouptoken",
|
||||
"language": "es",
|
||||
"original": "templates",
|
||||
}
|
||||
provider_type_instance.push_request_data_to_session(session, query_params)
|
||||
oauth_response_data = {
|
||||
"email": "testuser@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
user_info, original = provider_type_instance.get_user_info_from_oauth_json_response(
|
||||
oauth_response_data, session
|
||||
)
|
||||
assert user_info.email == oauth_response_data["email"]
|
||||
assert user_info.name == oauth_response_data["name"]
|
||||
assert user_info.group_invitation_token == query_params["group_invitation_token"]
|
||||
assert user_info.language == query_params["language"]
|
||||
assert original == query_params["original"]
|
||||
|
||||
|
||||
def test_push_pop_request_data_to_session():
|
||||
session = SessionBase()
|
||||
query_params = {
|
||||
"original": "templates",
|
||||
"oauth_state": "state",
|
||||
"group_invitation_token": "fjkldsfj",
|
||||
"language": "es",
|
||||
}
|
||||
mixin = OAuth2AuthProviderMixin()
|
||||
|
||||
mixin.push_request_data_to_session(session, query_params)
|
||||
retrieved_params = mixin.pop_request_data_from_session(session)
|
||||
|
||||
assert query_params == retrieved_params
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_openid_get_wellknown_urls():
|
||||
base_url = "http://example.com"
|
||||
endpoint_response = {
|
||||
"authorization_endpoint": "http://example.com/authorization",
|
||||
"token_endpoint": "http://example.com/accesstoken",
|
||||
"userinfo_endpoint": "http://example.com/userinfo",
|
||||
}
|
||||
responses.add(
|
||||
responses.GET,
|
||||
f"{base_url}/.well-known/openid-configuration",
|
||||
json=endpoint_response,
|
||||
status=200,
|
||||
)
|
||||
|
||||
wellknown_urls = OpenIdConnectAuthProviderType().get_wellknown_urls(base_url)
|
||||
|
||||
assert wellknown_urls == WellKnownUrls(
|
||||
authorization_url="http://example.com/authorization",
|
||||
access_token_url="http://example.com/accesstoken",
|
||||
user_info_url="http://example.com/userinfo",
|
||||
)
|
Loading…
Add table
Reference in a new issue