# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Tests for scope handling in views."""

import base64
import datetime as dt
from collections.abc import Awaitable, Callable
from typing import Any, Protocol, cast
from unittest import mock

import django.http
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
from django.test import RequestFactory
from django.utils import timezone
from django.utils.functional import SimpleLazyObject
from rest_framework import HTTP_HEADER_ENCODING

from debusine.db.context import ContextConsistencyError, context
from debusine.db.models import Token, User
from debusine.db.playground import scenarios
from debusine.server.middlewares.scopes import (
    AuthorizationMiddleware,
    ScopeMiddleware,
)
from debusine.test.django import TestCase

#: Singleton response used to check if a middleware called get_response()
MOCK_RESPONSE = django.http.HttpResponse()


class MiddlewareProtocol(Protocol):
    """A protocol defining our expectations of middlewares."""

    def __init__(
        self,
        get_response: Callable[
            [django.http.HttpRequest], django.http.HttpResponse
        ],
    ) -> None:
        """Instantiate the middleware."""

    def __call__(
        self, request: django.http.HttpRequest
    ) -> django.http.HttpResponseBase | Awaitable[django.http.HttpResponseBase]:
        """Call the middleware."""


class MiddlewareTestMixin[M: MiddlewareProtocol](TestCase):
    """Common functions to test middlewares."""

    middleware_class: type[M]

    scenario = scenarios.DefaultScopeUser()

    def get_middleware(self) -> M:
        """Instantiate a test middleware."""
        return self.middleware_class(
            cast(
                Callable[[django.http.HttpRequest], django.http.HttpResponse],
                lambda _: MOCK_RESPONSE,
            )
        )

    def request_for_path(
        self, path_info: str, **kwargs: Any
    ) -> django.http.HttpRequest:
        """Configure a request for path_info with the middleware."""
        request = RequestFactory().get(path_info, **kwargs)
        self.get_middleware()(request)
        return request


class ScopeMiddlewareTests(MiddlewareTestMixin[ScopeMiddleware]):
    """Test ScopeMiddleware."""

    middleware_class = ScopeMiddleware

    def test_get_scope(self) -> None:
        """Test get_scope."""
        other_scope = self.playground.get_or_create_scope("other")
        mw = self.get_middleware()
        self.assertEqual(
            mw.get_scope(self.scenario.scope.name), self.scenario.scope
        )
        self.assertEqual(mw.get_scope("other"), other_scope)
        with self.assertRaises(django.http.Http404):
            mw.get_scope("nonexistent")

    def test_request_setup_debusine(self) -> None:
        """Test request setup with the fallback scope."""
        request = self.request_for_path(f"/{settings.DEBUSINE_DEFAULT_SCOPE}/")
        self.assertEqual(getattr(request, "urlconf"), "debusine.project.urls")
        self.assertEqual(context.scope, self.scenario.scope)

    def test_request_setup_api(self) -> None:
        """Test request setup for API calls."""
        request = self.request_for_path("/api")
        self.assertEqual(getattr(request, "urlconf"), "debusine.project.urls")
        self.assertEqual(context.scope, self.scenario.scope)

    def test_request_setup_api_scope_in_header(self) -> None:
        """Test request setup for API calls with scope in header."""
        other_scope = self.playground.get_or_create_scope("other")
        request = self.request_for_path(
            "/api", headers={"X-Debusine-Scope": "other"}
        )
        self.assertEqual(
            getattr(request, "urlconf"), "debusine.server._urlconfs.other"
        )
        self.assertEqual(context.scope, other_scope)

    def test_request_setup_scope(self) -> None:
        """Test request setup with a valid scope."""
        other_scope = self.playground.get_or_create_scope("other")
        request = self.request_for_path("/other/")
        self.assertEqual(
            getattr(request, "urlconf"),
            "debusine.server._urlconfs.other",
        )
        self.assertEqual(context.scope, other_scope)

    def test_request_setup_wrong_scope(self) -> None:
        """Test request setup with an invalid scope."""
        with self.assertRaises(django.http.Http404):
            self.request_for_path("/wrongscope/")


