0
0
Fork 0
mirror of https://github.com/healthchecks/healthchecks.git synced 2025-04-11 15:51:19 +00:00

Improve OAuth error handling in hc.front.views.add_discord_complete

This commit is contained in:
Pēteris Caune 2023-10-25 13:49:14 +03:00
parent 8b006e0d58
commit da22899cd6
No known key found for this signature in database
GPG key ID: E28D7679E9A9EDE2
3 changed files with 36 additions and 12 deletions

View file

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from django.test.utils import override_settings
from hc.api.models import Channel
from hc.test import BaseTestCase
from hc.test import BaseTestCase, nolog
@override_settings(DISCORD_CLIENT_ID="t1", DISCORD_CLIENT_SECRET="s1")
@ -41,6 +41,24 @@ class AddDiscordCompleteTestCase(BaseTestCase):
# Session should now be clean
self.assertFalse("add_discord" in self.client.session)
@nolog
@patch("hc.front.views.curl.post", autospec=True)
def test_it_handles_unexpected_oauth_response(self, mock_post: Mock) -> None:
session = self.client.session
session["add_discord"] = ("foo", str(self.project.code))
session.save()
oauth_response = "surprise"
mock_post.return_value.text = json.dumps(oauth_response)
mock_post.return_value.json.return_value = oauth_response
url = self.url + "?code=12345678&state=foo"
self.client.login(username="alice@example.org", password="password")
r = self.client.get(url, follow=True)
self.assertRedirects(r, self.channels_url)
self.assertContains(r, "Received an unexpected response from Discord.")
def test_it_avoids_csrf(self) -> None:
session = self.client.session
session["add_discord"] = ("foo", str(self.project.code))

View file

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from django.test.utils import override_settings
from hc.api.models import Channel
from hc.test import BaseTestCase
from hc.test import BaseTestCase, nolog
@override_settings(SLACK_CLIENT_ID="fake-client-id")
@ -53,6 +53,7 @@ class AddSlackCompleteTestCase(BaseTestCase):
r = self.client.get(url)
self.assertEqual(r.status_code, 403)
@nolog
@patch("hc.front.views.curl.post", autospec=True)
def test_it_handles_oauth_error(self, mock_post: Mock) -> None:
session = self.client.session
@ -69,8 +70,9 @@ class AddSlackCompleteTestCase(BaseTestCase):
self.client.login(username="alice@example.org", password="password")
r = self.client.get(url, follow=True)
self.assertRedirects(r, self.channels_url)
self.assertContains(r, "Slack returned an unexpected response.")
self.assertContains(r, "Received an unexpected response from Slack")
@nolog
@patch("hc.front.views.curl.post", autospec=True)
def test_it_handles_unexpected_oauth_response(self, mock_post: Mock) -> None:
session = self.client.session
@ -87,7 +89,7 @@ class AddSlackCompleteTestCase(BaseTestCase):
self.client.login(username="alice@example.org", password="password")
r = self.client.get(url, follow=True)
self.assertRedirects(r, self.channels_url)
self.assertContains(r, "Slack returned an unexpected response.")
self.assertContains(r, "Received an unexpected response from Slack")
@override_settings(SLACK_CLIENT_ID=None)
def test_it_requires_client_id(self) -> None:

View file

@ -1766,15 +1766,19 @@ def add_discord_complete(request: AuthenticatedHttpRequest) -> HttpResponse:
result = curl.post("https://discordapp.com/api/oauth2/token", data)
doc = result.json()
if "access_token" in doc:
channel = Channel(kind="discord", project=project)
channel.value = result.text
channel.save()
channel.assign_all_checks()
messages.success(request, "The Discord integration has been added!")
else:
messages.warning(request, "Something went wrong.")
if not isinstance(doc, dict) or "access_token" not in doc:
messages.warning(
request,
"Received an unexpected response from Discord. Integration not added.",
)
logger.warning("Unexpected Discord OAuth response: %s", result.text)
return redirect("hc-channels", project.code)
channel = Channel(kind="discord", project=project)
channel.value = result.text
channel.save()
channel.assign_all_checks()
messages.success(request, "The Discord integration has been added!")
return redirect("hc-channels", project.code)