# Copyright 2017 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of gRPC Python interceptors."""

import collections
import sys
import types
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import grpc

from ._typing import DeserializingFunction
from ._typing import DoneCallbackType
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import SerializingFunction


class _ServicePipeline(object):
    interceptors: Tuple[grpc.ServerInterceptor]

    def __init__(self, interceptors: Sequence[grpc.ServerInterceptor]):
        self.interceptors = tuple(interceptors)

    def _continuation(self, thunk: Callable, index: int) -> Callable:
        return lambda context: self._intercept_at(thunk, index, context)

    def _intercept_at(
        self, thunk: Callable, index: int, context: grpc.HandlerCallDetails
    ) -> grpc.RpcMethodHandler:
        if index < len(self.interceptors):
            interceptor = self.interceptors[index]
            thunk = self._continuation(thunk, index + 1)
            return interceptor.intercept_service(thunk, context)
        else:
            return thunk(context)

    def execute(
        self, thunk: Callable, context: grpc.HandlerCallDetails
    ) -> grpc.RpcMethodHandler:
        return self._intercept_at(thunk, 0, context)


def service_pipeline(
    interceptors: Optional[Sequence[grpc.ServerInterceptor]],
) -> Optional[_ServicePipeline]:
    return _ServicePipeline(interceptors) if interceptors else None


class _ClientCallDetails(
    collections.namedtuple(
        "_ClientCallDetails",
        (
            "method",
            "timeout",
            "metadata",
            "credentials",
            "wait_for_ready",
            "compression",
        ),
    ),
    grpc.ClientCallDetails,
):
    pass


def _unwrap_client_call_details(
    call_details: grpc.ClientCallDetails,
    default_details: grpc.ClientCallDetails,
) -> Tuple[
    str, float, MetadataType, grpc.CallCredentials, bool, grpc.Compression
]:
    try:
        method = call_details.method  # pytype: disable=attribute-error
    except AttributeError:
        method = default_details.method  # pytype: disable=attribute-error

    try:
        timeout = call_details.timeout  # pytype: disable=attribute-error
    except AttributeError:
        timeout = default_details.timeout  # pytype: disable=attribute-error

    try:
        metadata = call_details.metadata  # pytype: disable=attribute-error
    except AttributeError:
        metadata = default_details.metadata  # pytype: disable=attribute-error

    try:
        credentials = (
            call_details.credentials
        )  # pytype: disable=attribute-error
    except AttributeError:
        credentials = (
            default_details.credentials
        )  # pytype: disable=attribute-error

    try:
        wait_for_ready = (
            call_details.wait_for_ready
        )  # pytype: disable=attribute-error
    except AttributeError:
        wait_for_ready = (
            default_details.wait_for_ready
        )  # pytype: disable=attribute-error

    try:
        compression = (
            call_details.compression
        )  # pytype: disable=attribute-error
    except AttributeError:
        compression = (
            default_details.compression
        )  # pytype: disable=attribute-error

    return method, timeout, metadata, credentials, wait_for_ready, compression


class _FailureOutcome(
    grpc.RpcError, grpc.Future, grpc.Call
):  # pylint: disable=too-many-ancestors
    _exception: Exception
    _traceback: types.TracebackType

    def __init__(self, exception: Exception, traceback: types.TracebackType):
        super(_FailureOutcome, self).__init__()
        self._exception = exception
        self._traceback = traceback

    def initial_metadata(self) -> Optional[MetadataType]:
        return None

    def trailing_metadata(self) -> Optional[MetadataType]:
        return None

    def code(self) -> Optional[grpc.StatusCode]:
        return grpc.StatusCode.INTERNAL

    def details(self) -> Optional[str]:
        return "Exception raised while intercepting the RPC"

    def cancel(self) -> bool:
        return False

    def cancelled(self) -> bool:
        return False

    def is_active(self) -> bool:
        return False

    def time_remaining(self) -> Optional[float]:
        return None

    def running(self) -> bool:
        return False

    def done(self) -> bool:
        return True

    def result(self, ignored_timeout: Optional[float] = None):
        raise self._exception

    def exception(
        self, ignored_timeout: Optional[float] = None
    ) -> Optional[Exception]:
        return self._exception

    def traceback(
        self, ignored_timeout: Optional[float] = None
    ) -> Optional[types.TracebackType]:
        return self._traceback

    def add_callback(self, unused_callback) -> bool:
        return False

    def add_done_callback(self, fn: DoneCallbackType) -> None:
        fn(self)

    def __iter__(self):
        return self

    def __next__(self):
        raise self._exception

    def next(self):
        return self.__next__()