class AuthorizationMiddlewareTests(
    MiddlewareTestMixin[AuthorizationMiddleware]
):
    """Test AuthorizationMiddleware."""

    middleware_class = AuthorizationMiddleware

    def setUp(self) -> None:
        """Set the scope in context."""
        super().setUp()
        context.set_scope(self.scenario.scope)

    def make_request(
        self,
        user: User | None = None,
        token: Token | str | None = None,
        *,
        secure: bool = True,
    ) -> django.http.HttpRequest:
        """Create a request."""
        headers = {}
        match token:
            case Token():
                headers["token"] = token.key
            case str():
                headers["token"] = token
        request = RequestFactory().get("/", headers=headers, secure=secure)
        request.user = user or AnonymousUser()
        return request

    def make_request_archive_view(
        self,
        user: User | None = None,
        token: Token | str | None = None,
        *,
        secure: bool = True,
    ) -> django.http.HttpRequest:
        """Create a request for archive views."""
        request = self.make_request(user, token, secure=secure)
        setattr(request, "is_archive_view", True)
        return request

    def add_basic_auth_header(
        self, request: django.http.HttpRequest, username: str, password: str
    ) -> None:
        """Add an http basic auth header to the request."""
        credentials = f"{username}:{password}"
        base64_credentials = base64.b64encode(
            credentials.encode(HTTP_HEADER_ENCODING)
        ).decode("ascii")
        request.META["HTTP_AUTHORIZATION"] = f"Basic {base64_credentials}"

    def add_basic_auth_token(
        self, request: django.http.HttpRequest, token: Token
    ) -> None:
        """Pass the user token as a basic auth header."""
        assert token.user is not None
        self.add_basic_auth_header(request, token.user.username, token.key)

    def assertSetsUserNone(
        self, request: django.http.HttpRequest
    ) -> django.http.HttpResponse:
        """
        Ensure that the middleware sets user context to None.

        :returns: the response from the middleware.
        """
        mw = self.get_middleware()
        response = mw(request)
        assert isinstance(response, django.http.HttpResponse)
        self.assertIsNone(context.user)
        return response

    def assertSetsUser(
        self, request: django.http.HttpRequest, user: User | AnonymousUser
    ) -> django.http.HttpResponse:
        """
        Ensure that the middleware sets user context to the given user.

        :returns: the response from the middleware.
        """
        mw = self.get_middleware()
        response = mw(request)
        assert isinstance(response, django.http.HttpResponse)
        self.assertEqual(context.user, user)
        return response

    def assertResponse403(
        self, response: django.http.HttpResponse, message: str
    ) -> None:
        """Ensure the response is a 403 with the given message."""
        assert isinstance(response, django.http.HttpResponse)
        self.assertEqual(response.status_code, 403)
        self.assertEqual(response.content.decode(), message)

    def test_allowed(self) -> None:
        """Test the allowed case."""
        # There currently is no permission predicate implemented for scope
        # visibility, only a comment placeholder in Context.set_user.
        request = self.make_request(user=self.scenario.user)
        response = self.assertSetsUser(request, self.scenario.user)
        self.assertIs(response, MOCK_RESPONSE)

    def test_forbidden(self) -> None:
        """Test the forbidden case."""
        # There currently is no permission predicate implemented for scope
        # visibility, only a comment placeholder in Context.set_user.
        # Mock Context.set_user instead of a permission, to test handling of
        # failure to set it
        request = self.make_request(user=self.scenario.user)
        with mock.patch(
            "debusine.db.context.Context.set_user",
            side_effect=ContextConsistencyError("expected fail"),
        ):
            response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "expected fail")

    def test_evaluates_lazy_object(self) -> None:
        """Lazy objects are evaluated before storing them in the context."""
        mw = self.get_middleware()
        request = RequestFactory().get("/")
        request.user = SimpleLazyObject(lambda: self.scenario.user)  # type: ignore[assignment]
        self.assertIs(mw(request), MOCK_RESPONSE)
        with mock.patch.object(User.objects, "get", side_effect=RuntimeError):
            self.assertEqual(context.user, self.scenario.user)

    def test_bare_token(self) -> None:
        """Check handling of bare tokens."""
        token = self.playground.create_bare_token()
        request = self.make_request(token=token)
        response = self.assertSetsUser(request, AnonymousUser())
        self.assertIs(response, MOCK_RESPONSE)
        self.assertEqual(getattr(request, "_debusine_token"), token)

    def test_user_token(self) -> None:
        """Check handling of user tokens."""
        token = self.scenario.create_user_token()
        request = self.make_request(token=token)
        response = self.assertSetsUser(request, self.scenario.user)
        self.assertIs(response, MOCK_RESPONSE)
        self.assertEqual(getattr(request, "_debusine_token"), token)
        self.assertIsNone(context.worker_token)

    def test_user_token_inactive_user(self) -> None:
        """Inactive users are not allowed."""
        token = self.scenario.create_user_token()
        assert token.user is not None
        token.user.is_active = False
        token.user.save()
        request = self.make_request(token=token)
        response = self.assertSetsUserNone(request)
        assert isinstance(response, django.http.HttpResponse)
        self.assertResponse403(response, "user token has an inactive user")

    def test_user_token_logged_in(self) -> None:
        """Check handling of user tokens while logged in."""
        token = self.scenario.create_user_token()
        request = self.make_request(token=token, user=token.user)
        response = self.assertSetsUserNone(request)
        self.assertResponse403(
            response, "cannot use both Django and user token authentication"
        )
        self.assertIsNone(context.worker_token)

    def test_user_token_forbidden(self) -> None:
        """Test the forbidden case for user token."""
        token = self.scenario.create_user_token()
        request = self.make_request(token=token)

        # There currently is no permission predicate implemented for scope
        # visibility, only a comment placeholder in Context.set_user.
        # Mock Context.set_user instead of a permission, to test handling of
        # failure to set it
        with mock.patch(
            "debusine.db.context.Context.set_user",
            side_effect=ContextConsistencyError("expected fail"),
        ):
            response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "expected fail")

    def test_worker_token(self) -> None:
        """Check handling of worker tokens."""
        token = self.playground.create_worker_token()
        request = self.make_request(token=token)
        response = self.assertSetsUserNone(request)
        self.assertIs(response, MOCK_RESPONSE)
        self.assertEqual(getattr(request, "_debusine_token"), token)
        self.assertEqual(context.worker_token, token)

    def test_worker_token_disabled(self) -> None:
        """Check handling of disabled tokens."""
        token = self.playground.create_worker_token()
        token.disable()
        request = self.make_request(token=token)
        response = self.assertSetsUser(request, AnonymousUser())
        self.assertIs(response, MOCK_RESPONSE)
        self.assertIsNone(getattr(request, "_debusine_token"))
        self.assertIsNone(context.worker_token)

    def test_token_invalid(self) -> None:
        """Check handling of invalid tokens."""
        mw = self.get_middleware()
        request = self.make_request(token="invalid")
        response = mw(request)
        self.assertIs(response, MOCK_RESPONSE)
        self.assertIsNone(getattr(request, "_debusine_token"))
        self.assertIsNone(context.worker_token)

    def test_user_token_and_django_user(self) -> None:
        """Check handling of user tokens and django users."""
        token = self.playground.create_worker_token()
        token.user = self.scenario.user
        token.save()
        request = self.make_request(token=token, user=self.scenario.user)
        response = self.assertSetsUserNone(request)
        self.assertResponse403(
            response, "a token cannot be both a user and a worker token"
        )

    def test_count_queries(self) -> None:
        """Check the number of queries done to authenticate."""
        token = self.playground.create_worker_token()
        request = self.make_request(token=token)
        with self.assertNumQueries(2):
            response = self.assertSetsUserNone(request)
        self.assertIs(response, MOCK_RESPONSE)

    def test_update_last_seen_at(self) -> None:
        """Request with a token updates last_seen_at."""
        before = timezone.now()
        token = self.playground.create_bare_token()
        self.assertIsNone(token.last_seen_at)
        request = self.make_request(token=token)
        response = self.assertSetsUser(request, AnonymousUser())
        self.assertIs(response, MOCK_RESPONSE)

        # token.last_seen_at was updated
        token.refresh_from_db()
        assert token.last_seen_at is not None
        self.assertGreaterEqual(token.last_seen_at, before)
        self.assertLessEqual(token.last_seen_at, timezone.now())

    def test_archive_views_http_no_basic_auth(self) -> None:
        request = self.make_request_archive_view(secure=False)
        response = self.assertSetsUser(request, AnonymousUser())
        self.assertIs(response, MOCK_RESPONSE)
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_not_debian_archive_view(self) -> None:
        token = self.scenario.create_user_token()
        request = self.make_request()
        self.add_basic_auth_token(request, token)
        response = self.assertSetsUser(request, AnonymousUser())
        self.assertIs(response, MOCK_RESPONSE)
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_authorization_not_basic(self) -> None:
        request = self.make_request_archive_view()
        request.META["HTTP_AUTHORIZATION"] = "Token foo"
        response = self.assertSetsUser(request, AnonymousUser())
        self.assertIs(response, MOCK_RESPONSE)
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_no_basic_credentials(self) -> None:
        request = self.make_request_archive_view()
        request.META["HTTP_AUTHORIZATION"] = "Basic"
        response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "No credentials provided")
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_credentials_malformed(self) -> None:
        for credentials in (
            "foo bar",
            "foo",
            # No colon in credentials
            base64.b64encode("foobar".encode(HTTP_HEADER_ENCODING)).decode(
                "ascii"
            ),
            # Invalid credentials, encoded in invalid UTF-8
            base64.b64encode(b"\x80").decode("ascii"),
        ):
            with self.subTest(credentials=repr(credentials)):
                request = self.make_request_archive_view()
                request.META["HTTP_AUTHORIZATION"] = f"Basic {credentials}"
                response = self.assertSetsUserNone(request)
                self.assertResponse403(
                    response, "Credentials string is malformed"
                )
                self.assertIsNone(context.worker_token)

    def test_token_basic_auth_not_https(self) -> None:
        token = self.scenario.create_user_token()
        request = self.make_request_archive_view(secure=False)
        self.add_basic_auth_token(request, token)
        response = self.assertSetsUserNone(request)
        self.assertResponse403(
            response,
            "Authenticated access is only allowed using HTTPS",
        )
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_nonexistent_token(self) -> None:
        request = self.make_request_archive_view()
        self.add_basic_auth_header(
            request, self.scenario.user.username, "invalid-token"
        )
        response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "Invalid token")
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_bare_token(self) -> None:
        token = self.playground.create_bare_token()
        request = self.make_request_archive_view()
        self.add_basic_auth_header(
            request, self.scenario.user.username, token.key
        )
        response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "Invalid token")
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_worker_token(self) -> None:
        token = self.playground.create_worker_token()
        request = self.make_request_archive_view()
        self.add_basic_auth_header(
            request, self.scenario.user.username, token.key
        )
        response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "Invalid token")
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_wrong_user(self) -> None:
        user = self.playground.create_user("test")
        token = self.playground.create_user_token(user=user)
        request = self.make_request_archive_view()
        self.add_basic_auth_header(
            request, self.scenario.user.username, token.key
        )
        response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "Invalid token")
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_expired_user_token(self) -> None:
        token = self.playground.create_user_token(
            expire_at=timezone.now() - dt.timedelta(seconds=1)
        )
        request = self.make_request_archive_view()
        self.add_basic_auth_token(request, token)
        response = self.assertSetsUserNone(request)
        self.assertResponse403(response, "Invalid token")
        self.assertIsNone(context.worker_token)

    def test_token_basic_auth_valid_user_token(self) -> None:
        token = self.scenario.create_user_token()
        request = self.make_request_archive_view()
        self.add_basic_auth_token(request, token)
        response = self.assertSetsUser(request, self.scenario.user)
        self.assertIs(response, MOCK_RESPONSE)
        self.assertIsNone(context.worker_token)
