# Copyright 2022-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Internal helpers for CSOT.""" from __future__ import annotations import functools import time from collections import deque from contextlib import AbstractContextManager from contextvars import ContextVar, Token from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast if TYPE_CHECKING: from pymongo.write_concern import WriteConcern TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None) RTT: ContextVar[float] = ContextVar("RTT", default=0.0) DEADLINE: ContextVar[float] = ContextVar("DEADLINE", default=float("inf")) def get_timeout() -> Optional[float]: return TIMEOUT.get(None) def get_rtt() -> float: return RTT.get() def get_deadline() -> float: return DEADLINE.get() def set_rtt(rtt: float) -> None: RTT.set(rtt) def remaining() -> Optional[float]: if not get_timeout(): return None return DEADLINE.get() - time.monotonic() def clamp_remaining(max_timeout: float) -> float: """Return the remaining timeout clamped to a max value.""" timeout = remaining() if timeout is None: return max_timeout return min(timeout, max_timeout) class _TimeoutContext(AbstractContextManager): """Internal timeout context manager. Use :func:`pymongo.timeout` instead:: with pymongo.timeout(0.5): client.test.test.insert_one({}) """ def __init__(self, timeout: Optional[float]): self._timeout = timeout self._tokens: Optional[tuple[Token[Optional[float]], Token[float], Token[float]]] = None def __enter__(self) -> _TimeoutContext: timeout_token = TIMEOUT.set(self._timeout) prev_deadline = DEADLINE.get() next_deadline = time.monotonic() + self._timeout if self._timeout else float("inf") deadline_token = DEADLINE.set(min(prev_deadline, next_deadline)) rtt_token = RTT.set(0.0) self._tokens = (timeout_token, deadline_token, rtt_token) return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._tokens: timeout_token, deadline_token, rtt_token = self._tokens TIMEOUT.reset(timeout_token) DEADLINE.reset(deadline_token) RTT.reset(rtt_token) # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) def apply(func: F) -> F: """Apply the client's timeoutMS to this operation.""" @functools.wraps(func) def csot_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if get_timeout() is None: timeout = self._timeout if timeout is not None: with _TimeoutContext(timeout): return func(self, *args, **kwargs) return func(self, *args, **kwargs) return cast(F, csot_wrapper) def apply_write_concern( cmd: MutableMapping[str, Any], write_concern: Optional[WriteConcern] ) -> None: """Apply the given write concern to a command.""" if not write_concern or write_concern.is_server_default: return wc = write_concern.document if get_timeout() is not None: wc.pop("wtimeout", None) if wc: cmd["writeConcern"] = wc _MAX_RTT_SAMPLES: int = 10 _MIN_RTT_SAMPLES: int = 2 class MovingMinimum: """Tracks a minimum RTT within the last 10 RTT samples.""" samples: Deque[float] def __init__(self) -> None: self.samples = deque(maxlen=_MAX_RTT_SAMPLES) def add_sample(self, sample: float) -> None: if sample < 0: # Likely system time change while waiting for hello response # and not using time.monotonic. Ignore it, the next one will # probably be valid. return self.samples.append(sample) def get(self) -> float: """Get the min, or 0.0 if there aren't enough samples yet.""" if len(self.samples) >= _MIN_RTT_SAMPLES: return min(self.samples) return 0.0 def reset(self) -> None: self.samples.clear()