class _UnaryOutcome(grpc.Call, grpc.Future):
    _response: Any
    _call: grpc.Call

    def __init__(self, response: Any, call: grpc.Call):
        self._response = response
        self._call = call

    def initial_metadata(self) -> Optional[MetadataType]:
        return self._call.initial_metadata()

    def trailing_metadata(self) -> Optional[MetadataType]:
        return self._call.trailing_metadata()

    def code(self) -> Optional[grpc.StatusCode]:
        return self._call.code()

    def details(self) -> Optional[str]:
        return self._call.details()

    def is_active(self) -> bool:
        return self._call.is_active()

    def time_remaining(self) -> Optional[float]:
        return self._call.time_remaining()

    def cancel(self) -> bool:
        return self._call.cancel()

    def add_callback(self, callback) -> bool:
        return self._call.add_callback(callback)

    def cancelled(self) -> bool:
        return False

    def running(self) -> bool:
        return False

    def done(self) -> bool:
        return True

    def result(self, ignored_timeout: Optional[float] = None):
        return self._response

    def exception(self, ignored_timeout: Optional[float] = None):
        return None

    def traceback(self, ignored_timeout: Optional[float] = None):
        return None

    def add_done_callback(self, fn: DoneCallbackType) -> None:
        fn(self)


class _UnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.UnaryUnaryClientInterceptor

    def __init__(
        self,
        thunk: Callable,
        method: str,
        interceptor: grpc.UnaryUnaryClientInterceptor,
    ):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Any:
        response, ignored_call = self._with_call(
            request,
            timeout=timeout,
            metadata=metadata,
            credentials=credentials,
            wait_for_ready=wait_for_ready,
            compression=compression,
        )
        return response

    def _with_call(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Tuple[Any, grpc.Call]:
        client_call_details = _ClientCallDetails(
            self._method,
            timeout,
            metadata,
            credentials,
            wait_for_ready,
            compression,
        )

        def continuation(new_details, request):
            (
                new_method,
                new_timeout,
                new_metadata,
                new_credentials,
                new_wait_for_ready,
                new_compression,
            ) = _unwrap_client_call_details(new_details, client_call_details)
            try:
                response, call = self._thunk(new_method).with_call(
                    request,
                    timeout=new_timeout,
                    metadata=new_metadata,
                    credentials=new_credentials,
                    wait_for_ready=new_wait_for_ready,
                    compression=new_compression,
                )
                return _UnaryOutcome(response, call)
            except grpc.RpcError as rpc_error:
                return rpc_error
            except Exception as exception:  # pylint:disable=broad-except
                return _FailureOutcome(exception, sys.exc_info()[2])

        call = self._interceptor.intercept_unary_unary(
            continuation, client_call_details, request
        )
        return call.result(), call

    def with_call(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Tuple[Any, grpc.Call]:
        return self._with_call(
            request,
            timeout=timeout,
            metadata=metadata,
            credentials=credentials,
            wait_for_ready=wait_for_ready,
            compression=compression,
        )

    def future(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Any:
        client_call_details = _ClientCallDetails(
            self._method,
            timeout,
            metadata,
            credentials,
            wait_for_ready,
            compression,
        )

        def continuation(new_details, request):
            (
                new_method,
                new_timeout,
                new_metadata,
                new_credentials,
                new_wait_for_ready,
                new_compression,
            ) = _unwrap_client_call_details(new_details, client_call_details)
            return self._thunk(new_method).future(
                request,
                timeout=new_timeout,
                metadata=new_metadata,
                credentials=new_credentials,
                wait_for_ready=new_wait_for_ready,
                compression=new_compression,
            )

        try:
            return self._interceptor.intercept_unary_unary(
                continuation, client_call_details, request
            )
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _UnaryStreamMultiCallable(grpc.UnaryStreamMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.UnaryStreamClientInterceptor

    def __init__(
        self,
        thunk: Callable,
        method: str,
        interceptor: grpc.UnaryStreamClientInterceptor,
    ):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(
        self,
        request: Any,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ):
        client_call_details = _ClientCallDetails(
            self._method,
            timeout,
            metadata,
            credentials,
            wait_for_ready,
            compression,
        )

        def continuation(new_details, request):
            (
                new_method,
                new_timeout,
                new_metadata,
                new_credentials,
                new_wait_for_ready,
                new_compression,
            ) = _unwrap_client_call_details(new_details, client_call_details)
            return self._thunk(new_method)(
                request,
                timeout=new_timeout,
                metadata=new_metadata,
                credentials=new_credentials,
                wait_for_ready=new_wait_for_ready,
                compression=new_compression,
            )

        try:
            return self._interceptor.intercept_unary_stream(
                continuation, client_call_details, request
            )
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _StreamUnaryMultiCallable(grpc.StreamUnaryMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.StreamUnaryClientInterceptor

    def __init__(
        self,
        thunk: Callable,
        method: str,
        interceptor: grpc.StreamUnaryClientInterceptor,
    ):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Any:
        response, ignored_call = self._with_call(
            request_iterator,
            timeout=timeout,
            metadata=metadata,
            credentials=credentials,
            wait_for_ready=wait_for_ready,
            compression=compression,
        )
        return response

    def _with_call(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Tuple[Any, grpc.Call]:
        client_call_details = _ClientCallDetails(
            self._method,
            timeout,
            metadata,
            credentials,
            wait_for_ready,
            compression,
        )

        def continuation(new_details, request_iterator):
            (
                new_method,
                new_timeout,
                new_metadata,
                new_credentials,
                new_wait_for_ready,
                new_compression,
            ) = _unwrap_client_call_details(new_details, client_call_details)
            try:
                response, call = self._thunk(new_method).with_call(
                    request_iterator,
                    timeout=new_timeout,
                    metadata=new_metadata,
                    credentials=new_credentials,
                    wait_for_ready=new_wait_for_ready,
                    compression=new_compression,
                )
                return _UnaryOutcome(response, call)
            except grpc.RpcError as rpc_error:
                return rpc_error
            except Exception as exception:  # pylint:disable=broad-except
                return _FailureOutcome(exception, sys.exc_info()[2])

        call = self._interceptor.intercept_stream_unary(
            continuation, client_call_details, request_iterator
        )
        return call.result(), call

    def with_call(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Tuple[Any, grpc.Call]:
        return self._with_call(
            request_iterator,
            timeout=timeout,
            metadata=metadata,
            credentials=credentials,
            wait_for_ready=wait_for_ready,
            compression=compression,
        )

    def future(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ) -> Any:
        client_call_details = _ClientCallDetails(
            self._method,
            timeout,
            metadata,
            credentials,
            wait_for_ready,
            compression,
        )

        def continuation(new_details, request_iterator):
            (
                new_method,
                new_timeout,
                new_metadata,
                new_credentials,
                new_wait_for_ready,
                new_compression,
            ) = _unwrap_client_call_details(new_details, client_call_details)
            return self._thunk(new_method).future(
                request_iterator,
                timeout=new_timeout,
                metadata=new_metadata,
                credentials=new_credentials,
                wait_for_ready=new_wait_for_ready,
                compression=new_compression,
            )

        try:
            return self._interceptor.intercept_stream_unary(
                continuation, client_call_details, request_iterator
            )
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _StreamStreamMultiCallable(grpc.StreamStreamMultiCallable):
    _thunk: Callable
    _method: str
    _interceptor: grpc.StreamStreamClientInterceptor

    def __init__(
        self,
        thunk: Callable,
        method: str,
        interceptor: grpc.StreamStreamClientInterceptor,
    ):
        self._thunk = thunk
        self._method = method
        self._interceptor = interceptor

    def __call__(
        self,
        request_iterator: RequestIterableType,
        timeout: Optional[float] = None,
        metadata: Optional[MetadataType] = None,
        credentials: Optional[grpc.CallCredentials] = None,
        wait_for_ready: Optional[bool] = None,
        compression: Optional[grpc.Compression] = None,
    ):
        client_call_details = _ClientCallDetails(
            self._method,
            timeout,
            metadata,
            credentials,
            wait_for_ready,
            compression,
        )

        def continuation(new_details, request_iterator):
            (
                new_method,
                new_timeout,
                new_metadata,
                new_credentials,
                new_wait_for_ready,
                new_compression,
            ) = _unwrap_client_call_details(new_details, client_call_details)
            return self._thunk(new_method)(
                request_iterator,
                timeout=new_timeout,
                metadata=new_metadata,
                credentials=new_credentials,
                wait_for_ready=new_wait_for_ready,
                compression=new_compression,
            )

        try:
            return self._interceptor.intercept_stream_stream(
                continuation, client_call_details, request_iterator
            )
        except Exception as exception:  # pylint:disable=broad-except
            return _FailureOutcome(exception, sys.exc_info()[2])


class _Channel(grpc.Channel):
    _channel: grpc.Channel
    _interceptor: Union[
        grpc.UnaryUnaryClientInterceptor,
        grpc.UnaryStreamClientInterceptor,
        grpc.StreamStreamClientInterceptor,
        grpc.StreamUnaryClientInterceptor,
    ]

    def __init__(
        self,
        channel: grpc.Channel,
        interceptor: Union[
            grpc.UnaryUnaryClientInterceptor,
            grpc.UnaryStreamClientInterceptor,
            grpc.StreamStreamClientInterceptor,
            grpc.StreamUnaryClientInterceptor,
        ],
    ):
        self._channel = channel
        self._interceptor = interceptor

    def subscribe(
        self, callback: Callable, try_to_connect: Optional[bool] = False
    ):
        self._channel.subscribe(callback, try_to_connect=try_to_connect)

    def unsubscribe(self, callback: Callable):
        self._channel.unsubscribe(callback)

    # pylint: disable=arguments-differ
    def unary_unary(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
        _registered_method: Optional[bool] = False,
    ) -> grpc.UnaryUnaryMultiCallable:
        # pytype: disable=wrong-arg-count
        thunk = lambda m: self._channel.unary_unary(
            m,
            request_serializer,
            response_deserializer,
            _registered_method,
        )
        # pytype: enable=wrong-arg-count
        if isinstance(self._interceptor, grpc.UnaryUnaryClientInterceptor):
            return _UnaryUnaryMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    # pylint: disable=arguments-differ
    def unary_stream(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
        _registered_method: Optional[bool] = False,
    ) -> grpc.UnaryStreamMultiCallable:
        # pytype: disable=wrong-arg-count
        thunk = lambda m: self._channel.unary_stream(
            m,
            request_serializer,
            response_deserializer,
            _registered_method,
        )
        # pytype: enable=wrong-arg-count
        if isinstance(self._interceptor, grpc.UnaryStreamClientInterceptor):
            return _UnaryStreamMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    # pylint: disable=arguments-differ
    def stream_unary(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
        _registered_method: Optional[bool] = False,
    ) -> grpc.StreamUnaryMultiCallable:
        # pytype: disable=wrong-arg-count
        thunk = lambda m: self._channel.stream_unary(
            m,
            request_serializer,
            response_deserializer,
            _registered_method,
        )
        # pytype: enable=wrong-arg-count
        if isinstance(self._interceptor, grpc.StreamUnaryClientInterceptor):
            return _StreamUnaryMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    # pylint: disable=arguments-differ
    def stream_stream(
        self,
        method: str,
        request_serializer: Optional[SerializingFunction] = None,
        response_deserializer: Optional[DeserializingFunction] = None,
        _registered_method: Optional[bool] = False,
    ) -> grpc.StreamStreamMultiCallable:
        # pytype: disable=wrong-arg-count
        thunk = lambda m: self._channel.stream_stream(
            m,
            request_serializer,
            response_deserializer,
            _registered_method,
        )
        # pytype: enable=wrong-arg-count
        if isinstance(self._interceptor, grpc.StreamStreamClientInterceptor):
            return _StreamStreamMultiCallable(thunk, method, self._interceptor)
        else:
            return thunk(method)

    def _close(self):
        self._channel.close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._close()
        return False

    def close(self):
        self._channel.close()


def intercept_channel(
    channel: grpc.Channel,
    *interceptors: Optional[
        Sequence[
            Union[
                grpc.UnaryUnaryClientInterceptor,
                grpc.UnaryStreamClientInterceptor,
                grpc.StreamStreamClientInterceptor,
                grpc.StreamUnaryClientInterceptor,
            ]
        ]
    ],
) -> grpc.Channel:
    for interceptor in reversed(list(interceptors)):
        if (
            not isinstance(interceptor, grpc.UnaryUnaryClientInterceptor)
            and not isinstance(interceptor, grpc.UnaryStreamClientInterceptor)
            and not isinstance(interceptor, grpc.StreamUnaryClientInterceptor)
            and not isinstance(interceptor, grpc.StreamStreamClientInterceptor)
        ):
            raise TypeError(
                "interceptor must be "
                "grpc.UnaryUnaryClientInterceptor or "
                "grpc.UnaryStreamClientInterceptor or "
                "grpc.StreamUnaryClientInterceptor or "
                "grpc.StreamStreamClientInterceptor or "
            )
        channel = _Channel(channel, interceptor)
    return channel