1
0
Fork 0
mirror of https://gitlab.com/bramw/baserow.git synced 2025-04-17 18:32:35 +00:00

Add tests for oauth providers and views

This commit is contained in:
Petr Stribny 2022-11-21 17:13:53 +00:00
parent 6d5917ba84
commit 3d517fd1a9
6 changed files with 939 additions and 31 deletions
enterprise/backend
src/baserow_enterprise
api/sso/oauth2
sso/oauth2
tests/baserow_enterprise_tests

View file

@ -18,6 +18,7 @@ from baserow.core.user.exceptions import DeactivatedUserException, DisabledSignu
from baserow_enterprise.api.sso.serializers import SsoLoginRequestSerializer
from baserow_enterprise.api.sso.utils import (
SsoErrorCode,
map_sso_exceptions,
redirect_to_sign_in_error_page,
redirect_user_on_success,
)
@ -63,6 +64,11 @@ class OAuth2LoginView(APIView):
auth=[],
)
@validate_query_parameters(SsoLoginRequestSerializer, return_validated=True)
@map_sso_exceptions(
{
AuthProviderModelNotFound: SsoErrorCode.PROVIDER_DOES_NOT_EXIST,
}
)
@transaction.atomic
def get(
self, request: Request, provider_id: int, query_params: Dict[str, Any]
@ -75,11 +81,8 @@ class OAuth2LoginView(APIView):
if not is_sso_feature_active():
return redirect_to_sign_in_error_page(SsoErrorCode.FEATURE_NOT_ACTIVE)
try:
provider = AuthProviderHandler.get_auth_provider(provider_id)
provider_type = auth_provider_type_registry.get_by_model(provider)
except AuthProviderModelNotFound:
return redirect_to_sign_in_error_page(SsoErrorCode.PROVIDER_DOES_NOT_EXIST)
provider = AuthProviderHandler.get_auth_provider(provider_id)
provider_type = auth_provider_type_registry.get_by_model(provider)
redirect_url = provider_type.get_authorization_url(
provider.specific_class.objects.get(id=provider_id),
@ -119,6 +122,16 @@ class OAuth2CallbackView(APIView):
},
auth=[],
)
@map_sso_exceptions(
{
AuthProviderModelNotFound: SsoErrorCode.PROVIDER_DOES_NOT_EXIST,
AuthFlowError: SsoErrorCode.AUTH_FLOW_ERROR,
DeactivatedUserException: SsoErrorCode.USER_DEACTIVATED,
DifferentAuthProvider: SsoErrorCode.DIFFERENT_PROVIDER,
GroupInvitationEmailMismatch: SsoErrorCode.GROUP_INVITATION_EMAIL_MISMATCH,
DisabledSignupError: SsoErrorCode.SIGNUP_DISABLED,
}
)
@transaction.atomic
def get(self, request: Request, provider_id: int) -> HttpResponseRedirect:
"""
@ -131,29 +144,14 @@ class OAuth2CallbackView(APIView):
if not is_sso_feature_active():
return redirect_to_sign_in_error_page(SsoErrorCode.FEATURE_NOT_ACTIVE)
try:
provider = AuthProviderHandler.get_auth_provider(provider_id)
provider_type = auth_provider_type_registry.get_by_model(provider)
code = request.query_params.get("code", None)
user_info, original_url = provider_type.get_user_info(
provider, code, request.session
)
user = AuthProviderHandler.get_or_create_user_and_sign_in_via_auth_provider(
user_info, provider
)
except AuthProviderModelNotFound:
return redirect_to_sign_in_error_page(SsoErrorCode.PROVIDER_DOES_NOT_EXIST)
except AuthFlowError:
return redirect_to_sign_in_error_page(SsoErrorCode.AUTH_FLOW_ERROR)
except DeactivatedUserException:
return redirect_to_sign_in_error_page(SsoErrorCode.USER_DEACTIVATED)
except DifferentAuthProvider:
return redirect_to_sign_in_error_page(SsoErrorCode.DIFFERENT_PROVIDER)
except GroupInvitationEmailMismatch:
return redirect_to_sign_in_error_page(
SsoErrorCode.GROUP_INVITATION_EMAIL_MISMATCH
)
except DisabledSignupError:
return redirect_to_sign_in_error_page(SsoErrorCode.SIGNUP_DISABLED)
provider = AuthProviderHandler.get_auth_provider(provider_id)
provider_type = auth_provider_type_registry.get_by_model(provider)
code = request.query_params.get("code", None)
user_info, original_url = provider_type.get_user_info(
provider, code, request.session
)
user = AuthProviderHandler.get_or_create_user_and_sign_in_via_auth_provider(
user_info, provider
)
return redirect_user_on_success(user, original_url)

View file

@ -160,6 +160,9 @@ class OAuth2AuthProviderMixin:
def get_user_info_url(self, instance: AuthProviderModel) -> str:
return self.USER_INFO_URL
def get_access_token_url(self, instance: AuthProviderModel) -> str:
return self.ACCESS_TOKEN_URL
def before_fetch_token(self, oauth: OAuth2Session) -> None:
pass
@ -170,7 +173,7 @@ class OAuth2AuthProviderMixin:
oauth = self.get_oauth_session(instance, session)
self.before_fetch_token(oauth)
token = oauth.fetch_token(
self.ACCESS_TOKEN_URL,
self.get_access_token_url(instance),
code=code,
client_secret=instance.secret,
)
@ -339,6 +342,9 @@ class GitLabAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
def get_user_info_url(self, instance: AuthProviderModel) -> str:
return f"{instance.base_url}{self.USER_INFO_PATH}"
def get_access_token_url(self, instance: AuthProviderModel) -> str:
return f"{instance.base_url}{self.ACCESS_TOKEN_PATH}"
class FacebookAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
"""
@ -367,7 +373,7 @@ class FacebookAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
oauth = facebook_compliance_fix(oauth)
authorization_url, state = oauth.authorization_url(self.AUTHORIZATION_URL)
session["oauth_state"] = state
self.add_params_to_session(session, query_params)
self.push_request_data_to_session(session, query_params)
return authorization_url
def before_fetch_token(self, oauth: OAuth2Session) -> None:
@ -413,6 +419,12 @@ class OpenIdConnectAuthProviderType(OAuth2AuthProviderMixin, AuthProviderType):
def get_base_url(self, instance: AuthProviderModel) -> str:
return instance.authorization_url
def get_access_token_url(self, instance: AuthProviderModel) -> str:
return instance.access_token_url
def get_user_info_url(self, instance: AuthProviderModel) -> str:
return instance.user_info_url
def get_wellknown_urls(self, base_url: str) -> WellKnownUrls:
"""
Queries the provider "wellknown URL endpoint" to retrieve OpenId Connect

View file

@ -0,0 +1,558 @@
import json
from unittest.mock import patch
from urllib.parse import parse_qsl, urlparse
from django.conf import settings
from django.shortcuts import reverse
from django.test.utils import override_settings
import pytest
from rest_framework.status import HTTP_302_FOUND
from rest_framework_simplejwt.tokens import RefreshToken
from baserow.api.user.jwt import get_user_from_jwt_token
from baserow.core.handler import CoreHandler
from baserow.core.models import GroupUser, Settings
from baserow_enterprise.auth_provider.handler import UserInfo
from baserow_enterprise.sso.exceptions import AuthFlowError
GET_USER_INFO = (
"baserow_enterprise.sso.oauth2.auth_provider_types."
"GoogleAuthProviderType.get_user_info"
)
def create_get_user_info_stub(provider):
def get_user_info_stub(self, instance, code, session):
assert instance == provider
assert code == "validcode"
data = json.loads(session.pop("oauth_request_data"))
return (
UserInfo(
email="testuser@example.com",
name="Test User",
language=data.get("language", "en"),
group_invitation_token=data.get("group_invitation_token", None),
),
data.get("original", ""),
)
return get_user_info_stub
@pytest.mark.django_db
def test_oauth2_login_feature_not_active(api_client, enterprise_data_fixture):
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
response = api_client.get(
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": provider.id}),
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorSsoFeatureNotActive"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_login_provider_doesnt_exist(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
response = api_client.get(
reverse("api:enterprise:sso:oauth2:login", kwargs={"provider_id": 50}),
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorProviderDoesNotExist"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_login_with_url_param(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:login",
kwargs={"provider_id": provider.id},
)
+ "?original=templates&group_invitation_token=t&language=en",
format="json",
)
assert response.status_code == HTTP_302_FOUND
location = response.headers["Location"]
assert location.startswith("https://accounts.google.com/o/oauth2/v2/auth")
assert "client_id=g_client_id" in location
session_data = json.loads(api_client.session.pop("oauth_request_data"))
assert session_data == {
"original": "templates",
"group_invitation_token": "t",
"language": "en",
}
@pytest.mark.django_db
def test_oauth2_callback_feature_not_active(api_client, enterprise_data_fixture):
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback", kwargs={"provider_id": provider.id}
),
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorSsoFeatureNotActive"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_provider_doesnt_exist(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
response = api_client.get(
reverse("api:enterprise:sso:oauth2:callback", kwargs={"provider_id": 50}),
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorProviderDoesNotExist"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_signup_success(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{"original": "templates", "oauth_state": "generated_oauth_state"}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"].startswith(
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
)
parsed_url = urlparse(response.headers["Location"])
params = dict(parse_qsl(parsed_url.query))
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
assert user.email == "testuser@example.com"
assert user.first_name == "Test User"
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_signup_set_language(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{
"original": "templates",
"language": "es",
"oauth_state": "generated_oauth_state",
}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"].startswith(
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
)
parsed_url = urlparse(response.headers["Location"])
params = dict(parse_qsl(parsed_url.query))
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
assert user.email == "testuser@example.com"
assert user.first_name == "Test User"
assert user.profile.language == "es"
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_signup_group_invitation(
api_client, data_fixture, enterprise_data_fixture
):
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
invitation = data_fixture.create_group_invitation(email="testuser@example.com")
core_handler = CoreHandler()
signer = core_handler.get_group_invitation_signer()
group_invitation_token = signer.dumps(invitation.id)
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{
"original": "templates",
"group_invitation_token": group_invitation_token,
"oauth_state": "generated_oauth_state",
}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"].startswith(
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
)
parsed_url = urlparse(response.headers["Location"])
params = dict(parse_qsl(parsed_url.query))
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
assert user.email == "testuser@example.com"
assert user.first_name == "Test User"
assert GroupUser.objects.get(user=user, group=invitation.group)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_signup_group_invitation_email_mismatch(
api_client, data_fixture, enterprise_data_fixture
):
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
invitation = data_fixture.create_group_invitation(
email="differenttestuser@example.com"
)
core_handler = CoreHandler()
signer = core_handler.get_group_invitation_signer()
group_invitation_token = signer.dumps(invitation.id)
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{
"original": "templates",
"group_invitation_token": group_invitation_token,
"oauth_state": "generated_oauth_state",
}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"].startswith(
(
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorGroupInvitationEmailMismatch"
)
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_signup_disabled(api_client, enterprise_data_fixture):
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
# disable signups
instance_settings = Settings.objects.all()[:1].get()
instance_settings.allow_new_signups = False
instance_settings.save()
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{"original": "templates", "oauth_state": "generated_oauth_state"}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorSignupDisabled"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_login_success(
api_client, data_fixture, enterprise_data_fixture
):
"""
When a user already have an account associated with the specific provider,
he can log in.
"""
user, token = data_fixture.create_user_and_token(
first_name="Test User", email="testuser@example.com"
)
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
provider.users.add(user)
provider.save()
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{"original": "templates", "oauth_state": "generated_oauth_state"}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"].startswith(
f"{settings.PUBLIC_WEB_FRONTEND_URL}/templates?token="
)
parsed_url = urlparse(response.headers["Location"])
params = dict(parse_qsl(parsed_url.query))
user = get_user_from_jwt_token(params["token"], token_class=RefreshToken)
assert user.email == "testuser@example.com"
assert user.first_name == "Test User"
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_login_deactivated_user(
api_client, data_fixture, enterprise_data_fixture
):
"""
Deactivated user can't log in anymore.
"""
user, token = data_fixture.create_user_and_token(
first_name="Test User", email="testuser@example.com", is_active=False
)
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
provider.users.add(user)
provider.save()
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{"original": "templates", "oauth_state": "generated_oauth_state"}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorUserDeactivated"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_login_different_provider(
api_client, data_fixture, enterprise_data_fixture
):
"""
Existing user account can't log in through a different auth provider.
"""
user, token = data_fixture.create_user_and_token(
first_name="Test User", email="testuser@example.com"
)
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
with patch(
GET_USER_INFO,
create_get_user_info_stub(provider),
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{"original": "templates", "oauth_state": "generated_oauth_state"}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorDifferentProvider"
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_oauth2_callback_login_auth_flow_error(
api_client, data_fixture, enterprise_data_fixture
):
user, token = data_fixture.create_user_and_token(
first_name="Test User", email="testuser@example.com"
)
enterprise_data_fixture.enable_enterprise()
provider = enterprise_data_fixture.create_oauth_provider(
type="google", client_id="g_client_id", secret="g_secret"
)
def get_user_info_raise_error(self, instance, code, session):
raise AuthFlowError()
with patch(
GET_USER_INFO,
get_user_info_raise_error,
):
session = api_client.session
session["oauth_request_data"] = json.dumps(
{"original": "templates", "oauth_state": "generated_oauth_state"}
)
session.save()
api_client.cookies[settings.SESSION_COOKIE_NAME] = session.session_key
response = api_client.get(
reverse(
"api:enterprise:sso:oauth2:callback",
kwargs={"provider_id": provider.id},
)
+ "?code=validcode",
format="json",
)
assert response.status_code == HTTP_302_FOUND
assert response.headers["Location"] == (
f"{settings.PUBLIC_WEB_FRONTEND_URL}/login/"
"error?error=errorAuthFlowError"
)

View file

@ -0,0 +1,340 @@
import json
from urllib.parse import parse_qsl, urlparse
from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase
from django.test.utils import override_settings
import pytest
import responses
from baserow.core.registries import auth_provider_type_registry
from baserow_enterprise.sso.oauth2.auth_provider_types import (
OAuth2AuthProviderMixin,
OpenIdConnectAuthProviderType,
WellKnownUrls,
)
@pytest.mark.parametrize(
"provider_type,extra_params",
[
("google", {}),
("facebook", {}),
("github", {}),
("gitlab", {"base_url": "https://gitlab.com"}),
("openid_connect", {"base_url": "https://gitlab.com"}),
],
)
@pytest.mark.django_db
@override_settings(DEBUG=True)
def test_get_login_options(provider_type, extra_params, enterprise_data_fixture):
provider = enterprise_data_fixture.create_oauth_provider(
type=provider_type,
client_id="test_client_id",
secret="test_secret",
name="provider1",
**extra_params,
)
provider2 = enterprise_data_fixture.create_oauth_provider(
type=provider_type,
client_id="g_client_id",
secret="g_secret",
name="provider2",
enabled=False,
**extra_params,
)
enterprise_data_fixture.enable_enterprise()
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
assert provider_type_instance.get_login_options() == {
"type": provider_type,
"items": [
{
"redirect_url": (
f"{settings.PUBLIC_BACKEND_URL}"
f"/api/sso/oauth2/login/{provider.id}/"
),
"name": provider.name,
"type": provider_type,
}
],
}
@pytest.mark.parametrize(
"provider_type,extra_params",
[
("google", {}),
("facebook", {}),
("github", {}),
("gitlab", {"base_url": "https://gitlab.com"}),
(
"openid_connect",
{
"base_url": "https://gitlab.com",
"authorization_url": "https://gitlab.com/oauth/authorize",
},
),
],
)
@pytest.mark.django_db
def test_get_authorization_url(provider_type, extra_params, enterprise_data_fixture):
provider = enterprise_data_fixture.create_oauth_provider(
type=provider_type,
client_id="test_client_id",
secret="test_secret",
name="provider1",
**extra_params,
)
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
session = SessionBase()
query_params = {"query_param": "param_value"}
auth_url = provider_type_instance.get_authorization_url(
provider, session, query_params
)
parsed_url = urlparse(auth_url)
params = dict(parse_qsl(parsed_url.query))
assert params["response_type"] == "code"
assert params["client_id"] == "test_client_id"
assert f"/api/sso/oauth2/callback/{provider.id}/" in params["redirect_uri"]
assert params["state"] == session["oauth_state"]
stored_query_params = json.loads(session["oauth_request_data"])
assert stored_query_params["query_param"] == "param_value"
@pytest.mark.parametrize(
"provider_type,extra_params",
[
("google", {}),
("facebook", {}),
("github", {}),
("gitlab", {"base_url": "https://gitlab.com"}),
(
"openid_connect",
{
"base_url": "https://gitlab.com",
"authorization_url": "https://gitlab.com/oauth/authorize",
"access_token_url": "https://gitlab.com/oauth/token",
"user_info_url": "https://gitlab.com/api/v4/user",
},
),
],
)
@pytest.mark.django_db
@responses.activate
def test_get_oauth_token_and_response(
provider_type, extra_params, enterprise_data_fixture
):
provider = enterprise_data_fixture.create_oauth_provider(
type=provider_type,
client_id="test_client_id",
secret="test_secret",
name="provider1",
**extra_params,
)
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
session = SessionBase()
code = "testcode"
# mock access token response
access_token_response = {
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
"token_type": "Bearer",
}
responses.add(
responses.POST,
provider_type_instance.get_access_token_url(provider),
json=access_token_response,
status=200,
)
# mock get user info response
oauth_response_data = {
"email": "testuser@example.com",
"name": "Test User",
}
responses.add(
responses.GET,
provider_type_instance.get_user_info_url(provider),
json=oauth_response_data,
status=200,
)
token, json_response = provider_type_instance.get_oauth_token_and_response(
provider, code, session
)
assert token == access_token_response
assert json_response["email"] == oauth_response_data["email"]
assert json_response["name"] == oauth_response_data["name"]
@pytest.mark.parametrize(
"provider_type,extra_params",
[
("google", {}),
("facebook", {}),
("github", {}),
("gitlab", {"base_url": "https://gitlab.com"}),
(
"openid_connect",
{
"base_url": "https://gitlab.com",
"authorization_url": "https://gitlab.com/oauth/authorize",
"access_token_url": "https://gitlab.com/oauth/token",
"user_info_url": "https://gitlab.com/api/v4/user",
},
),
],
)
@pytest.mark.django_db
@responses.activate
def test_get_user_info(provider_type, extra_params, enterprise_data_fixture):
provider = enterprise_data_fixture.create_oauth_provider(
type=provider_type,
client_id="test_client_id",
secret="test_secret",
name="provider1",
**extra_params,
)
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
session = SessionBase()
code = "testcode"
query_params = {
"group_invitation_token": "testgrouptoken",
"language": "es",
"original": "templates",
}
provider_type_instance.push_request_data_to_session(session, query_params)
# mock access token response
access_token_response = {
"access_token": "MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3",
"token_type": "Bearer",
}
responses.add(
responses.POST,
provider_type_instance.get_access_token_url(provider),
json=access_token_response,
status=200,
)
# mock get user info response
oauth_response_data = {
"email": "testuser@example.com",
"name": "Test User",
}
responses.add(
responses.GET,
provider_type_instance.get_user_info_url(provider),
json=oauth_response_data,
status=200,
)
# mock emails endpoint for github
if provider_type == "github":
responses.add(
responses.GET,
provider_type_instance.EMAILS_URL,
json=[{"email": "testuser@example.com"}],
status=200,
)
user_info, original = provider_type_instance.get_user_info(provider, code, session)
assert user_info.email == oauth_response_data["email"]
assert user_info.name == oauth_response_data["name"]
assert user_info.group_invitation_token == query_params["group_invitation_token"]
assert user_info.language == query_params["language"]
assert original == query_params["original"]
@pytest.mark.parametrize(
"provider_type,extra_params",
[
("google", {}),
("facebook", {}),
("github", {}),
("gitlab", {"base_url": "https://gitlab.com"}),
(
"openid_connect",
{
"base_url": "https://gitlab.com",
"authorization_url": "https://gitlab.com/oauth/authorize",
},
),
],
)
@pytest.mark.django_db
def test_get_user_info_from_oauth_json_response(
provider_type, extra_params, enterprise_data_fixture
):
provider = enterprise_data_fixture.create_oauth_provider(
type=provider_type,
client_id="test_client_id",
secret="test_secret",
name="provider1",
**extra_params,
)
provider_type_instance = auth_provider_type_registry.get_by_model(provider)
session = SessionBase()
query_params = {
"group_invitation_token": "testgrouptoken",
"language": "es",
"original": "templates",
}
provider_type_instance.push_request_data_to_session(session, query_params)
oauth_response_data = {
"email": "testuser@example.com",
"name": "Test User",
}
user_info, original = provider_type_instance.get_user_info_from_oauth_json_response(
oauth_response_data, session
)
assert user_info.email == oauth_response_data["email"]
assert user_info.name == oauth_response_data["name"]
assert user_info.group_invitation_token == query_params["group_invitation_token"]
assert user_info.language == query_params["language"]
assert original == query_params["original"]
def test_push_pop_request_data_to_session():
session = SessionBase()
query_params = {
"original": "templates",
"oauth_state": "state",
"group_invitation_token": "fjkldsfj",
"language": "es",
}
mixin = OAuth2AuthProviderMixin()
mixin.push_request_data_to_session(session, query_params)
retrieved_params = mixin.pop_request_data_from_session(session)
assert query_params == retrieved_params
@responses.activate
def test_openid_get_wellknown_urls():
base_url = "http://example.com"
endpoint_response = {
"authorization_endpoint": "http://example.com/authorization",
"token_endpoint": "http://example.com/accesstoken",
"userinfo_endpoint": "http://example.com/userinfo",
}
responses.add(
responses.GET,
f"{base_url}/.well-known/openid-configuration",
json=endpoint_response,
status=200,
)
wellknown_urls = OpenIdConnectAuthProviderType().get_wellknown_urls(base_url)
assert wellknown_urls == WellKnownUrls(
authorization_url="http://example.com/authorization",
access_token_url="http://example.com/accesstoken",
user_info_url="http://example.com/userinfo",
)