# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MONGODB-OIDC Authentication helpers."""
from __future__ import annotations

import abc
import os
import threading
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union
from urllib.parse import quote

import bson
from bson.binary import Binary
from pymongo._azure_helpers import _get_azure_response
from pymongo._csot import remaining
from pymongo._gcp_helpers import _get_gcp_response
from pymongo.errors import ConfigurationError, OperationFailure
from pymongo.helpers import _AUTHENTICATION_FAILURE_CODE

if TYPE_CHECKING:
    from pymongo.auth import MongoCredential
    from pymongo.pool import Connection


@dataclass
class OIDCIdPInfo:
    issuer: str
    clientId: Optional[str] = field(default=None)
    requestScopes: Optional[list[str]] = field(default=None)


@dataclass
class OIDCCallbackContext:
    timeout_seconds: float
    username: str
    version: int
    refresh_token: Optional[str] = field(default=None)
    idp_info: Optional[OIDCIdPInfo] = field(default=None)


@dataclass
class OIDCCallbackResult:
    access_token: str
    expires_in_seconds: Optional[float] = field(default=None)
    refresh_token: Optional[str] = field(default=None)


class OIDCCallback(abc.ABC):
    """A base class for defining OIDC callbacks."""

    @abc.abstractmethod
    def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
        """Convert the given BSON value into our own type."""


@dataclass
class _OIDCProperties:
    callback: Optional[OIDCCallback] = field(default=None)
    human_callback: Optional[OIDCCallback] = field(default=None)
    environment: Optional[str] = field(default=None)
    allowed_hosts: list[str] = field(default_factory=list)
    token_resource: Optional[str] = field(default=None)
    username: str = ""


"""Mechanism properties for MONGODB-OIDC authentication."""

TOKEN_BUFFER_MINUTES = 5
HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60
CALLBACK_VERSION = 1
MACHINE_CALLBACK_TIMEOUT_SECONDS = 60
TIME_BETWEEN_CALLS_SECONDS = 0.1


def _get_authenticator(
    credentials: MongoCredential, address: tuple[str, int]
) -> _OIDCAuthenticator:
    if credentials.cache.data:
        return credentials.cache.data

    # Extract values.
    principal_name = credentials.username
    properties = credentials.mechanism_properties

    # Validate that the address is allowed.
    if not properties.environment:
        found = False
        allowed_hosts = properties.allowed_hosts
        for patt in allowed_hosts:
            if patt == address[0]:
                found = True
            elif patt.startswith("*.") and address[0].endswith(patt[1:]):
                found = True
        if not found:
            raise ConfigurationError(
                f"Refusing to connect to {address[0]}, which is not in authOIDCAllowedHosts: {allowed_hosts}"
            )

    # Get or create the cache data.
    credentials.cache.data = _OIDCAuthenticator(username=principal_name, properties=properties)
    return credentials.cache.data


class _OIDCTestCallback(OIDCCallback):
    def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
        token_file = os.environ.get("OIDC_TOKEN_FILE")
        if not token_file:
            raise RuntimeError(
                'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set'
            )
        with open(token_file) as fid:
            return OIDCCallbackResult(access_token=fid.read().strip())


class _OIDCAzureCallback(OIDCCallback):
    def __init__(self, token_resource: str) -> None:
        self.token_resource = quote(token_resource)

    def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
        resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds)
        return OIDCCallbackResult(
            access_token=resp["access_token"], expires_in_seconds=resp["expires_in"]
        )


class _OIDCGCPCallback(OIDCCallback):
    def __init__(self, token_resource: str) -> None:
        self.token_resource = quote(token_resource)

    def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult:
        resp = _get_gcp_response(self.token_resource, context.timeout_seconds)
        return OIDCCallbackResult(access_token=resp["access_token"])


