# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for the 'hello' and legacy hello commands."""
from __future__ import annotations

import copy
import datetime
import itertools
from typing import Any, Generic, Mapping, Optional

from bson.objectid import ObjectId
from pymongo import common
from pymongo.server_type import SERVER_TYPE
from pymongo.typings import ClusterTime, _DocumentType


class HelloCompat:
    CMD = "hello"
    LEGACY_CMD = "ismaster"
    PRIMARY = "isWritablePrimary"
    LEGACY_PRIMARY = "ismaster"
    LEGACY_ERROR = "not master"


def _get_server_type(doc: Mapping[str, Any]) -> int:
    """Determine the server type from a hello response."""
    if not doc.get("ok"):
        return SERVER_TYPE.Unknown

    if doc.get("serviceId"):
        return SERVER_TYPE.LoadBalancer
    elif doc.get("isreplicaset"):
        return SERVER_TYPE.RSGhost
    elif doc.get("setName"):
        if doc.get("hidden"):
            return SERVER_TYPE.RSOther
        elif doc.get(HelloCompat.PRIMARY):
            return SERVER_TYPE.RSPrimary
        elif doc.get(HelloCompat.LEGACY_PRIMARY):
            return SERVER_TYPE.RSPrimary
        elif doc.get("secondary"):
            return SERVER_TYPE.RSSecondary
        elif doc.get("arbiterOnly"):
            return SERVER_TYPE.RSArbiter
        else:
            return SERVER_TYPE.RSOther
    elif doc.get("msg") == "isdbgrid":
        return SERVER_TYPE.Mongos
    else:
        return SERVER_TYPE.Standalone


class Hello(Generic[_DocumentType]):
    """Parse a hello response from the server.

    .. versionadded:: 3.12
    """

    __slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable")

    def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None:
        self._server_type = _get_server_type(doc)
        self._doc: _DocumentType = doc
        self._is_writable = self._server_type in (
            SERVER_TYPE.RSPrimary,
            SERVER_TYPE.Standalone,
            SERVER_TYPE.Mongos,
            SERVER_TYPE.LoadBalancer,
        )

        self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable
        self._awaitable = awaitable

    @property
    def document(self) -> _DocumentType:
        """The complete hello command response document.

        .. versionadded:: 3.4
        """
        return copy.copy(self._doc)

    @property
    def server_type(self) -> int:
        return self._server_type

    @property
    def all_hosts(self) -> set[tuple[str, int]]:
        """List of hosts, passives, and arbiters known to this server."""
        return set(
            map(
                common.clean_node,
                itertools.chain(
                    self._doc.get("hosts", []),
                    self._doc.get("passives", []),
                    self._doc.get("arbiters", []),
                ),
            )
        )

    @property
    def tags(self) -> Mapping[str, Any]:
        """Replica set member tags or empty dict."""
        return self._doc.get("tags", {})

    @property
    def primary(self) -> Optional[tuple[str, int]]:
        """This server's opinion about who the primary is, or None."""
        if self._doc.get("primary"):
            return common.partition_node(self._doc["primary"])
        else:
            return None

    @property
    def replica_set_name(self) -> Optional[str]:
        """Replica set name or None."""
        return self._doc.get("setName")

    @property
    def max_bson_size(self) -> int:
        return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE)

    @property
    def max_message_size(self) -> int:
        return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size)

    @property
    def max_write_batch_size(self) -> int:
        return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE)

    @property
    def min_wire_version(self) -> int:
        return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION)

    @property
    def max_wire_version(self) -> int:
        return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION)

    @property
    def set_version(self) -> Optional[int]:
        return self._doc.get("setVersion")

    @property
    def election_id(self) -> Optional[ObjectId]:
        return self._doc.get("electionId")

    @property
    def cluster_time(self) -> Optional[ClusterTime]:
        return self._doc.get("$clusterTime")

    @property
    def logical_session_timeout_minutes(self) -> Optional[int]:
        return self._doc.get("logicalSessionTimeoutMinutes")

    @property
    def is_writable(self) -> bool:
        return self._is_writable

    @property
    def is_readable(self) -> bool:
        return self._is_readable

    @property
    def me(self) -> Optional[tuple[str, int]]:
        me = self._doc.get("me")
        if me:
            return common.clean_node(me)
        return None

    @property
    def last_write_date(self) -> Optional[datetime.datetime]:
        return self._doc.get("lastWrite", {}).get("lastWriteDate")

    @property
    def compressors(self) -> Optional[list[str]]:
        return self._doc.get("compression")

    @property
    def sasl_supported_mechs(self) -> list[str]:
        """Supported authentication mechanisms for the current user.

        For example::

            >>> hello.sasl_supported_mechs
            ["SCRAM-SHA-1", "SCRAM-SHA-256"]

        """
        return self._doc.get("saslSupportedMechs", [])

    @property
    def speculative_authenticate(self) -> Optional[Mapping[str, Any]]:
        """The speculativeAuthenticate field."""
        return self._doc.get("speculativeAuthenticate")

    @property
    def topology_version(self) -> Optional[Mapping[str, Any]]:
        return self._doc.get("topologyVersion")

    @property
    def awaitable(self) -> bool:
        return self._awaitable

    @property
    def service_id(self) -> Optional[ObjectId]:
        return self._doc.get("serviceId")

    @property
    def hello_ok(self) -> bool:
        return self._doc.get("helloOk", False)

    @property
    def connection_id(self) -> Optional[int]:
        return self._doc.get("connectionId")