# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for creating `messages <https://www.mongodb.com/docs/manual/reference/mongodb-wire-protocol/>`_ to be sent to MongoDB. .. note:: This module is for internal use and is generally not needed by application developers. """ from __future__ import annotations import datetime import logging import random import struct from io import BytesIO as _BytesIO from typing import ( TYPE_CHECKING, Any, Callable, Iterable, Mapping, MutableMapping, NoReturn, Optional, Union, ) import bson from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode from bson.int64 import Int64 from bson.raw_bson import ( _RAW_ARRAY_BSON_OPTIONS, DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson, ) try: from pymongo import _cmessage # type: ignore[attr-defined] _use_c = True except ImportError: _use_c = False from pymongo.errors import ( ConfigurationError, CursorNotFound, DocumentTooLarge, ExecutionTimeout, InvalidOperation, NotPrimaryError, OperationFailure, ProtocolError, ) from pymongo.hello import HelloCompat from pymongo.helpers import _handle_reauth from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from datetime import timedelta from pymongo.client_session import ClientSession from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.mongo_client import MongoClient from pymongo.monitoring import _EventListeners from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.typings import _Address, _DocumentOut MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 # Overhead allowed for encoded command documents. _COMMAND_OVERHEAD = 16382 _INSERT = 0 _UPDATE = 1 _DELETE = 2 _EMPTY = b"" _BSONOBJ = b"\x03" _ZERO_8 = b"\x00" _ZERO_16 = b"\x00\x00" _ZERO_32 = b"\x00\x00\x00\x00" _ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" _SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" _OP_MAP = { _INSERT: b"\x04documents\x00\x00\x00\x00\x00", _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", } _FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} _UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( unicode_decode_error_handler="replace" ) def _randint() -> int: """Generate a pseudo random 32 bit integer.""" return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 def _maybe_add_read_preference( spec: MutableMapping[str, Any], read_preference: _ServerMode ) -> MutableMapping[str, Any]: """Add $readPreference to spec when appropriate.""" mode = read_preference.mode document = read_preference.document # Only add $readPreference if it's something other than primary to avoid # problems with mongos versions that don't support read preferences. Also, # for maximum backwards compatibility, don't add $readPreference for # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting # the secondaryOkay bit has the same effect). if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): if "$query" not in spec: spec = {"$query": spec} spec["$readPreference"] = document return spec def _convert_exception(exception: Exception) -> dict[str, Any]: """Convert an Exception into a failure document for publishing.""" return {"errmsg": str(exception), "errtype": exception.__class__.__name__} def _convert_write_result( operation: str, command: Mapping[str, Any], result: Mapping[str, Any] ) -> dict[str, Any]: """Convert a legacy write result to write command format.""" # Based on _merge_legacy from bulk.py affected = result.get("n", 0) res = {"ok": 1, "n": affected} errmsg = result.get("errmsg", result.get("err", "")) if errmsg: # The write was successful on at least the primary so don't return. if result.get("wtimeout"): res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} else: # The write failed. error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} if "errInfo" in result: error["errInfo"] = result["errInfo"] res["writeErrors"] = [error] return res if operation == "insert": # GLE result for insert is always 0 in most MongoDB versions. res["n"] = len(command["documents"]) elif operation == "update": if "upserted" in result: res["upserted"] = [{"index": 0, "_id": result["upserted"]}] # Versions of MongoDB before 2.6 don't return the _id for an # upsert if _id is not an ObjectId. elif result.get("updatedExisting") is False and affected == 1: # If _id is in both the update document *and* the query spec # the update document _id takes precedence. update = command["updates"][0] _id = update["u"].get("_id", update["q"].get("_id")) res["upserted"] = [{"index": 0, "_id": _id}] return res _OPTIONS = { "tailable": 2, "oplogReplay": 8, "noCursorTimeout": 16, "awaitData": 32, "allowPartialResults": 128, } _MODIFIERS = { "$query": "filter", "$orderby": "sort", "$hint": "hint", "$comment": "comment", "$maxScan": "maxScan", "$maxTimeMS": "maxTimeMS", "$max": "max", "$min": "min", "$returnKey": "returnKey", "$showRecordId": "showRecordId", "$showDiskLoc": "showRecordId", # <= MongoDb 3.0 "$snapshot": "snapshot", } def _gen_find_command( coll: str, spec: Mapping[str, Any], projection: Optional[Union[Mapping[str, Any], Iterable[str]]], skip: int, limit: int, batch_size: Optional[int], options: Optional[int], read_concern: ReadConcern, collation: Optional[Mapping[str, Any]] = None, session: Optional[ClientSession] = None, allow_disk_use: Optional[bool] = None, ) -> dict[str, Any]: """Generate a find command document.""" cmd: dict[str, Any] = {"find": coll} if "$query" in spec: cmd.update( [ (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) for key, val in spec.items() ] ) if "$explain" in cmd: cmd.pop("$explain") if "$readPreference" in cmd: cmd.pop("$readPreference") else: cmd["filter"] = spec if projection: cmd["projection"] = projection if skip: cmd["skip"] = skip if limit: cmd["limit"] = abs(limit) if limit < 0: cmd["singleBatch"] = True if batch_size: cmd["batchSize"] = batch_size if read_concern.level and not (session and session.in_transaction): cmd["readConcern"] = read_concern.document if collation: cmd["collation"] = collation if allow_disk_use is not None: cmd["allowDiskUse"] = allow_disk_use if options: cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) return cmd def _gen_get_more_command( cursor_id: Optional[int], coll: str, batch_size: Optional[int], max_await_time_ms: Optional[int], comment: Optional[Any], conn: Connection, ) -> dict[str, Any]: """Generate a getMore command document.""" cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} if batch_size: cmd["batchSize"] = batch_size if max_await_time_ms is not None: cmd["maxTimeMS"] = max_await_time_ms if comment is not None and conn.max_wire_version >= 9: cmd["comment"] = comment return cmd class _Query: """A query operation.""" __slots__ = ( "flags", "db", "coll", "ntoskip", "spec", "fields", "codec_options", "read_preference", "limit", "batch_size", "name", "read_concern", "collation", "session", "client", "allow_disk_use", "_as_command", "exhaust", ) # For compatibility with the _GetMore class. conn_mgr = None cursor_id = None def __init__( self, flags: int, db: str, coll: str, ntoskip: int, spec: Mapping[str, Any], fields: Optional[Mapping[str, Any]], codec_options: CodecOptions, read_preference: _ServerMode, limit: int, batch_size: int, read_concern: ReadConcern, collation: Optional[Mapping[str, Any]], session: Optional[ClientSession], client: MongoClient, allow_disk_use: Optional[bool], exhaust: bool, ): self.flags = flags self.db = db self.coll = coll self.ntoskip = ntoskip self.spec = spec self.fields = fields self.codec_options = codec_options self.read_preference = read_preference self.read_concern = read_concern self.limit = limit self.batch_size = batch_size self.collation = collation self.session = session self.client = client self.allow_disk_use = allow_disk_use self.name = "find" self._as_command: Optional[tuple[dict[str, Any], str]] = None self.exhaust = exhaust def reset(self) -> None: self._as_command = None def namespace(self) -> str: return f"{self.db}.{self.coll}" def use_command(self, conn: Connection) -> bool: use_find_cmd = False if not self.exhaust: use_find_cmd = True elif conn.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( "read concern level of %s is not valid " "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) ) conn.validate_session(self.client, self.session) return use_find_cmd def as_command( self, conn: Connection, apply_timeout: bool = False ) -> tuple[dict[str, Any], str]: """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. if self._as_command is not None: return self._as_command explain = "$explain" in self.spec cmd: dict[str, Any] = _gen_find_command( self.coll, self.spec, self.fields, self.ntoskip, self.limit, self.batch_size, self.flags, self.read_concern, self.collation, self.session, self.allow_disk_use, ) if explain: self.name = "explain" cmd = {"explain": cmd} session = self.session conn.add_server_api(cmd) if session: session._apply_to(cmd, False, self.read_preference, conn) # Explain does not support readConcern. if not explain and not session.in_transaction: session._update_read_concern(cmd, conn) conn.send_cluster_time(cmd, session, self.client) # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) # Support CSOT if apply_timeout: conn.apply_timeout(client, cmd) self._as_command = cmd, self.db return self._as_command def get_message( self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False ) -> tuple[int, bytes, int]: """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. self.read_preference = read_preference if read_preference.mode: # Set the secondaryOk bit. flags = self.flags | 4 else: flags = self.flags ns = self.namespace() spec = self.spec if use_cmd: spec = self.as_command(conn, apply_timeout=True)[0] request_id, msg, size, _ = _op_msg( 0, spec, self.db, read_preference, self.codec_options, ctx=conn.compression_context, ) return request_id, msg, size # OP_QUERY treats ntoreturn of -1 and 1 the same, return # one document and close the cursor. We have to use 2 for # batch size if 1 is specified. ntoreturn = self.batch_size == 1 and 2 or self.batch_size if self.limit: if ntoreturn: ntoreturn = min(self.limit, ntoreturn) else: ntoreturn = self.limit if conn.is_mongos: assert isinstance(spec, MutableMapping) spec = _maybe_add_read_preference(spec, read_preference) return _query( flags, ns, self.ntoskip, ntoreturn, spec, None if use_cmd else self.fields, self.codec_options, ctx=conn.compression_context, ) class _GetMore: """A getmore operation.""" __slots__ = ( "db", "coll", "ntoreturn", "cursor_id", "max_await_time_ms", "codec_options", "read_preference", "session", "client", "conn_mgr", "_as_command", "exhaust", "comment", ) name = "getMore" def __init__( self, db: str, coll: str, ntoreturn: int, cursor_id: int, codec_options: CodecOptions, read_preference: _ServerMode, session: Optional[ClientSession], client: MongoClient, max_await_time_ms: Optional[int], conn_mgr: Any, exhaust: bool, comment: Any, ): self.db = db self.coll = coll self.ntoreturn = ntoreturn self.cursor_id = cursor_id self.codec_options = codec_options self.read_preference = read_preference self.session = session self.client = client self.max_await_time_ms = max_await_time_ms self.conn_mgr = conn_mgr self._as_command: Optional[tuple[dict[str, Any], str]] = None self.exhaust = exhaust self.comment = comment def reset(self) -> None: self._as_command = None def namespace(self) -> str: return f"{self.db}.{self.coll}" def use_command(self, conn: Connection) -> bool: use_cmd = False if not self.exhaust: use_cmd = True elif conn.max_wire_version >= 8: # OP_MSG supports exhaust on MongoDB 4.2+ use_cmd = True conn.validate_session(self.client, self.session) return use_cmd def as_command( self, conn: Connection, apply_timeout: bool = False ) -> tuple[dict[str, Any], str]: """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: return self._as_command cmd: dict[str, Any] = _gen_get_more_command( self.cursor_id, self.coll, self.ntoreturn, self.max_await_time_ms, self.comment, conn, ) if self.session: self.session._apply_to(cmd, False, self.read_preference, conn) conn.add_server_api(cmd) conn.send_cluster_time(cmd, self.session, self.client) # Support auto encryption client = self.client if client._encrypter and not client._encrypter._bypass_auto_encryption: cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) # Support CSOT if apply_timeout: conn.apply_timeout(client, cmd=None) self._as_command = cmd, self.db return self._as_command def get_message( self, dummy0: Any, conn: Connection, use_cmd: bool = False ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: """Get a getmore message.""" ns = self.namespace() ctx = conn.compression_context if use_cmd: spec = self.as_command(conn, apply_timeout=True)[0] if self.conn_mgr and self.exhaust: flags = _OpMsg.EXHAUST_ALLOWED else: flags = 0 request_id, msg, size, _ = _op_msg( flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context ) return request_id, msg, size return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) class _RawBatchQuery(_Query): def use_command(self, conn: Connection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True elif not self.exhaust: return True return False class _RawBatchGetMore(_GetMore): def use_command(self, conn: Connection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: # MongoDB 4.2+ supports exhaust over OP_MSG return True elif not self.exhaust: return True return False class _CursorAddress(tuple): """The server address (host, port) of a cursor, with namespace property.""" __namespace: Any def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: self = tuple.__new__(cls, address) self.__namespace = namespace return self @property def namespace(self) -> str: """The namespace this cursor.""" return self.__namespace def __hash__(self) -> int: # Two _CursorAddress instances with different namespaces # must not hash the same. return ((*self, self.__namespace)).__hash__() def __eq__(self, other: object) -> bool: if isinstance(other, _CursorAddress): return tuple(self) == tuple(other) and self.namespace == other.namespace return NotImplemented def __ne__(self, other: object) -> bool: return not self == other _pack_compression_header = struct.Struct("<iiiiiiB").pack _COMPRESSION_HEADER_SIZE = 25 def _compress( operation: int, data: bytes, ctx: Union[SnappyContext, ZlibContext, ZstdContext] ) -> tuple[int, bytes]: """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" compressed = ctx.compress(data) request_id = _randint() header = _pack_compression_header( _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length request_id, # Request id 0, # responseTo 2012, # operation id operation, # original operation id len(data), # uncompressed message length ctx.compressor_id, ) # compressor id return request_id, header + compressed _pack_header = struct.Struct("<iiii").pack def __pack_message(operation: int, data: bytes) -> tuple[int, bytes]: """Takes message data and adds a message header based on the operation. Returns the resultant message string. """ rid = _randint() message = _pack_header(16 + len(data), rid, 0, operation) return rid, message + data _pack_int = struct.Struct("<i").pack _pack_op_msg_flags_type = struct.Struct("<IB").pack _pack_byte = struct.Struct("<B").pack def _op_msg_no_header( flags: int, command: Mapping[str, Any], identifier: str, docs: Optional[list[Mapping[str, Any]]], opts: CodecOptions, ) -> tuple[bytes, int, int]: """Get a OP_MSG message. Note: this method handles multiple documents in a type one payload but it does not perform batch splitting and the total message size is only checked *after* generating the entire message. """ # Encode the command document in payload 0 without checking keys. encoded = _dict_to_bson(command, False, opts) flags_type = _pack_op_msg_flags_type(flags, 0) total_size = len(encoded) max_doc_size = 0 if identifier and docs is not None: type_one = _pack_byte(1) cstring = _make_c_string(identifier) encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 encoded_size = _pack_int(size) total_size += size max_doc_size = max(len(doc) for doc in encoded_docs) data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] else: data = [flags_type, encoded] return b"".join(data), total_size, max_doc_size def _op_msg_compressed( flags: int, command: Mapping[str, Any], identifier: str, docs: Optional[list[Mapping[str, Any]]], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext], ) -> tuple[int, bytes, int, int]: """Internal OP_MSG message helper.""" msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) rid, msg = _compress(2013, msg, ctx) return rid, msg, total_size, max_bson_size def _op_msg_uncompressed( flags: int, command: Mapping[str, Any], identifier: str, docs: Optional[list[Mapping[str, Any]]], opts: CodecOptions, ) -> tuple[int, bytes, int, int]: """Internal compressed OP_MSG message helper.""" data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) request_id, op_message = __pack_message(2013, data) return request_id, op_message, total_size, max_bson_size if _use_c: _op_msg_uncompressed = _cmessage._op_msg def _op_msg( flags: int, command: MutableMapping[str, Any], dbname: str, read_preference: Optional[_ServerMode], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, ) -> tuple[int, bytes, int, int]: """Get a OP_MSG message.""" command["$db"] = dbname # getMore commands do not send $readPreference. if read_preference is not None and "$readPreference" not in command: # Only send $readPreference if it's not primary (the default). if read_preference.mode: command["$readPreference"] = read_preference.document name = next(iter(command)) try: identifier = _FIELD_MAP[name] docs = command.pop(identifier) except KeyError: identifier = "" docs = None try: if ctx: return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) return _op_msg_uncompressed(flags, command, identifier, docs, opts) finally: # Add the field back to the command. if identifier: command[identifier] = docs def _query_impl( options: int, collection_name: str, num_to_skip: int, num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, ) -> tuple[bytes, int]: """Get an OP_QUERY message.""" encoded = _dict_to_bson(query, False, opts) if field_selector: efs = _dict_to_bson(field_selector, False, opts) else: efs = b"" max_bson_size = max(len(encoded), len(efs)) return ( b"".join( [ _pack_int(options), _make_c_string(collection_name), _pack_int(num_to_skip), _pack_int(num_to_return), encoded, efs, ] ), max_bson_size, ) def _query_compressed( options: int, collection_name: str, num_to_skip: int, num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext], ) -> tuple[int, bytes, int]: """Internal compressed query message helper.""" op_query, max_bson_size = _query_impl( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts ) rid, msg = _compress(2004, op_query, ctx) return rid, msg, max_bson_size def _query_uncompressed( options: int, collection_name: str, num_to_skip: int, num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, ) -> tuple[int, bytes, int]: """Internal query message helper.""" op_query, max_bson_size = _query_impl( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts ) rid, msg = __pack_message(2004, op_query) return rid, msg, max_bson_size if _use_c: _query_uncompressed = _cmessage._query_message def _query( options: int, collection_name: str, num_to_skip: int, num_to_return: int, query: Mapping[str, Any], field_selector: Optional[Mapping[str, Any]], opts: CodecOptions, ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, ) -> tuple[int, bytes, int]: """Get a **query** message.""" if ctx: return _query_compressed( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx ) return _query_uncompressed( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts ) _pack_long_long = struct.Struct("<q").pack def _get_more_impl(collection_name: str, num_to_return: int, cursor_id: int) -> bytes: """Get an OP_GET_MORE message.""" return b"".join( [ _ZERO_32, _make_c_string(collection_name), _pack_int(num_to_return), _pack_long_long(cursor_id), ] ) def _get_more_compressed( collection_name: str, num_to_return: int, cursor_id: int, ctx: Union[SnappyContext, ZlibContext, ZstdContext], ) -> tuple[int, bytes]: """Internal compressed getMore message helper.""" return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) def _get_more_uncompressed( collection_name: str, num_to_return: int, cursor_id: int ) -> tuple[int, bytes]: """Internal getMore message helper.""" return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) if _use_c: _get_more_uncompressed = _cmessage._get_more_message def _get_more( collection_name: str, num_to_return: int, cursor_id: int, ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, ) -> tuple[int, bytes]: """Get a **getMore** message.""" if ctx: return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) return _get_more_uncompressed(collection_name, num_to_return, cursor_id) class _BulkWriteContext: """A wrapper around Connection for use with write splitting functions.""" __slots__ = ( "db_name", "conn", "op_id", "name", "field", "publish", "start_time", "listeners", "session", "compress", "op_type", "codec", ) def __init__( self, database_name: str, cmd_name: str, conn: Connection, operation_id: int, listeners: _EventListeners, session: ClientSession, op_type: int, codec: CodecOptions, ): self.db_name = database_name self.conn = conn self.op_id = operation_id self.listeners = listeners self.publish = listeners.enabled_for_commands self.name = cmd_name self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() self.session = session self.compress = bool(conn.compression_context) self.op_type = op_type self.codec = codec def __batch_command( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] ) -> tuple[int, bytes, list[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" request_id, msg, to_send = _do_batched_op_msg( namespace, self.op_type, cmd, docs, self.codec, self ) if not to_send: raise InvalidOperation("cannot do an empty bulk write") return request_id, msg, to_send def execute( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: request_id, msg, to_send = self.__batch_command(cmd, docs) result = self.write_command(cmd, request_id, msg, to_send, client) client._process_response(result, self.session) return result, to_send def execute_unack( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient ) -> list[Mapping[str, Any]]: request_id, msg, to_send = self.__batch_command(cmd, docs) # Though this isn't strictly a "legacy" write, the helper # handles publishing commands and sending our message # without receiving a result. Send 0 for max_doc_size # to disable size checking. Size checking is handled while # the documents are encoded to BSON. self.unack_write(cmd, request_id, msg, 0, to_send, client) return to_send @property def max_bson_size(self) -> int: """A proxy for SockInfo.max_bson_size.""" return self.conn.max_bson_size @property def max_message_size(self) -> int: """A proxy for SockInfo.max_message_size.""" if self.compress: # Subtract 16 bytes for the message header. return self.conn.max_message_size - 16 return self.conn.max_message_size @property def max_write_batch_size(self) -> int: """A proxy for SockInfo.max_write_batch_size.""" return self.conn.max_write_batch_size @property def max_split_size(self) -> int: """The maximum size of a BSON command before batch splitting.""" return self.max_bson_size def unack_write( self, cmd: MutableMapping[str, Any], request_id: int, msg: bytes, max_doc_size: int, docs: list[Mapping[str, Any]], client: MongoClient, ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, command=cmd, commandName=next(iter(cmd)), databaseName=self.db_name, requestId=request_id, operationId=request_id, driverConnectionId=self.conn.id, serverConnectionId=self.conn.server_connection_id, serverHost=self.conn.address[0], serverPort=self.conn.address[1], serviceId=self.conn.service_id, ) if self.publish: cmd = self._start(cmd, request_id, docs) try: result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] duration = datetime.datetime.now() - self.start_time if result is not None: reply = _convert_write_result(self.name, cmd, result) else: # Comply with APM spec. reply = {"ok": 1} if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, durationMS=duration, reply=reply, commandName=next(iter(cmd)), databaseName=self.db_name, requestId=request_id, operationId=request_id, driverConnectionId=self.conn.id, serverConnectionId=self.conn.server_connection_id, serverHost=self.conn.address[0], serverPort=self.conn.address[1], serviceId=self.conn.service_id, ) if self.publish: self._succeed(request_id, reply, duration) except Exception as exc: duration = datetime.datetime.now() - self.start_time if isinstance(exc, OperationFailure): failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] elif isinstance(exc, NotPrimaryError): failure = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, durationMS=duration, failure=failure, commandName=next(iter(cmd)), databaseName=self.db_name, requestId=request_id, operationId=request_id, driverConnectionId=self.conn.id, serverConnectionId=self.conn.server_connection_id, serverHost=self.conn.address[0], serverPort=self.conn.address[1], serviceId=self.conn.service_id, isServerSideError=isinstance(exc, OperationFailure), ) if self.publish: assert self.start_time is not None self._fail(request_id, failure, duration) raise finally: self.start_time = datetime.datetime.now() return result @_handle_reauth def write_command( self, cmd: MutableMapping[str, Any], request_id: int, msg: bytes, docs: list[Mapping[str, Any]], client: MongoClient, ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[self.field] = docs if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.STARTED, command=cmd, commandName=next(iter(cmd)), databaseName=self.db_name, requestId=request_id, operationId=request_id, driverConnectionId=self.conn.id, serverConnectionId=self.conn.server_connection_id, serverHost=self.conn.address[0], serverPort=self.conn.address[1], serviceId=self.conn.service_id, ) if self.publish: self._start(cmd, request_id, docs) try: reply = self.conn.write_command(request_id, msg, self.codec) duration = datetime.datetime.now() - self.start_time if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.SUCCEEDED, durationMS=duration, reply=reply, commandName=next(iter(cmd)), databaseName=self.db_name, requestId=request_id, operationId=request_id, driverConnectionId=self.conn.id, serverConnectionId=self.conn.server_connection_id, serverHost=self.conn.address[0], serverPort=self.conn.address[1], serviceId=self.conn.service_id, ) if self.publish: self._succeed(request_id, reply, duration) except Exception as exc: duration = datetime.datetime.now() - self.start_time if isinstance(exc, (NotPrimaryError, OperationFailure)): failure: _DocumentOut = exc.details # type: ignore[assignment] else: failure = _convert_exception(exc) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, clientId=client._topology_settings._topology_id, message=_CommandStatusMessage.FAILED, durationMS=duration, failure=failure, commandName=next(iter(cmd)), databaseName=self.db_name, requestId=request_id, operationId=request_id, driverConnectionId=self.conn.id, serverConnectionId=self.conn.server_connection_id, serverHost=self.conn.address[0], serverPort=self.conn.address[1], serviceId=self.conn.service_id, isServerSideError=isinstance(exc, OperationFailure), ) if self.publish: self._fail(request_id, failure, duration) raise finally: self.start_time = datetime.datetime.now() return reply def _start( self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] ) -> MutableMapping[str, Any]: """Publish a CommandStartedEvent.""" cmd[self.field] = docs self.listeners.publish_command_start( cmd, self.db_name, request_id, self.conn.address, self.conn.server_connection_id, self.op_id, self.conn.service_id, ) return cmd def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: """Publish a CommandSucceededEvent.""" self.listeners.publish_command_success( duration, reply, self.name, request_id, self.conn.address, self.conn.server_connection_id, self.op_id, self.conn.service_id, database_name=self.db_name, ) def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: """Publish a CommandFailedEvent.""" self.listeners.publish_command_failure( duration, failure, self.name, request_id, self.conn.address, self.conn.server_connection_id, self.op_id, self.conn.service_id, database_name=self.db_name, ) # From the Client Side Encryption spec: # Because automatic encryption increases the size of commands, the driver # MUST split bulk writes at a reduced size limit before undergoing automatic # encryption. The write payload MUST be split at 2MiB (2097152). _MAX_SPLIT_SIZE_ENC = 2097152 class _EncryptedBulkWriteContext(_BulkWriteContext): __slots__ = () def __batch_command( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: namespace = self.db_name + ".$cmd" msg, to_send = _encode_batched_write_command( namespace, self.op_type, cmd, docs, self.codec, self ) if not to_send: raise InvalidOperation("cannot do an empty bulk write") # Chop off the OP_QUERY header to get a properly batched write command. cmd_start = msg.index(b"\x00", 4) + 9 outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) return outgoing, to_send def execute( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: batched_cmd, to_send = self.__batch_command(cmd, docs) result: Mapping[str, Any] = self.conn.command( self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client ) return result, to_send def execute_unack( self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient ) -> list[Mapping[str, Any]]: batched_cmd, to_send = self.__batch_command(cmd, docs) self.conn.command( self.db_name, batched_cmd, write_concern=WriteConcern(w=0), session=self.session, client=client, ) return to_send @property def max_split_size(self) -> int: """Reduce the batch splitting size.""" return _MAX_SPLIT_SIZE_ENC def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: """Internal helper for raising DocumentTooLarge.""" if operation == "insert": raise DocumentTooLarge( "BSON document too large (%d bytes)" " - the connected server supports" " BSON document sizes up to %d" " bytes." % (doc_size, max_size) ) else: # There's nothing intelligent we can say # about size for update and delete raise DocumentTooLarge(f"{operation!r} command document too large") # OP_MSG ------------------------------------------------------------- _OP_MSG_MAP = { _INSERT: b"documents\x00", _UPDATE: b"updates\x00", _DELETE: b"deletes\x00", } def _batched_op_msg_impl( operation: int, command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, buf: _BytesIO, ) -> tuple[list[Mapping[str, Any]], int]: """Create a batched OP_MSG write.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size max_message_size = ctx.max_message_size flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" # Flags buf.write(flags) # Type 0 Section buf.write(b"\x00") buf.write(_dict_to_bson(command, False, opts)) # Type 1 Section buf.write(b"\x01") size_location = buf.tell() # Save space for size buf.write(b"\x00\x00\x00\x00") try: buf.write(_OP_MSG_MAP[operation]) except KeyError: raise InvalidOperation("Unknown command") from None to_send = [] idx = 0 for doc in docs: # Encode the current operation value = _dict_to_bson(doc, False, opts) doc_length = len(value) new_message_size = buf.tell() + doc_length # Does first document exceed max_message_size? doc_too_large = idx == 0 and (new_message_size > max_message_size) # When OP_MSG is used unacknowledged we have to check # document size client side or applications won't be notified. # Otherwise we let the server deal with documents that are too large # since ordered=False causes those documents to be skipped instead of # halting the bulk write operation. unacked_doc_too_large = not ack and (doc_length > max_bson_size) if doc_too_large or unacked_doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] _raise_document_too_large(write_op, len(value), max_bson_size) # We have enough data, return this batch. if new_message_size > max_message_size: break buf.write(value) to_send.append(doc) idx += 1 # We have enough documents, return this batch. if idx == max_write_batch_size: break # Write type 1 section size length = buf.tell() buf.seek(size_location) buf.write(_pack_int(length - size_location)) return to_send, length def _encode_batched_op_msg( operation: int, command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, ) -> tuple[bytes, list[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete operation as OP_MSG. """ buf = _BytesIO() to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) return buf.getvalue(), to_send if _use_c: _encode_batched_op_msg = _cmessage._encode_batched_op_msg def _batched_op_msg_compressed( operation: int, command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation with OP_MSG, compressed. """ data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) assert ctx.conn.compression_context is not None request_id, msg = _compress(2013, data, ctx.conn.compression_context) return request_id, msg, to_send def _batched_op_msg( operation: int, command: Mapping[str, Any], docs: list[Mapping[str, Any]], ack: bool, opts: CodecOptions, ctx: _BulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]]]: """OP_MSG implementation entry point.""" buf = _BytesIO() # Save space for message length and request id buf.write(_ZERO_64) # responseTo, opCode buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) # Header - request id and message length buf.seek(4) request_id = _randint() buf.write(_pack_int(request_id)) buf.seek(0) buf.write(_pack_int(length)) return request_id, buf.getvalue(), to_send if _use_c: _batched_op_msg = _cmessage._batched_op_msg def _do_batched_op_msg( namespace: str, operation: int, command: MutableMapping[str, Any], docs: list[Mapping[str, Any]], opts: CodecOptions, ctx: _BulkWriteContext, ) -> tuple[int, bytes, list[Mapping[str, Any]]]: """Create the next batched insert, update, or delete operation using OP_MSG. """ command["$db"] = namespace.split(".", 1)[0] if "writeConcern" in command: ack = bool(command["writeConcern"].get("w", 1)) else: ack = True if ctx.conn.compression_context: return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) return _batched_op_msg(operation, command, docs, ack, opts, ctx) # End OP_MSG ----------------------------------------------------- def _encode_batched_write_command( namespace: str, operation: int, command: MutableMapping[str, Any], docs: list[Mapping[str, Any]], opts: CodecOptions, ctx: _BulkWriteContext, ) -> tuple[bytes, list[Mapping[str, Any]]]: """Encode the next batched insert, update, or delete command.""" buf = _BytesIO() to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) return buf.getvalue(), to_send if _use_c: _encode_batched_write_command = _cmessage._encode_batched_write_command def _batched_write_command_impl( namespace: str, operation: int, command: MutableMapping[str, Any], docs: list[Mapping[str, Any]], opts: CodecOptions, ctx: _BulkWriteContext, buf: _BytesIO, ) -> tuple[list[Mapping[str, Any]], int]: """Create a batched OP_QUERY write command.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size # Max BSON object size + 16k - 2 bytes for ending NUL bytes. # Server guarantees there is enough room: SERVER-10643. max_cmd_size = max_bson_size + _COMMAND_OVERHEAD max_split_size = ctx.max_split_size # No options buf.write(_ZERO_32) # Namespace as C string buf.write(namespace.encode("utf8")) buf.write(_ZERO_8) # Skip: 0, Limit: -1 buf.write(_SKIPLIM) # Where to write command document length command_start = buf.tell() buf.write(encode(command)) # Start of payload buf.seek(-1, 2) # Work around some Jython weirdness. buf.truncate() try: buf.write(_OP_MAP[operation]) except KeyError: raise InvalidOperation("Unknown command") from None # Where to write list document length list_start = buf.tell() - 4 to_send = [] idx = 0 for doc in docs: # Encode the current operation key = str(idx).encode("utf8") value = _dict_to_bson(doc, False, opts) # Is there enough room to add this document? max_cmd_size accounts for # the two trailing null bytes. doc_too_large = len(value) > max_cmd_size if doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] _raise_document_too_large(write_op, len(value), max_bson_size) enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size enough_documents = idx >= max_write_batch_size if enough_data or enough_documents: break buf.write(_BSONOBJ) buf.write(key) buf.write(_ZERO_8) buf.write(value) to_send.append(doc) idx += 1 # Finalize the current OP_QUERY message. # Close list and command documents buf.write(_ZERO_16) # Write document lengths and request id length = buf.tell() buf.seek(list_start) buf.write(_pack_int(length - list_start - 1)) buf.seek(command_start) buf.write(_pack_int(length - command_start)) return to_send, length class _OpReply: """A MongoDB OP_REPLY response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "documents") UNPACK_FROM = struct.Struct("<iqii").unpack_from OP_CODE = 1 def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes): self.flags = flags self.cursor_id = Int64(cursor_id) self.number_returned = number_returned self.documents = documents def raw_response( self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None ) -> list[bytes]: """Check the response header from the database, without decoding BSON. Check the response for errors and unpack. Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or OperationFailure. :param cursor_id: cursor_id we sent to get this response - used for raising an informative exception when we get cursor id not valid at server response. """ if self.flags & 1: # Shouldn't get this response if we aren't doing a getMore if cursor_id is None: raise ProtocolError("No cursor id for getMore operation") # Fake a getMore command response. OP_GET_MORE provides no # document. msg = "Cursor not found, cursor id: %d" % (cursor_id,) errobj = {"ok": 0, "errmsg": msg, "code": 43} raise CursorNotFound(msg, 43, errobj) elif self.flags & 2: error_object: dict = bson.BSON(self.documents).decode() # Fake the ok field if it doesn't exist. error_object.setdefault("ok", 0) if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): raise NotPrimaryError(error_object["$err"], error_object) elif error_object.get("code") == 50: default_msg = "operation exceeded time limit" raise ExecutionTimeout( error_object.get("$err", default_msg), error_object.get("code"), error_object ) raise OperationFailure( "database error: %s" % error_object.get("$err"), error_object.get("code"), error_object, ) if self.documents: return [self.documents] return [] def unpack_response( self, cursor_id: Optional[int] = None, codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> list[dict[str, Any]]: """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary containing the response data. Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or OperationFailure. :param cursor_id: cursor_id we sent to get this response - used for raising an informative exception when we get cursor id not valid at server response :param codec_options: an instance of :class:`~bson.codec_options.CodecOptions` :param user_fields: Response fields that should be decoded using the TypeDecoders from codec_options, passed to bson._decode_all_selective. """ self.raw_response(cursor_id) if legacy_response: return bson.decode_all(self.documents, codec_options) return bson._decode_all_selective(self.documents, codec_options, user_fields) def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: """Unpack a command response.""" docs = self.unpack_response(codec_options=codec_options) assert self.number_returned == 1 return docs[0] def raw_command_response(self) -> NoReturn: """Return the bytes of the command response.""" # This should never be called on _OpReply. raise NotImplementedError @property def more_to_come(self) -> bool: """Is the moreToCome bit set on this response?""" return False @classmethod def unpack(cls, msg: bytes) -> _OpReply: """Construct an _OpReply from raw bytes.""" # PYTHON-945: ignore starting_from field. flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) documents = msg[20:] return cls(flags, cursor_id, number_returned, documents) class _OpMsg: """A MongoDB OP_MSG response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") UNPACK_FROM = struct.Struct("<IBi").unpack_from OP_CODE = 2013 # Flag bits. CHECKSUM_PRESENT = 1 MORE_TO_COME = 1 << 1 EXHAUST_ALLOWED = 1 << 16 # Only present on requests. def __init__(self, flags: int, payload_document: bytes): self.flags = flags self.payload_document = payload_document def raw_response( self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = {}, ) -> list[Mapping[str, Any]]: """ cursor_id is ignored user_fields is used to determine which fields must not be decoded """ inflated_response = _decode_selective( RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS ) return [inflated_response] def unpack_response( self, cursor_id: Optional[int] = None, codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, user_fields: Optional[Mapping[str, Any]] = None, legacy_response: bool = False, ) -> list[dict[str, Any]]: """Unpack a OP_MSG command response. :param cursor_id: Ignored, for compatibility with _OpReply. :param codec_options: an instance of :class:`~bson.codec_options.CodecOptions` :param user_fields: Response fields that should be decoded using the TypeDecoders from codec_options, passed to bson._decode_all_selective. """ # If _OpMsg is in-use, this cannot be a legacy response. assert not legacy_response return bson._decode_all_selective(self.payload_document, codec_options, user_fields) def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: """Unpack a command response.""" return self.unpack_response(codec_options=codec_options)[0] def raw_command_response(self) -> bytes: """Return the bytes of the command response.""" return self.payload_document @property def more_to_come(self) -> bool: """Is the moreToCome bit set on this response?""" return bool(self.flags & self.MORE_TO_COME) @classmethod def unpack(cls, msg: bytes) -> _OpMsg: """Construct an _OpMsg from raw bytes.""" flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) if flags != 0: if flags & cls.CHECKSUM_PRESENT: raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") if flags ^ cls.MORE_TO_COME: raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") if first_payload_type != 0: raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") if len(msg) != first_payload_size + 5: raise ProtocolError("Unsupported OP_MSG reply: >1 section") payload_document = msg[5:] return cls(flags, payload_document) _UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { _OpReply.OP_CODE: _OpReply.unpack, _OpMsg.OP_CODE: _OpMsg.unpack, }