@dataclass
class _OIDCAuthenticator:
    username: str
    properties: _OIDCProperties
    refresh_token: Optional[str] = field(default=None)
    access_token: Optional[str] = field(default=None)
    idp_info: Optional[OIDCIdPInfo] = field(default=None)
    token_gen_id: int = field(default=0)
    lock: threading.Lock = field(default_factory=threading.Lock)
    last_call_time: float = field(default=0)

    def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
        """Handle a reauthenticate from the server."""
        # Invalidate the token for the connection.
        self._invalidate(conn)
        # Call the appropriate auth logic for the callback type.
        if self.properties.callback:
            return self._authenticate_machine(conn)
        return self._authenticate_human(conn)

    def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]:
        """Handle an initial authenticate request."""
        # First handle speculative auth.
        # If it succeeded, we are done.
        ctx = conn.auth_ctx
        if ctx and ctx.speculate_succeeded():
            resp = ctx.speculative_authenticate
            if resp and resp["done"]:
                conn.oidc_token_gen_id = self.token_gen_id
                return resp

        # If spec auth failed, call the appropriate auth logic for the callback type.
        # We cannot assume that the token is invalid, because a proxy may have been
        # involved that stripped the speculative auth information.
        if self.properties.callback:
            return self._authenticate_machine(conn)
        return self._authenticate_human(conn)

    def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]:
        """Get the appropriate speculative auth command."""
        if not self.access_token:
            return None
        return self._get_start_command({"jwt": self.access_token})

    def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]:
        # If there is a cached access token, try to authenticate with it. If
        # authentication fails with error code 18, invalidate the access token,
        # fetch a new access token, and try to authenticate again. If authentication
        # fails for any other reason, raise the error to the user.
        if self.access_token:
            try:
                return self._sasl_start_jwt(conn)
            except OperationFailure as e:
                if self._is_auth_error(e):
                    return self._authenticate_machine(conn)
                raise
        return self._sasl_start_jwt(conn)

    def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]:
        # If we have a cached access token, try a JwtStepRequest.
        # authentication fails with error code 18, invalidate the access token,
        # and try to authenticate again.  If authentication fails for any other
        # reason, raise the error to the user.
        if self.access_token:
            try:
                return self._sasl_start_jwt(conn)
            except OperationFailure as e:
                if self._is_auth_error(e):
                    return self._authenticate_human(conn)
                raise

        # If we have a cached refresh token, try a JwtStepRequest with that.
        # If authentication fails with error code 18, invalidate the access and
        # refresh tokens, and try to authenticate again. If authentication fails for
        # any other reason, raise the error to the user.
        if self.refresh_token:
            try:
                return self._sasl_start_jwt(conn)
            except OperationFailure as e:
                if self._is_auth_error(e):
                    self.refresh_token = None
                    return self._authenticate_human(conn)
                raise

        # Start a new Two-Step SASL conversation.
        # Run a PrincipalStepRequest to get the IdpInfo.
        cmd = self._get_start_command(None)
        start_resp = self._run_command(conn, cmd)
        # Attempt to authenticate with a JwtStepRequest.
        return self._sasl_continue_jwt(conn, start_resp)

    def _get_access_token(self) -> Optional[str]:
        properties = self.properties
        cb: Union[None, OIDCCallback]
        resp: OIDCCallbackResult

        is_human = properties.human_callback is not None
        if is_human and self.idp_info is None:
            return None

        if properties.callback:
            cb = properties.callback
        if properties.human_callback:
            cb = properties.human_callback

        prev_token = self.access_token
        if prev_token:
            return prev_token

        if cb is None and not prev_token:
            return None

        if not prev_token and cb is not None:
            with self.lock:
                # See if the token was changed while we were waiting for the
                # lock.
                new_token = self.access_token
                if new_token != prev_token:
                    return new_token

                # Ensure that we are waiting a min time between callback invocations.
                delta = time.time() - self.last_call_time
                if delta < TIME_BETWEEN_CALLS_SECONDS:
                    time.sleep(TIME_BETWEEN_CALLS_SECONDS - delta)
                self.last_call_time = time.time()

                if is_human:
                    timeout = HUMAN_CALLBACK_TIMEOUT_SECONDS
                    assert self.idp_info is not None
                else:
                    timeout = int(remaining() or MACHINE_CALLBACK_TIMEOUT_SECONDS)
                context = OIDCCallbackContext(
                    timeout_seconds=timeout,
                    version=CALLBACK_VERSION,
                    refresh_token=self.refresh_token,
                    idp_info=self.idp_info,
                    username=self.properties.username,
                )
                resp = cb.fetch(context)
                if not isinstance(resp, OIDCCallbackResult):
                    raise ValueError("Callback result must be of type OIDCCallbackResult")
                self.refresh_token = resp.refresh_token
                self.access_token = resp.access_token
                self.token_gen_id += 1

        return self.access_token

    def _run_command(self, conn: Connection, cmd: MutableMapping[str, Any]) -> Mapping[str, Any]:
        try:
            return conn.command("$external", cmd, no_reauth=True)  # type: ignore[call-arg]
        except OperationFailure as e:
            if self._is_auth_error(e):
                self._invalidate(conn)
            raise

    def _is_auth_error(self, err: Exception) -> bool:
        if not isinstance(err, OperationFailure):
            return False
        return err.code == _AUTHENTICATION_FAILURE_CODE

    def _invalidate(self, conn: Connection) -> None:
        # Ignore the invalidation if a token gen id is given and is less than our
        # current token gen id.
        token_gen_id = conn.oidc_token_gen_id or 0
        if token_gen_id is not None and token_gen_id < self.token_gen_id:
            return
        self.access_token = None

    def _sasl_continue_jwt(
        self, conn: Connection, start_resp: Mapping[str, Any]
    ) -> Mapping[str, Any]:
        self.access_token = None
        self.refresh_token = None
        start_payload: dict = bson.decode(start_resp["payload"])
        if "issuer" in start_payload:
            self.idp_info = OIDCIdPInfo(**start_payload)
        access_token = self._get_access_token()
        conn.oidc_token_gen_id = self.token_gen_id
        cmd = self._get_continue_command({"jwt": access_token}, start_resp)
        return self._run_command(conn, cmd)

    def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]:
        access_token = self._get_access_token()
        conn.oidc_token_gen_id = self.token_gen_id
        cmd = self._get_start_command({"jwt": access_token})
        return self._run_command(conn, cmd)

    def _get_start_command(self, payload: Optional[Mapping[str, Any]]) -> MutableMapping[str, Any]:
        if payload is None:
            principal_name = self.username
            if principal_name:
                payload = {"n": principal_name}
            else:
                payload = {}
        bin_payload = Binary(bson.encode(payload))
        return {"saslStart": 1, "mechanism": "MONGODB-OIDC", "payload": bin_payload}

    def _get_continue_command(
        self, payload: Mapping[str, Any], start_resp: Mapping[str, Any]
    ) -> MutableMapping[str, Any]:
        bin_payload = Binary(bson.encode(payload))
        return {
            "saslContinue": 1,
            "payload": bin_payload,
            "conversationId": start_resp["conversationId"],
        }


def _authenticate_oidc(
    credentials: MongoCredential, conn: Connection, reauthenticate: bool
) -> Optional[Mapping[str, Any]]:
    """Authenticate using MONGODB-OIDC."""
    authenticator = _get_authenticator(credentials, conn.address)
    if reauthenticate:
        return authenticator.reauthenticate(conn)
    else:
        return authenticator.authenticate(conn)