# Copyright 2009-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for creating `messages
`_ to be sent to
MongoDB.
.. note:: This module is for internal use and is generally not needed by
application developers.
"""
from __future__ import annotations
import datetime
import logging
import random
import struct
from io import BytesIO as _BytesIO
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Mapping,
MutableMapping,
NoReturn,
Optional,
Union,
)
import bson
from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode
from bson.int64 import Int64
from bson.raw_bson import (
_RAW_ARRAY_BSON_OPTIONS,
DEFAULT_RAW_BSON_OPTIONS,
RawBSONDocument,
_inflate_bson,
)
try:
from pymongo import _cmessage # type: ignore[attr-defined]
_use_c = True
except ImportError:
_use_c = False
from pymongo.errors import (
ConfigurationError,
CursorNotFound,
DocumentTooLarge,
ExecutionTimeout,
InvalidOperation,
NotPrimaryError,
OperationFailure,
ProtocolError,
)
from pymongo.hello import HelloCompat
from pymongo.helpers import _handle_reauth
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from datetime import timedelta
from pymongo.client_session import ClientSession
from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext
from pymongo.mongo_client import MongoClient
from pymongo.monitoring import _EventListeners
from pymongo.pool import Connection
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _DocumentOut
MAX_INT32 = 2147483647
MIN_INT32 = -2147483648
# Overhead allowed for encoded command documents.
_COMMAND_OVERHEAD = 16382
_INSERT = 0
_UPDATE = 1
_DELETE = 2
_EMPTY = b""
_BSONOBJ = b"\x03"
_ZERO_8 = b"\x00"
_ZERO_16 = b"\x00\x00"
_ZERO_32 = b"\x00\x00\x00\x00"
_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00"
_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff"
_OP_MAP = {
_INSERT: b"\x04documents\x00\x00\x00\x00\x00",
_UPDATE: b"\x04updates\x00\x00\x00\x00\x00",
_DELETE: b"\x04deletes\x00\x00\x00\x00\x00",
}
_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"}
_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions(
unicode_decode_error_handler="replace"
)
def _randint() -> int:
"""Generate a pseudo random 32 bit integer."""
return random.randint(MIN_INT32, MAX_INT32) # noqa: S311
def _maybe_add_read_preference(
spec: MutableMapping[str, Any], read_preference: _ServerMode
) -> MutableMapping[str, Any]:
"""Add $readPreference to spec when appropriate."""
mode = read_preference.mode
document = read_preference.document
# Only add $readPreference if it's something other than primary to avoid
# problems with mongos versions that don't support read preferences. Also,
# for maximum backwards compatibility, don't add $readPreference for
# secondaryPreferred unless tags or maxStalenessSeconds are in use (setting
# the secondaryOkay bit has the same effect).
if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1):
if "$query" not in spec:
spec = {"$query": spec}
spec["$readPreference"] = document
return spec
def _convert_exception(exception: Exception) -> dict[str, Any]:
"""Convert an Exception into a failure document for publishing."""
return {"errmsg": str(exception), "errtype": exception.__class__.__name__}
def _convert_write_result(
operation: str, command: Mapping[str, Any], result: Mapping[str, Any]
) -> dict[str, Any]:
"""Convert a legacy write result to write command format."""
# Based on _merge_legacy from bulk.py
affected = result.get("n", 0)
res = {"ok": 1, "n": affected}
errmsg = result.get("errmsg", result.get("err", ""))
if errmsg:
# The write was successful on at least the primary so don't return.
if result.get("wtimeout"):
res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}}
else:
# The write failed.
error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg}
if "errInfo" in result:
error["errInfo"] = result["errInfo"]
res["writeErrors"] = [error]
return res
if operation == "insert":
# GLE result for insert is always 0 in most MongoDB versions.
res["n"] = len(command["documents"])
elif operation == "update":
if "upserted" in result:
res["upserted"] = [{"index": 0, "_id": result["upserted"]}]
# Versions of MongoDB before 2.6 don't return the _id for an
# upsert if _id is not an ObjectId.
elif result.get("updatedExisting") is False and affected == 1:
# If _id is in both the update document *and* the query spec
# the update document _id takes precedence.
update = command["updates"][0]
_id = update["u"].get("_id", update["q"].get("_id"))
res["upserted"] = [{"index": 0, "_id": _id}]
return res
_OPTIONS = {
"tailable": 2,
"oplogReplay": 8,
"noCursorTimeout": 16,
"awaitData": 32,
"allowPartialResults": 128,
}
_MODIFIERS = {
"$query": "filter",
"$orderby": "sort",
"$hint": "hint",
"$comment": "comment",
"$maxScan": "maxScan",
"$maxTimeMS": "maxTimeMS",
"$max": "max",
"$min": "min",
"$returnKey": "returnKey",
"$showRecordId": "showRecordId",
"$showDiskLoc": "showRecordId", # <= MongoDb 3.0
"$snapshot": "snapshot",
}
def _gen_find_command(
coll: str,
spec: Mapping[str, Any],
projection: Optional[Union[Mapping[str, Any], Iterable[str]]],
skip: int,
limit: int,
batch_size: Optional[int],
options: Optional[int],
read_concern: ReadConcern,
collation: Optional[Mapping[str, Any]] = None,
session: Optional[ClientSession] = None,
allow_disk_use: Optional[bool] = None,
) -> dict[str, Any]:
"""Generate a find command document."""
cmd: dict[str, Any] = {"find": coll}
if "$query" in spec:
cmd.update(
[
(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val)
for key, val in spec.items()
]
)
if "$explain" in cmd:
cmd.pop("$explain")
if "$readPreference" in cmd:
cmd.pop("$readPreference")
else:
cmd["filter"] = spec
if projection:
cmd["projection"] = projection
if skip:
cmd["skip"] = skip
if limit:
cmd["limit"] = abs(limit)
if limit < 0:
cmd["singleBatch"] = True
if batch_size:
cmd["batchSize"] = batch_size
if read_concern.level and not (session and session.in_transaction):
cmd["readConcern"] = read_concern.document
if collation:
cmd["collation"] = collation
if allow_disk_use is not None:
cmd["allowDiskUse"] = allow_disk_use
if options:
cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val])
return cmd
def _gen_get_more_command(
cursor_id: Optional[int],
coll: str,
batch_size: Optional[int],
max_await_time_ms: Optional[int],
comment: Optional[Any],
conn: Connection,
) -> dict[str, Any]:
"""Generate a getMore command document."""
cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll}
if batch_size:
cmd["batchSize"] = batch_size
if max_await_time_ms is not None:
cmd["maxTimeMS"] = max_await_time_ms
if comment is not None and conn.max_wire_version >= 9:
cmd["comment"] = comment
return cmd
class _Query:
"""A query operation."""
__slots__ = (
"flags",
"db",
"coll",
"ntoskip",
"spec",
"fields",
"codec_options",
"read_preference",
"limit",
"batch_size",
"name",
"read_concern",
"collation",
"session",
"client",
"allow_disk_use",
"_as_command",
"exhaust",
)
# For compatibility with the _GetMore class.
conn_mgr = None
cursor_id = None
def __init__(
self,
flags: int,
db: str,
coll: str,
ntoskip: int,
spec: Mapping[str, Any],
fields: Optional[Mapping[str, Any]],
codec_options: CodecOptions,
read_preference: _ServerMode,
limit: int,
batch_size: int,
read_concern: ReadConcern,
collation: Optional[Mapping[str, Any]],
session: Optional[ClientSession],
client: MongoClient,
allow_disk_use: Optional[bool],
exhaust: bool,
):
self.flags = flags
self.db = db
self.coll = coll
self.ntoskip = ntoskip
self.spec = spec
self.fields = fields
self.codec_options = codec_options
self.read_preference = read_preference
self.read_concern = read_concern
self.limit = limit
self.batch_size = batch_size
self.collation = collation
self.session = session
self.client = client
self.allow_disk_use = allow_disk_use
self.name = "find"
self._as_command: Optional[tuple[dict[str, Any], str]] = None
self.exhaust = exhaust
def reset(self) -> None:
self._as_command = None
def namespace(self) -> str:
return f"{self.db}.{self.coll}"
def use_command(self, conn: Connection) -> bool:
use_find_cmd = False
if not self.exhaust:
use_find_cmd = True
elif conn.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+
use_find_cmd = True
elif not self.read_concern.ok_for_legacy:
raise ConfigurationError(
"read concern level of %s is not valid "
"with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version)
)
conn.validate_session(self.client, self.session)
return use_find_cmd
def as_command(
self, conn: Connection, apply_timeout: bool = False
) -> tuple[dict[str, Any], str]:
"""Return a find command document for this query."""
# We use the command twice: on the wire and for command monitoring.
# Generate it once, for speed and to avoid repeating side-effects.
if self._as_command is not None:
return self._as_command
explain = "$explain" in self.spec
cmd: dict[str, Any] = _gen_find_command(
self.coll,
self.spec,
self.fields,
self.ntoskip,
self.limit,
self.batch_size,
self.flags,
self.read_concern,
self.collation,
self.session,
self.allow_disk_use,
)
if explain:
self.name = "explain"
cmd = {"explain": cmd}
session = self.session
conn.add_server_api(cmd)
if session:
session._apply_to(cmd, False, self.read_preference, conn)
# Explain does not support readConcern.
if not explain and not session.in_transaction:
session._update_read_concern(cmd, conn)
conn.send_cluster_time(cmd, session, self.client)
# Support auto encryption
client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
# Support CSOT
if apply_timeout:
conn.apply_timeout(client, cmd)
self._as_command = cmd, self.db
return self._as_command
def get_message(
self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False
) -> tuple[int, bytes, int]:
"""Get a query message, possibly setting the secondaryOk bit."""
# Use the read_preference decided by _socket_from_server.
self.read_preference = read_preference
if read_preference.mode:
# Set the secondaryOk bit.
flags = self.flags | 4
else:
flags = self.flags
ns = self.namespace()
spec = self.spec
if use_cmd:
spec = self.as_command(conn, apply_timeout=True)[0]
request_id, msg, size, _ = _op_msg(
0,
spec,
self.db,
read_preference,
self.codec_options,
ctx=conn.compression_context,
)
return request_id, msg, size
# OP_QUERY treats ntoreturn of -1 and 1 the same, return
# one document and close the cursor. We have to use 2 for
# batch size if 1 is specified.
ntoreturn = self.batch_size == 1 and 2 or self.batch_size
if self.limit:
if ntoreturn:
ntoreturn = min(self.limit, ntoreturn)
else:
ntoreturn = self.limit
if conn.is_mongos:
assert isinstance(spec, MutableMapping)
spec = _maybe_add_read_preference(spec, read_preference)
return _query(
flags,
ns,
self.ntoskip,
ntoreturn,
spec,
None if use_cmd else self.fields,
self.codec_options,
ctx=conn.compression_context,
)
class _GetMore:
"""A getmore operation."""
__slots__ = (
"db",
"coll",
"ntoreturn",
"cursor_id",
"max_await_time_ms",
"codec_options",
"read_preference",
"session",
"client",
"conn_mgr",
"_as_command",
"exhaust",
"comment",
)
name = "getMore"
def __init__(
self,
db: str,
coll: str,
ntoreturn: int,
cursor_id: int,
codec_options: CodecOptions,
read_preference: _ServerMode,
session: Optional[ClientSession],
client: MongoClient,
max_await_time_ms: Optional[int],
conn_mgr: Any,
exhaust: bool,
comment: Any,
):
self.db = db
self.coll = coll
self.ntoreturn = ntoreturn
self.cursor_id = cursor_id
self.codec_options = codec_options
self.read_preference = read_preference
self.session = session
self.client = client
self.max_await_time_ms = max_await_time_ms
self.conn_mgr = conn_mgr
self._as_command: Optional[tuple[dict[str, Any], str]] = None
self.exhaust = exhaust
self.comment = comment
def reset(self) -> None:
self._as_command = None
def namespace(self) -> str:
return f"{self.db}.{self.coll}"
def use_command(self, conn: Connection) -> bool:
use_cmd = False
if not self.exhaust:
use_cmd = True
elif conn.max_wire_version >= 8:
# OP_MSG supports exhaust on MongoDB 4.2+
use_cmd = True
conn.validate_session(self.client, self.session)
return use_cmd
def as_command(
self, conn: Connection, apply_timeout: bool = False
) -> tuple[dict[str, Any], str]:
"""Return a getMore command document for this query."""
# See _Query.as_command for an explanation of this caching.
if self._as_command is not None:
return self._as_command
cmd: dict[str, Any] = _gen_get_more_command(
self.cursor_id,
self.coll,
self.ntoreturn,
self.max_await_time_ms,
self.comment,
conn,
)
if self.session:
self.session._apply_to(cmd, False, self.read_preference, conn)
conn.add_server_api(cmd)
conn.send_cluster_time(cmd, self.session, self.client)
# Support auto encryption
client = self.client
if client._encrypter and not client._encrypter._bypass_auto_encryption:
cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options)
# Support CSOT
if apply_timeout:
conn.apply_timeout(client, cmd=None)
self._as_command = cmd, self.db
return self._as_command
def get_message(
self, dummy0: Any, conn: Connection, use_cmd: bool = False
) -> Union[tuple[int, bytes, int], tuple[int, bytes]]:
"""Get a getmore message."""
ns = self.namespace()
ctx = conn.compression_context
if use_cmd:
spec = self.as_command(conn, apply_timeout=True)[0]
if self.conn_mgr and self.exhaust:
flags = _OpMsg.EXHAUST_ALLOWED
else:
flags = 0
request_id, msg, size, _ = _op_msg(
flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context
)
return request_id, msg, size
return _get_more(ns, self.ntoreturn, self.cursor_id, ctx)
class _RawBatchQuery(_Query):
def use_command(self, conn: Connection) -> bool:
# Compatibility checks.
super().use_command(conn)
if conn.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
elif not self.exhaust:
return True
return False
class _RawBatchGetMore(_GetMore):
def use_command(self, conn: Connection) -> bool:
# Compatibility checks.
super().use_command(conn)
if conn.max_wire_version >= 8:
# MongoDB 4.2+ supports exhaust over OP_MSG
return True
elif not self.exhaust:
return True
return False
class _CursorAddress(tuple):
"""The server address (host, port) of a cursor, with namespace property."""
__namespace: Any
def __new__(cls, address: _Address, namespace: str) -> _CursorAddress:
self = tuple.__new__(cls, address)
self.__namespace = namespace
return self
@property
def namespace(self) -> str:
"""The namespace this cursor."""
return self.__namespace
def __hash__(self) -> int:
# Two _CursorAddress instances with different namespaces
# must not hash the same.
return ((*self, self.__namespace)).__hash__()
def __eq__(self, other: object) -> bool:
if isinstance(other, _CursorAddress):
return tuple(self) == tuple(other) and self.namespace == other.namespace
return NotImplemented
def __ne__(self, other: object) -> bool:
return not self == other
_pack_compression_header = struct.Struct(" tuple[int, bytes]:
"""Takes message data, compresses it, and adds an OP_COMPRESSED header."""
compressed = ctx.compress(data)
request_id = _randint()
header = _pack_compression_header(
_COMPRESSION_HEADER_SIZE + len(compressed), # Total message length
request_id, # Request id
0, # responseTo
2012, # operation id
operation, # original operation id
len(data), # uncompressed message length
ctx.compressor_id,
) # compressor id
return request_id, header + compressed
_pack_header = struct.Struct(" tuple[int, bytes]:
"""Takes message data and adds a message header based on the operation.
Returns the resultant message string.
"""
rid = _randint()
message = _pack_header(16 + len(data), rid, 0, operation)
return rid, message + data
_pack_int = struct.Struct(" tuple[bytes, int, int]:
"""Get a OP_MSG message.
Note: this method handles multiple documents in a type one payload but
it does not perform batch splitting and the total message size is
only checked *after* generating the entire message.
"""
# Encode the command document in payload 0 without checking keys.
encoded = _dict_to_bson(command, False, opts)
flags_type = _pack_op_msg_flags_type(flags, 0)
total_size = len(encoded)
max_doc_size = 0
if identifier and docs is not None:
type_one = _pack_byte(1)
cstring = _make_c_string(identifier)
encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs]
size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4
encoded_size = _pack_int(size)
total_size += size
max_doc_size = max(len(doc) for doc in encoded_docs)
data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs]
else:
data = [flags_type, encoded]
return b"".join(data), total_size, max_doc_size
def _op_msg_compressed(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes, int, int]:
"""Internal OP_MSG message helper."""
msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
rid, msg = _compress(2013, msg, ctx)
return rid, msg, total_size, max_bson_size
def _op_msg_uncompressed(
flags: int,
command: Mapping[str, Any],
identifier: str,
docs: Optional[list[Mapping[str, Any]]],
opts: CodecOptions,
) -> tuple[int, bytes, int, int]:
"""Internal compressed OP_MSG message helper."""
data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts)
request_id, op_message = __pack_message(2013, data)
return request_id, op_message, total_size, max_bson_size
if _use_c:
_op_msg_uncompressed = _cmessage._op_msg
def _op_msg(
flags: int,
command: MutableMapping[str, Any],
dbname: str,
read_preference: Optional[_ServerMode],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes, int, int]:
"""Get a OP_MSG message."""
command["$db"] = dbname
# getMore commands do not send $readPreference.
if read_preference is not None and "$readPreference" not in command:
# Only send $readPreference if it's not primary (the default).
if read_preference.mode:
command["$readPreference"] = read_preference.document
name = next(iter(command))
try:
identifier = _FIELD_MAP[name]
docs = command.pop(identifier)
except KeyError:
identifier = ""
docs = None
try:
if ctx:
return _op_msg_compressed(flags, command, identifier, docs, opts, ctx)
return _op_msg_uncompressed(flags, command, identifier, docs, opts)
finally:
# Add the field back to the command.
if identifier:
command[identifier] = docs
def _query_impl(
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
) -> tuple[bytes, int]:
"""Get an OP_QUERY message."""
encoded = _dict_to_bson(query, False, opts)
if field_selector:
efs = _dict_to_bson(field_selector, False, opts)
else:
efs = b""
max_bson_size = max(len(encoded), len(efs))
return (
b"".join(
[
_pack_int(options),
_make_c_string(collection_name),
_pack_int(num_to_skip),
_pack_int(num_to_return),
encoded,
efs,
]
),
max_bson_size,
)
def _query_compressed(
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes, int]:
"""Internal compressed query message helper."""
op_query, max_bson_size = _query_impl(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
)
rid, msg = _compress(2004, op_query, ctx)
return rid, msg, max_bson_size
def _query_uncompressed(
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
) -> tuple[int, bytes, int]:
"""Internal query message helper."""
op_query, max_bson_size = _query_impl(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
)
rid, msg = __pack_message(2004, op_query)
return rid, msg, max_bson_size
if _use_c:
_query_uncompressed = _cmessage._query_message
def _query(
options: int,
collection_name: str,
num_to_skip: int,
num_to_return: int,
query: Mapping[str, Any],
field_selector: Optional[Mapping[str, Any]],
opts: CodecOptions,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes, int]:
"""Get a **query** message."""
if ctx:
return _query_compressed(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx
)
return _query_uncompressed(
options, collection_name, num_to_skip, num_to_return, query, field_selector, opts
)
_pack_long_long = struct.Struct(" bytes:
"""Get an OP_GET_MORE message."""
return b"".join(
[
_ZERO_32,
_make_c_string(collection_name),
_pack_int(num_to_return),
_pack_long_long(cursor_id),
]
)
def _get_more_compressed(
collection_name: str,
num_to_return: int,
cursor_id: int,
ctx: Union[SnappyContext, ZlibContext, ZstdContext],
) -> tuple[int, bytes]:
"""Internal compressed getMore message helper."""
return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx)
def _get_more_uncompressed(
collection_name: str, num_to_return: int, cursor_id: int
) -> tuple[int, bytes]:
"""Internal getMore message helper."""
return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id))
if _use_c:
_get_more_uncompressed = _cmessage._get_more_message
def _get_more(
collection_name: str,
num_to_return: int,
cursor_id: int,
ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
) -> tuple[int, bytes]:
"""Get a **getMore** message."""
if ctx:
return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx)
return _get_more_uncompressed(collection_name, num_to_return, cursor_id)
class _BulkWriteContext:
"""A wrapper around Connection for use with write splitting functions."""
__slots__ = (
"db_name",
"conn",
"op_id",
"name",
"field",
"publish",
"start_time",
"listeners",
"session",
"compress",
"op_type",
"codec",
)
def __init__(
self,
database_name: str,
cmd_name: str,
conn: Connection,
operation_id: int,
listeners: _EventListeners,
session: ClientSession,
op_type: int,
codec: CodecOptions,
):
self.db_name = database_name
self.conn = conn
self.op_id = operation_id
self.listeners = listeners
self.publish = listeners.enabled_for_commands
self.name = cmd_name
self.field = _FIELD_MAP[self.name]
self.start_time = datetime.datetime.now()
self.session = session
self.compress = bool(conn.compression_context)
self.op_type = op_type
self.codec = codec
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
request_id, msg, to_send = _do_batched_op_msg(
namespace, self.op_type, cmd, docs, self.codec, self
)
if not to_send:
raise InvalidOperation("cannot do an empty bulk write")
return request_id, msg, to_send
def execute(
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]:
request_id, msg, to_send = self.__batch_command(cmd, docs)
result = self.write_command(cmd, request_id, msg, to_send, client)
client._process_response(result, self.session)
return result, to_send
def execute_unack(
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> list[Mapping[str, Any]]:
request_id, msg, to_send = self.__batch_command(cmd, docs)
# Though this isn't strictly a "legacy" write, the helper
# handles publishing commands and sending our message
# without receiving a result. Send 0 for max_doc_size
# to disable size checking. Size checking is handled while
# the documents are encoded to BSON.
self.unack_write(cmd, request_id, msg, 0, to_send, client)
return to_send
@property
def max_bson_size(self) -> int:
"""A proxy for SockInfo.max_bson_size."""
return self.conn.max_bson_size
@property
def max_message_size(self) -> int:
"""A proxy for SockInfo.max_message_size."""
if self.compress:
# Subtract 16 bytes for the message header.
return self.conn.max_message_size - 16
return self.conn.max_message_size
@property
def max_write_batch_size(self) -> int:
"""A proxy for SockInfo.max_write_batch_size."""
return self.conn.max_write_batch_size
@property
def max_split_size(self) -> int:
"""The maximum size of a BSON command before batch splitting."""
return self.max_bson_size
def unack_write(
self,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
max_doc_size: int,
docs: list[Mapping[str, Any]],
client: MongoClient,
) -> Optional[Mapping[str, Any]]:
"""A proxy for Connection.unack_write that handles event publishing."""
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=self.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=self.conn.id,
serverConnectionId=self.conn.server_connection_id,
serverHost=self.conn.address[0],
serverPort=self.conn.address[1],
serviceId=self.conn.service_id,
)
if self.publish:
cmd = self._start(cmd, request_id, docs)
try:
result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value]
duration = datetime.datetime.now() - self.start_time
if result is not None:
reply = _convert_write_result(self.name, cmd, result)
else:
# Comply with APM spec.
reply = {"ok": 1}
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=self.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=self.conn.id,
serverConnectionId=self.conn.server_connection_id,
serverHost=self.conn.address[0],
serverPort=self.conn.address[1],
serviceId=self.conn.service_id,
)
if self.publish:
self._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - self.start_time
if isinstance(exc, OperationFailure):
failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type]
elif isinstance(exc, NotPrimaryError):
failure = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=self.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=self.conn.id,
serverConnectionId=self.conn.server_connection_id,
serverHost=self.conn.address[0],
serverPort=self.conn.address[1],
serviceId=self.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if self.publish:
assert self.start_time is not None
self._fail(request_id, failure, duration)
raise
finally:
self.start_time = datetime.datetime.now()
return result
@_handle_reauth
def write_command(
self,
cmd: MutableMapping[str, Any],
request_id: int,
msg: bytes,
docs: list[Mapping[str, Any]],
client: MongoClient,
) -> dict[str, Any]:
"""A proxy for SocketInfo.write_command that handles event publishing."""
cmd[self.field] = docs
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.STARTED,
command=cmd,
commandName=next(iter(cmd)),
databaseName=self.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=self.conn.id,
serverConnectionId=self.conn.server_connection_id,
serverHost=self.conn.address[0],
serverPort=self.conn.address[1],
serviceId=self.conn.service_id,
)
if self.publish:
self._start(cmd, request_id, docs)
try:
reply = self.conn.write_command(request_id, msg, self.codec)
duration = datetime.datetime.now() - self.start_time
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.SUCCEEDED,
durationMS=duration,
reply=reply,
commandName=next(iter(cmd)),
databaseName=self.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=self.conn.id,
serverConnectionId=self.conn.server_connection_id,
serverHost=self.conn.address[0],
serverPort=self.conn.address[1],
serviceId=self.conn.service_id,
)
if self.publish:
self._succeed(request_id, reply, duration)
except Exception as exc:
duration = datetime.datetime.now() - self.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
failure: _DocumentOut = exc.details # type: ignore[assignment]
else:
failure = _convert_exception(exc)
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
clientId=client._topology_settings._topology_id,
message=_CommandStatusMessage.FAILED,
durationMS=duration,
failure=failure,
commandName=next(iter(cmd)),
databaseName=self.db_name,
requestId=request_id,
operationId=request_id,
driverConnectionId=self.conn.id,
serverConnectionId=self.conn.server_connection_id,
serverHost=self.conn.address[0],
serverPort=self.conn.address[1],
serviceId=self.conn.service_id,
isServerSideError=isinstance(exc, OperationFailure),
)
if self.publish:
self._fail(request_id, failure, duration)
raise
finally:
self.start_time = datetime.datetime.now()
return reply
def _start(
self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]]
) -> MutableMapping[str, Any]:
"""Publish a CommandStartedEvent."""
cmd[self.field] = docs
self.listeners.publish_command_start(
cmd,
self.db_name,
request_id,
self.conn.address,
self.conn.server_connection_id,
self.op_id,
self.conn.service_id,
)
return cmd
def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None:
"""Publish a CommandSucceededEvent."""
self.listeners.publish_command_success(
duration,
reply,
self.name,
request_id,
self.conn.address,
self.conn.server_connection_id,
self.op_id,
self.conn.service_id,
database_name=self.db_name,
)
def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None:
"""Publish a CommandFailedEvent."""
self.listeners.publish_command_failure(
duration,
failure,
self.name,
request_id,
self.conn.address,
self.conn.server_connection_id,
self.op_id,
self.conn.service_id,
database_name=self.db_name,
)
# From the Client Side Encryption spec:
# Because automatic encryption increases the size of commands, the driver
# MUST split bulk writes at a reduced size limit before undergoing automatic
# encryption. The write payload MUST be split at 2MiB (2097152).
_MAX_SPLIT_SIZE_ENC = 2097152
class _EncryptedBulkWriteContext(_BulkWriteContext):
__slots__ = ()
def __batch_command(
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]]
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
namespace = self.db_name + ".$cmd"
msg, to_send = _encode_batched_write_command(
namespace, self.op_type, cmd, docs, self.codec, self
)
if not to_send:
raise InvalidOperation("cannot do an empty bulk write")
# Chop off the OP_QUERY header to get a properly batched write command.
cmd_start = msg.index(b"\x00", 4) + 9
outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS)
return outgoing, to_send
def execute(
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]:
batched_cmd, to_send = self.__batch_command(cmd, docs)
result: Mapping[str, Any] = self.conn.command(
self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client
)
return result, to_send
def execute_unack(
self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient
) -> list[Mapping[str, Any]]:
batched_cmd, to_send = self.__batch_command(cmd, docs)
self.conn.command(
self.db_name,
batched_cmd,
write_concern=WriteConcern(w=0),
session=self.session,
client=client,
)
return to_send
@property
def max_split_size(self) -> int:
"""Reduce the batch splitting size."""
return _MAX_SPLIT_SIZE_ENC
def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn:
"""Internal helper for raising DocumentTooLarge."""
if operation == "insert":
raise DocumentTooLarge(
"BSON document too large (%d bytes)"
" - the connected server supports"
" BSON document sizes up to %d"
" bytes." % (doc_size, max_size)
)
else:
# There's nothing intelligent we can say
# about size for update and delete
raise DocumentTooLarge(f"{operation!r} command document too large")
# OP_MSG -------------------------------------------------------------
_OP_MSG_MAP = {
_INSERT: b"documents\x00",
_UPDATE: b"updates\x00",
_DELETE: b"deletes\x00",
}
def _batched_op_msg_impl(
operation: int,
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], int]:
"""Create a batched OP_MSG write."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
max_message_size = ctx.max_message_size
flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00"
# Flags
buf.write(flags)
# Type 0 Section
buf.write(b"\x00")
buf.write(_dict_to_bson(command, False, opts))
# Type 1 Section
buf.write(b"\x01")
size_location = buf.tell()
# Save space for size
buf.write(b"\x00\x00\x00\x00")
try:
buf.write(_OP_MSG_MAP[operation])
except KeyError:
raise InvalidOperation("Unknown command") from None
to_send = []
idx = 0
for doc in docs:
# Encode the current operation
value = _dict_to_bson(doc, False, opts)
doc_length = len(value)
new_message_size = buf.tell() + doc_length
# Does first document exceed max_message_size?
doc_too_large = idx == 0 and (new_message_size > max_message_size)
# When OP_MSG is used unacknowledged we have to check
# document size client side or applications won't be notified.
# Otherwise we let the server deal with documents that are too large
# since ordered=False causes those documents to be skipped instead of
# halting the bulk write operation.
unacked_doc_too_large = not ack and (doc_length > max_bson_size)
if doc_too_large or unacked_doc_too_large:
write_op = list(_FIELD_MAP.keys())[operation]
_raise_document_too_large(write_op, len(value), max_bson_size)
# We have enough data, return this batch.
if new_message_size > max_message_size:
break
buf.write(value)
to_send.append(doc)
idx += 1
# We have enough documents, return this batch.
if idx == max_write_batch_size:
break
# Write type 1 section size
length = buf.tell()
buf.seek(size_location)
buf.write(_pack_int(length - size_location))
return to_send, length
def _encode_batched_op_msg(
operation: int,
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete operation
as OP_MSG.
"""
buf = _BytesIO()
to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf)
return buf.getvalue(), to_send
if _use_c:
_encode_batched_op_msg = _cmessage._encode_batched_op_msg
def _batched_op_msg_compressed(
operation: int,
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
with OP_MSG, compressed.
"""
data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx)
assert ctx.conn.compression_context is not None
request_id, msg = _compress(2013, data, ctx.conn.compression_context)
return request_id, msg, to_send
def _batched_op_msg(
operation: int,
command: Mapping[str, Any],
docs: list[Mapping[str, Any]],
ack: bool,
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""OP_MSG implementation entry point."""
buf = _BytesIO()
# Save space for message length and request id
buf.write(_ZERO_64)
# responseTo, opCode
buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")
to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf)
# Header - request id and message length
buf.seek(4)
request_id = _randint()
buf.write(_pack_int(request_id))
buf.seek(0)
buf.write(_pack_int(length))
return request_id, buf.getvalue(), to_send
if _use_c:
_batched_op_msg = _cmessage._batched_op_msg
def _do_batched_op_msg(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> tuple[int, bytes, list[Mapping[str, Any]]]:
"""Create the next batched insert, update, or delete operation
using OP_MSG.
"""
command["$db"] = namespace.split(".", 1)[0]
if "writeConcern" in command:
ack = bool(command["writeConcern"].get("w", 1))
else:
ack = True
if ctx.conn.compression_context:
return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx)
return _batched_op_msg(operation, command, docs, ack, opts, ctx)
# End OP_MSG -----------------------------------------------------
def _encode_batched_write_command(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
) -> tuple[bytes, list[Mapping[str, Any]]]:
"""Encode the next batched insert, update, or delete command."""
buf = _BytesIO()
to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf)
return buf.getvalue(), to_send
if _use_c:
_encode_batched_write_command = _cmessage._encode_batched_write_command
def _batched_write_command_impl(
namespace: str,
operation: int,
command: MutableMapping[str, Any],
docs: list[Mapping[str, Any]],
opts: CodecOptions,
ctx: _BulkWriteContext,
buf: _BytesIO,
) -> tuple[list[Mapping[str, Any]], int]:
"""Create a batched OP_QUERY write command."""
max_bson_size = ctx.max_bson_size
max_write_batch_size = ctx.max_write_batch_size
# Max BSON object size + 16k - 2 bytes for ending NUL bytes.
# Server guarantees there is enough room: SERVER-10643.
max_cmd_size = max_bson_size + _COMMAND_OVERHEAD
max_split_size = ctx.max_split_size
# No options
buf.write(_ZERO_32)
# Namespace as C string
buf.write(namespace.encode("utf8"))
buf.write(_ZERO_8)
# Skip: 0, Limit: -1
buf.write(_SKIPLIM)
# Where to write command document length
command_start = buf.tell()
buf.write(encode(command))
# Start of payload
buf.seek(-1, 2)
# Work around some Jython weirdness.
buf.truncate()
try:
buf.write(_OP_MAP[operation])
except KeyError:
raise InvalidOperation("Unknown command") from None
# Where to write list document length
list_start = buf.tell() - 4
to_send = []
idx = 0
for doc in docs:
# Encode the current operation
key = str(idx).encode("utf8")
value = _dict_to_bson(doc, False, opts)
# Is there enough room to add this document? max_cmd_size accounts for
# the two trailing null bytes.
doc_too_large = len(value) > max_cmd_size
if doc_too_large:
write_op = list(_FIELD_MAP.keys())[operation]
_raise_document_too_large(write_op, len(value), max_bson_size)
enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size
enough_documents = idx >= max_write_batch_size
if enough_data or enough_documents:
break
buf.write(_BSONOBJ)
buf.write(key)
buf.write(_ZERO_8)
buf.write(value)
to_send.append(doc)
idx += 1
# Finalize the current OP_QUERY message.
# Close list and command documents
buf.write(_ZERO_16)
# Write document lengths and request id
length = buf.tell()
buf.seek(list_start)
buf.write(_pack_int(length - list_start - 1))
buf.seek(command_start)
buf.write(_pack_int(length - command_start))
return to_send, length
class _OpReply:
"""A MongoDB OP_REPLY response message."""
__slots__ = ("flags", "cursor_id", "number_returned", "documents")
UNPACK_FROM = struct.Struct(" list[bytes]:
"""Check the response header from the database, without decoding BSON.
Check the response for errors and unpack.
Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
OperationFailure.
:param cursor_id: cursor_id we sent to get this response -
used for raising an informative exception when we get cursor id not
valid at server response.
"""
if self.flags & 1:
# Shouldn't get this response if we aren't doing a getMore
if cursor_id is None:
raise ProtocolError("No cursor id for getMore operation")
# Fake a getMore command response. OP_GET_MORE provides no
# document.
msg = "Cursor not found, cursor id: %d" % (cursor_id,)
errobj = {"ok": 0, "errmsg": msg, "code": 43}
raise CursorNotFound(msg, 43, errobj)
elif self.flags & 2:
error_object: dict = bson.BSON(self.documents).decode()
# Fake the ok field if it doesn't exist.
error_object.setdefault("ok", 0)
if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
raise NotPrimaryError(error_object["$err"], error_object)
elif error_object.get("code") == 50:
default_msg = "operation exceeded time limit"
raise ExecutionTimeout(
error_object.get("$err", default_msg), error_object.get("code"), error_object
)
raise OperationFailure(
"database error: %s" % error_object.get("$err"),
error_object.get("code"),
error_object,
)
if self.documents:
return [self.documents]
return []
def unpack_response(
self,
cursor_id: Optional[int] = None,
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[dict[str, Any]]:
"""Unpack a response from the database and decode the BSON document(s).
Check the response for errors and unpack, returning a dictionary
containing the response data.
Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
OperationFailure.
:param cursor_id: cursor_id we sent to get this response -
used for raising an informative exception when we get cursor id not
valid at server response
:param codec_options: an instance of
:class:`~bson.codec_options.CodecOptions`
:param user_fields: Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
self.raw_response(cursor_id)
if legacy_response:
return bson.decode_all(self.documents, codec_options)
return bson._decode_all_selective(self.documents, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
"""Unpack a command response."""
docs = self.unpack_response(codec_options=codec_options)
assert self.number_returned == 1
return docs[0]
def raw_command_response(self) -> NoReturn:
"""Return the bytes of the command response."""
# This should never be called on _OpReply.
raise NotImplementedError
@property
def more_to_come(self) -> bool:
"""Is the moreToCome bit set on this response?"""
return False
@classmethod
def unpack(cls, msg: bytes) -> _OpReply:
"""Construct an _OpReply from raw bytes."""
# PYTHON-945: ignore starting_from field.
flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
documents = msg[20:]
return cls(flags, cursor_id, number_returned, documents)
class _OpMsg:
"""A MongoDB OP_MSG response message."""
__slots__ = ("flags", "cursor_id", "number_returned", "payload_document")
UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]:
"""
cursor_id is ignored
user_fields is used to determine which fields must not be decoded
"""
inflated_response = _decode_selective(
RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS
)
return [inflated_response]
def unpack_response(
self,
cursor_id: Optional[int] = None,
codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS,
user_fields: Optional[Mapping[str, Any]] = None,
legacy_response: bool = False,
) -> list[dict[str, Any]]:
"""Unpack a OP_MSG command response.
:param cursor_id: Ignored, for compatibility with _OpReply.
:param codec_options: an instance of
:class:`~bson.codec_options.CodecOptions`
:param user_fields: Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
"""
# If _OpMsg is in-use, this cannot be a legacy response.
assert not legacy_response
return bson._decode_all_selective(self.payload_document, codec_options, user_fields)
def command_response(self, codec_options: CodecOptions) -> dict[str, Any]:
"""Unpack a command response."""
return self.unpack_response(codec_options=codec_options)[0]
def raw_command_response(self) -> bytes:
"""Return the bytes of the command response."""
return self.payload_document
@property
def more_to_come(self) -> bool:
"""Is the moreToCome bit set on this response?"""
return bool(self.flags & self.MORE_TO_COME)
@classmethod
def unpack(cls, msg: bytes) -> _OpMsg:
"""Construct an _OpMsg from raw bytes."""
flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
if flags != 0:
if flags & cls.CHECKSUM_PRESENT:
raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}")
if flags ^ cls.MORE_TO_COME:
raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}")
if first_payload_type != 0:
raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}")
if len(msg) != first_payload_size + 5:
raise ProtocolError("Unsupported OP_MSG reply: >1 section")
payload_document = msg[5:]
return cls(flags, payload_document)
_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = {
_OpReply.OP_CODE: _OpReply.unpack,
_OpMsg.OP_CODE: _OpMsg.unpack,
}