""" Handlers for Content-Encoding. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding """ from __future__ import annotations import codecs import io import typing import zlib from ._compat import brotli, zstd from ._exceptions import DecodingError class ContentDecoder: def decode(self, data: bytes) -> bytes: raise NotImplementedError() # pragma: no cover def flush(self) -> bytes: raise NotImplementedError() # pragma: no cover class IdentityDecoder(ContentDecoder): """ Handle unencoded data. """ def decode(self, data: bytes) -> bytes: return data def flush(self) -> bytes: return b"" class DeflateDecoder(ContentDecoder): """ Handle 'deflate' decoding. See: https://stackoverflow.com/questions/1838699 """ def __init__(self) -> None: self.first_attempt = True self.decompressor = zlib.decompressobj() def decode(self, data: bytes) -> bytes: was_first_attempt = self.first_attempt self.first_attempt = False try: return self.decompressor.decompress(data) except zlib.error as exc: if was_first_attempt: self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) return self.decode(data) raise DecodingError(str(exc)) from exc def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: no cover raise DecodingError(str(exc)) from exc class GZipDecoder(ContentDecoder): """ Handle 'gzip' decoding. See: https://stackoverflow.com/questions/1838699 """ def __init__(self) -> None: self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) def decode(self, data: bytes) -> bytes: try: return self.decompressor.decompress(data) except zlib.error as exc: raise DecodingError(str(exc)) from exc def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: no cover raise DecodingError(str(exc)) from exc class BrotliDecoder(ContentDecoder): """ Handle 'brotli' decoding. Requires `pip install brotlipy`. See: https://brotlipy.readthedocs.io/ or `pip install brotli`. See https://github.com/google/brotli Supports both 'brotlipy' and 'Brotli' packages since they share an import name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' """ def __init__(self) -> None: if brotli is None: # pragma: no cover raise ImportError( "Using 'BrotliDecoder', but neither of the 'brotlicffi' or 'brotli' " "packages have been installed. " "Make sure to install httpx using `pip install httpx[brotli]`." ) from None self.decompressor = brotli.Decompressor() self.seen_data = False self._decompress: typing.Callable[[bytes], bytes] if hasattr(self.decompressor, "decompress"): # The 'brotlicffi' package. self._decompress = self.decompressor.decompress # pragma: no cover else: # The 'brotli' package. self._decompress = self.decompressor.process # pragma: no cover def decode(self, data: bytes) -> bytes: if not data: return b"" self.seen_data = True try: return self._decompress(data) except brotli.error as exc: raise DecodingError(str(exc)) from exc def flush(self) -> bytes: if not self.seen_data: return b"" try: if hasattr(self.decompressor, "finish"): # Only available in the 'brotlicffi' package. # As the decompressor decompresses eagerly, this # will never actually emit any data. However, it will potentially throw # errors if a truncated or damaged data stream has been used. self.decompressor.finish() # pragma: no cover return b"" except brotli.error as exc: # pragma: no cover raise DecodingError(str(exc)) from exc class ZStandardDecoder(ContentDecoder): """ Handle 'zstd' RFC 8878 decoding. Requires `pip install zstandard`. Can be installed as a dependency of httpx using `pip install httpx[zstd]`. """ # inspired by the ZstdDecoder implementation in urllib3 def __init__(self) -> None: if zstd is None: # pragma: no cover raise ImportError( "Using 'ZStandardDecoder', ..." "Make sure to install httpx using `pip install httpx[zstd]`." ) from None self.decompressor = zstd.ZstdDecompressor().decompressobj() def decode(self, data: bytes) -> bytes: assert zstd is not None output = io.BytesIO() try: output.write(self.decompressor.decompress(data)) while self.decompressor.eof and self.decompressor.unused_data: unused_data = self.decompressor.unused_data self.decompressor = zstd.ZstdDecompressor().decompressobj() output.write(self.decompressor.decompress(unused_data)) except zstd.ZstdError as exc: raise DecodingError(str(exc)) from exc return output.getvalue() def flush(self) -> bytes: ret = self.decompressor.flush() # note: this is a no-op if not self.decompressor.eof: raise DecodingError("Zstandard data is incomplete") # pragma: no cover return bytes(ret) class MultiDecoder(ContentDecoder): """ Handle the case where multiple encodings have been applied. """ def __init__(self, children: typing.Sequence[ContentDecoder]) -> None: """ 'children' should be a sequence of decoders in the order in which each was applied. """ # Note that we reverse the order for decoding. self.children = list(reversed(children)) def decode(self, data: bytes) -> bytes: for child in self.children: data = child.decode(data) return data def flush(self) -> bytes: data = b"" for child in self.children: data = child.decode(data) + child.flush() return data class ByteChunker: """ Handles returning byte content in fixed-size chunks. """ def __init__(self, chunk_size: int | None = None) -> None: self._buffer = io.BytesIO() self._chunk_size = chunk_size def decode(self, content: bytes) -> list[bytes]: if self._chunk_size is None: return [content] if content else [] self._buffer.write(content) if self._buffer.tell() >= self._chunk_size: value = self._buffer.getvalue() chunks = [ value[i : i + self._chunk_size] for i in range(0, len(value), self._chunk_size) ] if len(chunks[-1]) == self._chunk_size: self._buffer.seek(0) self._buffer.truncate() return chunks else: self._buffer.seek(0) self._buffer.write(chunks[-1]) self._buffer.truncate() return chunks[:-1] else: return [] def flush(self) -> list[bytes]: value = self._buffer.getvalue() self._buffer.seek(0) self._buffer.truncate() return [value] if value else [] class TextChunker: """ Handles returning text content in fixed-size chunks. """ def __init__(self, chunk_size: int | None = None) -> None: self._buffer = io.StringIO() self._chunk_size = chunk_size def decode(self, content: str) -> list[str]: if self._chunk_size is None: return [content] if content else [] self._buffer.write(content) if self._buffer.tell() >= self._chunk_size: value = self._buffer.getvalue() chunks = [ value[i : i + self._chunk_size] for i in range(0, len(value), self._chunk_size) ] if len(chunks[-1]) == self._chunk_size: self._buffer.seek(0) self._buffer.truncate() return chunks else: self._buffer.seek(0) self._buffer.write(chunks[-1]) self._buffer.truncate() return chunks[:-1] else: return [] def flush(self) -> list[str]: value = self._buffer.getvalue() self._buffer.seek(0) self._buffer.truncate() return [value] if value else [] class TextDecoder: """ Handles incrementally decoding bytes into text """ def __init__(self, encoding: str = "utf-8") -> None: self.decoder = codecs.getincrementaldecoder(encoding)(errors="replace") def decode(self, data: bytes) -> str: return self.decoder.decode(data) def flush(self) -> str: return self.decoder.decode(b"", True) class LineDecoder: """ Handles incrementally reading lines from text. Has the same behaviour as the stdllib splitlines, but handling the input iteratively. """ def __init__(self) -> None: self.buffer: list[str] = [] self.trailing_cr: bool = False def decode(self, text: str) -> list[str]: # See https://docs.python.org/3/library/stdtypes.html#str.splitlines NEWLINE_CHARS = "\n\r\x0b\x0c\x1c\x1d\x1e\x85\u2028\u2029" # We always push a trailing `\r` into the next decode iteration. if self.trailing_cr: text = "\r" + text self.trailing_cr = False if text.endswith("\r"): self.trailing_cr = True text = text[:-1] if not text: # NOTE: the edge case input of empty text doesn't occur in practice, # because other httpx internals filter out this value return [] # pragma: no cover trailing_newline = text[-1] in NEWLINE_CHARS lines = text.splitlines() if len(lines) == 1 and not trailing_newline: # No new lines, buffer the input and continue. self.buffer.append(lines[0]) return [] if self.buffer: # Include any existing buffer in the first portion of the # splitlines result. lines = ["".join(self.buffer) + lines[0]] + lines[1:] self.buffer = [] if not trailing_newline: # If the last segment of splitlines is not newline terminated, # then drop it from our output and start a new buffer. self.buffer = [lines.pop()] return lines def flush(self) -> list[str]: if not self.buffer and not self.trailing_cr: return [] lines = ["".join(self.buffer)] self.buffer = [] self.trailing_cr = False return lines SUPPORTED_DECODERS = { "identity": IdentityDecoder, "gzip": GZipDecoder, "deflate": DeflateDecoder, "br": BrotliDecoder, "zstd": ZStandardDecoder, } if brotli is None: SUPPORTED_DECODERS.pop("br") # pragma: no cover if zstd is None: SUPPORTED_DECODERS.pop("zstd") # pragma: no cover