# Copyright 2018 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import warnings from typing import Any, Iterable, Optional, Union from pymongo.hello import HelloCompat from pymongo.helpers import _SENSITIVE_COMMANDS _SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) def _have_snappy() -> bool: try: import snappy # type:ignore[import] # noqa: F401 return True except ImportError: return False def _have_zlib() -> bool: try: import zlib # noqa: F401 return True except ImportError: return False def _have_zstd() -> bool: try: import zstandard # noqa: F401 return True except ImportError: return False def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]: try: # `value` is string. compressors = value.split(",") # type: ignore[union-attr] except AttributeError: # `value` is an iterable. compressors = list(value) for compressor in compressors[:]: if compressor not in _SUPPORTED_COMPRESSORS: compressors.remove(compressor) warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2) elif compressor == "snappy" and not _have_snappy(): compressors.remove(compressor) warnings.warn( "Wire protocol compression with snappy is not available. " "You must install the python-snappy module for snappy support.", stacklevel=2, ) elif compressor == "zlib" and not _have_zlib(): compressors.remove(compressor) warnings.warn( "Wire protocol compression with zlib is not available. " "The zlib module is not available.", stacklevel=2, ) elif compressor == "zstd" and not _have_zstd(): compressors.remove(compressor) warnings.warn( "Wire protocol compression with zstandard is not available. " "You must install the zstandard module for zstandard support.", stacklevel=2, ) return compressors def validate_zlib_compression_level(option: str, value: Any) -> int: try: level = int(value) except Exception: raise TypeError(f"{option} must be an integer, not {value!r}.") from None if level < -1 or level > 9: raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) return level class CompressionSettings: def __init__(self, compressors: list[str], zlib_compression_level: int): self.compressors = compressors self.zlib_compression_level = zlib_compression_level def get_compression_context( self, compressors: Optional[list[str]] ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: if compressors: chosen = compressors[0] if chosen == "snappy": return SnappyContext() elif chosen == "zlib": return ZlibContext(self.zlib_compression_level) elif chosen == "zstd": return ZstdContext() return None return None class SnappyContext: compressor_id = 1 @staticmethod def compress(data: bytes) -> bytes: import snappy return snappy.compress(data) class ZlibContext: compressor_id = 2 def __init__(self, level: int): self.level = level def compress(self, data: bytes) -> bytes: import zlib return zlib.compress(data, self.level) class ZstdContext: compressor_id = 3 @staticmethod def compress(data: bytes) -> bytes: # ZstdCompressor is not thread safe. # TODO: Use a pool? import zstandard return zstandard.ZstdCompressor().compress(data) def decompress(data: bytes, compressor_id: int) -> bytes: if compressor_id == SnappyContext.compressor_id: # python-snappy doesn't support the buffer interface. # https://github.com/andrix/python-snappy/issues/65 # This only matters when data is a memoryview since # id(bytes(data)) == id(data) when data is a bytes. import snappy return snappy.uncompress(bytes(data)) elif compressor_id == ZlibContext.compressor_id: import zlib return zlib.decompress(data) elif compressor_id == ZstdContext.compressor_id: # ZstdDecompressor is not thread safe. # TODO: Use a pool? import zstandard return zstandard.ZstdDecompressor().decompress(data) else: raise ValueError("Unknown compressorId %d" % (compressor_id,))