# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License.  You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.  See the License for the specific language governing
# permissions and limitations under the License.

"""A CPython compatible SSLContext implementation wrapping PyOpenSSL's
context.
"""
from __future__ import annotations

import socket as _socket
import ssl as _stdlibssl
import sys as _sys
import time as _time
from errno import EINTR as _EINTR
from ipaddress import ip_address as _ip_address
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

import cryptography.x509 as x509
import service_identity
from OpenSSL import SSL as _SSL
from OpenSSL import crypto as _crypto

from pymongo.errors import ConfigurationError as _ConfigurationError
from pymongo.errors import _CertificateError  # type:ignore[attr-defined]
from pymongo.ocsp_cache import _OCSPCache
from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback
from pymongo.socket_checker import SocketChecker as _SocketChecker
from pymongo.socket_checker import _errno_from_exception
from pymongo.write_concern import validate_boolean

if TYPE_CHECKING:
    from ssl import VerifyMode


_T = TypeVar("_T")

try:
    import certifi

    _HAVE_CERTIFI = True
except ImportError:
    _HAVE_CERTIFI = False

PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD
# Always available
OP_NO_SSLv2 = _SSL.OP_NO_SSLv2
OP_NO_SSLv3 = _SSL.OP_NO_SSLv3
OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION
# This isn't currently documented for PyOpenSSL
OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0)

# Always available
HAS_SNI = True
IS_PYOPENSSL = True

# Base Exception class
SSLError = _SSL.Error

# https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002
_VERIFY_MAP = {
    _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE,
    _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER,
    _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
}

_REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()}


# For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are
# not permitted for SNI hostname.
def _is_ip_address(address: Any) -> bool:
    try:
        _ip_address(address)
        return True
    except (ValueError, UnicodeError):
        return False


# According to the docs for socket.send it can raise
# WantX509LookupError and should be retried.
BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError)


def _ragged_eof(exc: BaseException) -> bool:
    """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF."""
    return exc.args == (-1, "Unexpected EOF")


# https://github.com/pyca/pyopenssl/issues/168
# https://github.com/pyca/pyopenssl/issues/176
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
class _sslConn(_SSL.Connection):
    def __init__(
        self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
    ):
        self.socket_checker = _SocketChecker()
        self.suppress_ragged_eofs = suppress_ragged_eofs
        super().__init__(ctx, sock)

    def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
        timeout = self.gettimeout()
        if timeout:
            start = _time.monotonic()
        while True:
            try:
                return call(*args, **kwargs)
            except BLOCKING_IO_ERRORS as exc:
                # Check for closed socket.
                if self.fileno() == -1:
                    if timeout and _time.monotonic() - start > timeout:
                        raise _socket.timeout("timed out") from None
                    raise SSLError("Underlying socket has been closed") from None
                if isinstance(exc, _SSL.WantReadError):
                    want_read = True
                    want_write = False
                elif isinstance(exc, _SSL.WantWriteError):
                    want_read = False
                    want_write = True
                else:
                    want_read = True
                    want_write = True
                self.socket_checker.select(self, want_read, want_write, timeout)
                if timeout and _time.monotonic() - start > timeout:
                    raise _socket.timeout("timed out") from None
                continue

    def do_handshake(self, *args: Any, **kwargs: Any) -> None:
        return self._call(super().do_handshake, *args, **kwargs)

    def recv(self, *args: Any, **kwargs: Any) -> bytes:
        try:
            return self._call(super().recv, *args, **kwargs)
        except _SSL.SysCallError as exc:
            # Suppress ragged EOFs to match the stdlib.
            if self.suppress_ragged_eofs and _ragged_eof(exc):
                return b""
            raise

    def recv_into(self, *args: Any, **kwargs: Any) -> int:
        try:
            return self._call(super().recv_into, *args, **kwargs)
        except _SSL.SysCallError as exc:
            # Suppress ragged EOFs to match the stdlib.
            if self.suppress_ragged_eofs and _ragged_eof(exc):
                return 0
            raise

    def sendall(self, buf: bytes, flags: int = 0) -> None:  # type: ignore[override]
        view = memoryview(buf)
        total_length = len(buf)
        total_sent = 0
        while total_sent < total_length:
            try:
                sent = self._call(super().send, view[total_sent:], flags)
            # XXX: It's not clear if this can actually happen. PyOpenSSL
            # doesn't appear to have any interrupt handling, nor any interrupt
            # errors for OpenSSL connections.
            except OSError as exc:
                if _errno_from_exception(exc) == _EINTR:
                    continue
                raise
            # https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756
            # https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html
            if sent <= 0:
                raise OSError("connection closed")
            total_sent += sent


