# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Bits and pieces used by the driver that don't really fit elsewhere.""" from __future__ import annotations import sys import traceback from collections import abc from typing import ( TYPE_CHECKING, Any, Callable, Container, Iterable, Mapping, NoReturn, Optional, Sequence, TypeVar, Union, cast, ) from pymongo import ASCENDING from pymongo.errors import ( CursorNotFound, DuplicateKeyError, ExecutionTimeout, NotPrimaryError, OperationFailure, WriteConcernError, WriteError, WTimeoutError, _wtimeout_error, ) from pymongo.hello import HelloCompat if TYPE_CHECKING: from pymongo.cursor import _Hint from pymongo.operations import _IndexList from pymongo.typings import _DocumentOut # From the SDAM spec, the "node is shutting down" codes. _SHUTDOWN_CODES: frozenset = frozenset( [ 11600, # InterruptedAtShutdown 91, # ShutdownInProgress ] ) # From the SDAM spec, the "not primary" error codes are combined with the # "node is recovering" error codes (of which the "node is shutting down" # errors are a subset). _NOT_PRIMARY_CODES: frozenset = ( frozenset( [ 10058, # LegacyNotPrimary <=3.2 "not primary" error code 10107, # NotWritablePrimary 13435, # NotPrimaryNoSecondaryOk 11602, # InterruptedDueToReplStateChange 13436, # NotPrimaryOrSecondary 189, # PrimarySteppedDown ] ) | _SHUTDOWN_CODES ) # From the retryable writes spec. _RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( [ 7, # HostNotFound 6, # HostUnreachable 89, # NetworkTimeout 9001, # SocketException 262, # ExceededTimeLimit 134, # ReadConcernMajorityNotAvailableYet ] ) # Server code raised when re-authentication is required _REAUTHENTICATION_REQUIRED_CODE: int = 391 # Server code raised when authentication fails. _AUTHENTICATION_FAILURE_CODE: int = 18 # Note - to avoid bugs from forgetting which if these is all lowercase and # which are camelCase, and at the same time avoid having to add a test for # every command, use all lowercase here and test against command_name.lower(). _SENSITIVE_COMMANDS: set = { "authenticate", "saslstart", "saslcontinue", "getnonce", "createuser", "updateuser", "copydbgetnonce", "copydbsaslstart", "copydb", } def _gen_index_name(keys: _IndexList) -> str: """Generate an index name from the set of fields it is over.""" return "_".join(["{}_{}".format(*item) for item in keys]) def _index_list( key_or_list: _Hint, direction: Optional[Union[int, str]] = None ) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: """Helper to generate a list of (key, direction) pairs. Takes such a list, or a single key, or a single key and direction. """ if direction is not None: if not isinstance(key_or_list, str): raise TypeError("Expected a string and a direction") return [(key_or_list, direction)] else: if isinstance(key_or_list, str): return [(key_or_list, ASCENDING)] elif isinstance(key_or_list, abc.ItemsView): return list(key_or_list) # type: ignore[arg-type] elif isinstance(key_or_list, abc.Mapping): return list(key_or_list.items()) elif not isinstance(key_or_list, (list, tuple)): raise TypeError("if no direction is specified, key_or_list must be an instance of list") values: list[tuple[str, int]] = [] for item in key_or_list: if isinstance(item, str): item = (item, ASCENDING) # noqa: PLW2901 values.append(item) return values def _index_document(index_list: _IndexList) -> dict[str, Any]: """Helper to generate an index specifying document. Takes a list of (key, direction) pairs. """ if not isinstance(index_list, (list, tuple, abc.Mapping)): raise TypeError( "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) ) if not len(index_list): raise ValueError("key_or_list must not be empty") index: dict[str, Any] = {} if isinstance(index_list, abc.Mapping): for key in index_list: value = index_list[key] _validate_index_key_pair(key, value) index[key] = value else: for item in index_list: if isinstance(item, str): item = (item, ASCENDING) # noqa: PLW2901 key, value = item _validate_index_key_pair(key, value) index[key] = value return index def _validate_index_key_pair(key: Any, value: Any) -> None: if not isinstance(key, str): raise TypeError("first item in each key pair must be an instance of str") if not isinstance(value, (str, int, abc.Mapping)): raise TypeError( "second item in each key pair must be 1, -1, " "'2d', or another valid MongoDB index specifier." ) def _check_command_response( response: _DocumentOut, max_wire_version: Optional[int], allowable_errors: Optional[Container[Union[int, str]]] = None, parse_write_concern_error: bool = False, ) -> None: """Check the response to a command for errors.""" if "ok" not in response: # Server didn't recognize our message as a command. raise OperationFailure( response.get("$err"), # type: ignore[arg-type] response.get("code"), response, max_wire_version, ) if parse_write_concern_error and "writeConcernError" in response: _error = response["writeConcernError"] _labels = response.get("errorLabels") if _labels: _error.update({"errorLabels": _labels}) _raise_write_concern_error(_error) if response["ok"]: return details = response # Mongos returns the error details in a 'raw' object # for some errors. if "raw" in response: for shard in response["raw"].values(): # Grab the first non-empty raw error from a shard. if shard.get("errmsg") and not shard.get("ok"): details = shard break errmsg = details["errmsg"] code = details.get("code") # For allowable errors, only check for error messages when the code is not # included. if allowable_errors: if code is not None: if code in allowable_errors: return elif errmsg in allowable_errors: return # Server is "not primary" or "recovering" if code is not None: if code in _NOT_PRIMARY_CODES: raise NotPrimaryError(errmsg, response) elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: raise NotPrimaryError(errmsg, response) # Other errors # findAndModify with upsert can raise duplicate key error if code in (11000, 11001, 12582): raise DuplicateKeyError(errmsg, code, response, max_wire_version) elif code == 50: raise ExecutionTimeout(errmsg, code, response, max_wire_version) elif code == 43: raise CursorNotFound(errmsg, code, response, max_wire_version) raise OperationFailure(errmsg, code, response, max_wire_version) def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: # If the last batch had multiple errors only report # the last error to emulate continue_on_error. error = write_errors[-1] if error.get("code") == 11000: raise DuplicateKeyError(error.get("errmsg"), 11000, error) raise WriteError(error.get("errmsg"), error.get("code"), error) def _raise_write_concern_error(error: Any) -> NoReturn: if _wtimeout_error(error): # Make sure we raise WTimeoutError raise WTimeoutError(error.get("errmsg"), error.get("code"), error) raise WriteConcernError(error.get("errmsg"), error.get("code"), error) def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: """Return the writeConcernError or None.""" wce = result.get("writeConcernError") if wce: # The server reports errorLabels at the top level but it's more # convenient to attach it to the writeConcernError doc itself. error_labels = result.get("errorLabels") if error_labels: # Copy to avoid changing the original document. wce = wce.copy() wce["errorLabels"] = error_labels return wce def _check_write_command_response(result: Mapping[str, Any]) -> None: """Backward compatibility helper for write command error handling.""" # Prefer write errors over write concern errors write_errors = result.get("writeErrors") if write_errors: _raise_last_write_error(write_errors) wce = _get_wce_doc(result) if wce: _raise_write_concern_error(wce) def _fields_list_to_dict( fields: Union[Mapping[str, Any], Iterable[str]], option_name: str ) -> Mapping[str, Any]: """Takes a sequence of field names and returns a matching dictionary. ["a", "b"] becomes {"a": 1, "b": 1} and ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} """ if isinstance(fields, abc.Mapping): return fields if isinstance(fields, (abc.Sequence, abc.Set)): if not all(isinstance(field, str) for field in fields): raise TypeError(f"{option_name} must be a list of key names, each an instance of str") return dict.fromkeys(fields, 1) raise TypeError(f"{option_name} must be a mapping or list of key names") def _handle_exception() -> None: """Print exceptions raised by subscribers to stderr.""" # Heavily influenced by logging.Handler.handleError. # See note here: # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ if sys.stderr: einfo = sys.exc_info() try: traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) except OSError: pass finally: del einfo # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) def _handle_reauth(func: F) -> F: def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.message import _BulkWriteContext from pymongo.pool import Connection try: return func(*args, **kwargs) except OperationFailure as exc: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: # Look for an argument that either is a Connection # or has a connection attribute, so we can trigger # a reauth. conn = None for arg in args: if isinstance(arg, Connection): conn = arg break if isinstance(arg, _BulkWriteContext): conn = arg.conn break if conn: conn.authenticate(reauthenticate=True) else: raise return func(*args, **kwargs) raise return cast(F, inner)