class _CallbackData:
    """Data class which is passed to the OCSP callback."""

    def __init__(self) -> None:
        self.trusted_ca_certs: Optional[list[x509.Certificate]] = None
        self.check_ocsp_endpoint: Optional[bool] = None
        self.ocsp_response_cache = _OCSPCache()


class SSLContext:
    """A CPython compatible SSLContext implementation wrapping PyOpenSSL's
    context.
    """

    __slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname")

    def __init__(self, protocol: int):
        self._protocol = protocol
        self._ctx = _SSL.Context(self._protocol)
        self._callback_data = _CallbackData()
        self._check_hostname = True
        # OCSP
        # XXX: Find a better place to do this someday, since this is client
        # side configuration and wrap_socket tries to support both client and
        # server side sockets.
        self._callback_data.check_ocsp_endpoint = True
        self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data)

    @property
    def protocol(self) -> int:
        """The protocol version chosen when constructing the context.
        This attribute is read-only.
        """
        return self._protocol

    def __get_verify_mode(self) -> VerifyMode:
        """Whether to try to verify other peers' certificates and how to
        behave if verification fails. This attribute must be one of
        ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED.
        """
        return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()]

    def __set_verify_mode(self, value: VerifyMode) -> None:
        """Setter for verify_mode."""

        def _cb(
            _connobj: _SSL.Connection,
            _x509obj: _crypto.X509,
            _errnum: int,
            _errdepth: int,
            retcode: int,
        ) -> bool:
            # It seems we don't need to do anything here. Twisted doesn't,
            # and OpenSSL's SSL_CTX_set_verify let's you pass NULL
            # for the callback option. It's weird that PyOpenSSL requires
            # this.
            # This is optional in pyopenssl >= 20 and can be removed once minimum
            # supported version is bumped
            # See: pyopenssl.org/en/latest/changelog.html#id47
            return bool(retcode)

        self._ctx.set_verify(_VERIFY_MAP[value], _cb)

    verify_mode = property(__get_verify_mode, __set_verify_mode)

    def __get_check_hostname(self) -> bool:
        return self._check_hostname

    def __set_check_hostname(self, value: Any) -> None:
        validate_boolean("check_hostname", value)
        self._check_hostname = value

    check_hostname = property(__get_check_hostname, __set_check_hostname)

    def __get_check_ocsp_endpoint(self) -> Optional[bool]:
        return self._callback_data.check_ocsp_endpoint

    def __set_check_ocsp_endpoint(self, value: bool) -> None:
        validate_boolean("check_ocsp", value)
        self._callback_data.check_ocsp_endpoint = value

    check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint)

    def __get_options(self) -> None:
        # Calling set_options adds the option to the existing bitmask and
        # returns the new bitmask.
        # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options
        return self._ctx.set_options(0)

    def __set_options(self, value: int) -> None:
        # Explicitly convert to int, since newer CPython versions
        # use enum.IntFlag for options. The values are the same
        # regardless of implementation.
        self._ctx.set_options(int(value))

    options = property(__get_options, __set_options)

    def load_cert_chain(
        self,
        certfile: Union[str, bytes],
        keyfile: Union[str, bytes, None] = None,
        password: Optional[str] = None,
    ) -> None:
        """Load a private key and the corresponding certificate. The certfile
        string must be the path to a single file in PEM format containing the
        certificate as well as any number of CA certificates needed to
        establish the certificate's authenticity. The keyfile string, if
        present, must point to a file containing the private key. Otherwise
        the private key will be taken from certfile as well.
        """
        # Match CPython behavior
        # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971
        # Password callback MUST be set first or it will be ignored.
        if password:

            def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: bytes) -> bytes:
                # XXX:We could check the password length against what OpenSSL
                # tells us is the max, but we can't raise an exception, so...
                # warn?
                assert password is not None
                return password.encode("utf-8")

            self._ctx.set_passwd_cb(_pwcb)
        self._ctx.use_certificate_chain_file(certfile)
        self._ctx.use_privatekey_file(keyfile or certfile)
        self._ctx.check_privatekey()

    def load_verify_locations(
        self, cafile: Optional[str] = None, capath: Optional[str] = None
    ) -> None:
        """Load a set of "certification authority"(CA) certificates used to
        validate other peers' certificates when `~verify_mode` is other than
        ssl.CERT_NONE.
        """
        self._ctx.load_verify_locations(cafile, capath)
        # Manually load the CA certs when get_verified_chain is not available (pyopenssl<20).
        if not hasattr(_SSL.Connection, "get_verified_chain"):
            assert cafile is not None
            self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile)

    def _load_certifi(self) -> None:
        """Attempt to load CA certs from certifi."""
        if _HAVE_CERTIFI:
            self.load_verify_locations(certifi.where())
        else:
            raise _ConfigurationError(
                "tlsAllowInvalidCertificates is False but no system "
                "CA certificates could be loaded. Please install the "
                "certifi package, or provide a path to a CA file using "
                "the tlsCAFile option"
            )

    def _load_wincerts(self, store: str) -> None:
        """Attempt to load CA certs from Windows trust store."""
        cert_store = self._ctx.get_cert_store()
        oid = _stdlibssl.Purpose.SERVER_AUTH.oid

        for cert, encoding, trust in _stdlibssl.enum_certificates(store):  # type: ignore
            if encoding == "x509_asn":
                if trust is True or oid in trust:
                    cert_store.add_cert(
                        _crypto.X509.from_cryptography(x509.load_der_x509_certificate(cert))
                    )

    def load_default_certs(self) -> None:
        """A PyOpenSSL version of load_default_certs from CPython."""
        # PyOpenSSL is incapable of loading CA certs from Windows, and mostly
        # incapable on macOS.
        # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths
        if _sys.platform == "win32":
            try:
                for storename in ("CA", "ROOT"):
                    self._load_wincerts(storename)
            except PermissionError:
                # Fall back to certifi
                self._load_certifi()
        elif _sys.platform == "darwin":
            self._load_certifi()
        self._ctx.set_default_verify_paths()

    def set_default_verify_paths(self) -> None:
        """Specify that the platform provided CA certificates are to be used
        for verification purposes.
        """
        # Note: See PyOpenSSL's docs for limitations, which are similar
        # but not that same as CPython's.
        self._ctx.set_default_verify_paths()

    def wrap_socket(
        self,
        sock: _socket.socket,
        server_side: bool = False,
        do_handshake_on_connect: bool = True,
        suppress_ragged_eofs: bool = True,
        server_hostname: Optional[str] = None,
        session: Optional[_SSL.Session] = None,
    ) -> _sslConn:
        """Wrap an existing Python socket connection and return a TLS socket
        object.
        """
        ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
        if session:
            ssl_conn.set_session(session)
        if server_side is True:
            ssl_conn.set_accept_state()
        else:
            # SNI
            if server_hostname and not _is_ip_address(server_hostname):
                # XXX: Do this in a callback registered with
                # SSLContext.set_info_callback? See Twisted for an example.
                ssl_conn.set_tlsext_host_name(server_hostname.encode("idna"))
            if self.verify_mode != _stdlibssl.CERT_NONE:
                # Request a stapled OCSP response.
                ssl_conn.request_ocsp()
            ssl_conn.set_connect_state()
        # If this wasn't true the caller of wrap_socket would call
        # do_handshake()
        if do_handshake_on_connect:
            # XXX: If we do hostname checking in a callback we can get rid
            # of this call to do_handshake() since the handshake
            # will happen automatically later.
            ssl_conn.do_handshake()
            # XXX: Do this in a callback registered with
            # SSLContext.set_info_callback? See Twisted for an example.
            if self.check_hostname and server_hostname is not None:
                from service_identity import pyopenssl

                try:
                    if _is_ip_address(server_hostname):
                        pyopenssl.verify_ip_address(ssl_conn, server_hostname)
                    else:
                        pyopenssl.verify_hostname(ssl_conn, server_hostname)
                except (  # type:ignore[misc]
                    service_identity.SICertificateError,
                    service_identity.SIVerificationError,
                ) as exc:
                    raise _CertificateError(str(exc)) from None
        return ssl_conn