from __future__ import annotations import codecs import email.message import ipaddress import mimetypes import os import re import time import typing from pathlib import Path from urllib.request import getproxies import sniffio from ._types import PrimitiveData if typing.TYPE_CHECKING: # pragma: no cover from ._urls import URL _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} _HTML5_FORM_ENCODING_REPLACEMENTS.update( {chr(c): "%{:02X}".format(c) for c in range(0x1F + 1) if c != 0x1B} ) _HTML5_FORM_ENCODING_RE = re.compile( r"|".join([re.escape(c) for c in _HTML5_FORM_ENCODING_REPLACEMENTS.keys()]) ) def normalize_header_key( value: str | bytes, lower: bool, encoding: str | None = None, ) -> bytes: """ Coerce str/bytes into a strictly byte-wise HTTP header key. """ if isinstance(value, bytes): bytes_value = value else: bytes_value = value.encode(encoding or "ascii") return bytes_value.lower() if lower else bytes_value def normalize_header_value(value: str | bytes, encoding: str | None = None) -> bytes: """ Coerce str/bytes into a strictly byte-wise HTTP header value. """ if isinstance(value, bytes): return value return value.encode(encoding or "ascii") def primitive_value_to_str(value: PrimitiveData) -> str: """ Coerce a primitive data type into a string value. Note that we prefer JSON-style 'true'/'false' for boolean values here. """ if value is True: return "true" elif value is False: return "false" elif value is None: return "" return str(value) def is_known_encoding(encoding: str) -> bool: """ Return `True` if `encoding` is a known codec. """ try: codecs.lookup(encoding) except LookupError: return False return True def format_form_param(name: str, value: str) -> bytes: """ Encode a name/value pair within a multipart form. """ def replacer(match: typing.Match[str]) -> str: return _HTML5_FORM_ENCODING_REPLACEMENTS[match.group(0)] value = _HTML5_FORM_ENCODING_RE.sub(replacer, value) return f'{name}="{value}"'.encode() def get_ca_bundle_from_env() -> str | None: if "SSL_CERT_FILE" in os.environ: ssl_file = Path(os.environ["SSL_CERT_FILE"]) if ssl_file.is_file(): return str(ssl_file) if "SSL_CERT_DIR" in os.environ: ssl_path = Path(os.environ["SSL_CERT_DIR"]) if ssl_path.is_dir(): return str(ssl_path) return None def parse_header_links(value: str) -> list[dict[str, str]]: """ Returns a list of parsed link headers, for more info see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link The generic syntax of those is: Link: < uri-reference >; param1=value1; param2="value2" So for instance: Link; '; type="image/jpeg",;' would return [ {"url": "http:/.../front.jpeg", "type": "image/jpeg"}, {"url": "http://.../back.jpeg"}, ] :param value: HTTP Link entity-header field :return: list of parsed link headers """ links: list[dict[str, str]] = [] replace_chars = " '\"" value = value.strip(replace_chars) if not value: return links for val in re.split(", *<", value): try: url, params = val.split(";", 1) except ValueError: url, params = val, "" link = {"url": url.strip("<> '\"")} for param in params.split(";"): try: key, value = param.split("=") except ValueError: break link[key.strip(replace_chars)] = value.strip(replace_chars) links.append(link) return links def parse_content_type_charset(content_type: str) -> str | None: # We used to use `cgi.parse_header()` here, but `cgi` became a dead battery. # See: https://peps.python.org/pep-0594/#cgi msg = email.message.Message() msg["content-type"] = content_type return msg.get_content_charset(failobj=None) SENSITIVE_HEADERS = {"authorization", "proxy-authorization"} def obfuscate_sensitive_headers( items: typing.Iterable[tuple[typing.AnyStr, typing.AnyStr]], ) -> typing.Iterator[tuple[typing.AnyStr, typing.AnyStr]]: for k, v in items: if to_str(k.lower()) in SENSITIVE_HEADERS: v = to_bytes_or_str("[secure]", match_type_of=v) yield k, v def port_or_default(url: URL) -> int | None: if url.port is not None: return url.port return {"http": 80, "https": 443}.get(url.scheme) def same_origin(url: URL, other: URL) -> bool: """ Return 'True' if the given URLs share the same origin. """ return ( url.scheme == other.scheme and url.host == other.host and port_or_default(url) == port_or_default(other) ) def is_https_redirect(url: URL, location: URL) -> bool: """ Return 'True' if 'location' is a HTTPS upgrade of 'url' """ if url.host != location.host: return False return ( url.scheme == "http" and port_or_default(url) == 80 and location.scheme == "https" and port_or_default(location) == 443 ) def get_environment_proxies() -> dict[str, str | None]: """Gets proxy information from the environment""" # urllib.request.getproxies() falls back on System # Registry and Config for proxies on Windows and macOS. # We don't want to propagate non-HTTP proxies into # our configuration such as 'TRAVIS_APT_PROXY'. proxy_info = getproxies() mounts: dict[str, str | None] = {} for scheme in ("http", "https", "all"): if proxy_info.get(scheme): hostname = proxy_info[scheme] mounts[f"{scheme}://"] = ( hostname if "://" in hostname else f"http://{hostname}" ) no_proxy_hosts = [host.strip() for host in proxy_info.get("no", "").split(",")] for hostname in no_proxy_hosts: # See https://curl.haxx.se/libcurl/c/CURLOPT_NOPROXY.html for details # on how names in `NO_PROXY` are handled. if hostname == "*": # If NO_PROXY=* is used or if "*" occurs as any one of the comma # separated hostnames, then we should just bypass any information # from HTTP_PROXY, HTTPS_PROXY, ALL_PROXY, and always ignore # proxies. return {} elif hostname: # NO_PROXY=.google.com is marked as "all://*.google.com, # which disables "www.google.com" but not "google.com" # NO_PROXY=google.com is marked as "all://*google.com, # which disables "www.google.com" and "google.com". # (But not "wwwgoogle.com") # NO_PROXY can include domains, IPv6, IPv4 addresses and "localhost" # NO_PROXY=example.com,::1,localhost,192.168.0.0/16 if "://" in hostname: mounts[hostname] = None elif is_ipv4_hostname(hostname): mounts[f"all://{hostname}"] = None elif is_ipv6_hostname(hostname): mounts[f"all://[{hostname}]"] = None elif hostname.lower() == "localhost": mounts[f"all://{hostname}"] = None else: mounts[f"all://*{hostname}"] = None return mounts def to_bytes(value: str | bytes, encoding: str = "utf-8") -> bytes: return value.encode(encoding) if isinstance(value, str) else value def to_str(value: str | bytes, encoding: str = "utf-8") -> str: return value if isinstance(value, str) else value.decode(encoding) def to_bytes_or_str(value: str, match_type_of: typing.AnyStr) -> typing.AnyStr: return value if isinstance(match_type_of, str) else value.encode() def unquote(value: str) -> str: return value[1:-1] if value[0] == value[-1] == '"' else value def guess_content_type(filename: str | None) -> str | None: if filename: return mimetypes.guess_type(filename)[0] or "application/octet-stream" return None def peek_filelike_length(stream: typing.Any) -> int | None: """ Given a file-like stream object, return its length in number of bytes without reading it into memory. """ try: # Is it an actual file? fd = stream.fileno() # Yup, seems to be an actual file. length = os.fstat(fd).st_size except (AttributeError, OSError): # No... Maybe it's something that supports random access, like `io.BytesIO`? try: # Assuming so, go to end of stream to figure out its length, # then put it back in place. offset = stream.tell() length = stream.seek(0, os.SEEK_END) stream.seek(offset) except (AttributeError, OSError): # Not even that? Sorry, we're doomed... return None return length class Timer: async def _get_time(self) -> float: library = sniffio.current_async_library() if library == "trio": import trio return trio.current_time() else: import asyncio return asyncio.get_event_loop().time() def sync_start(self) -> None: self.started = time.perf_counter() async def async_start(self) -> None: self.started = await self._get_time() def sync_elapsed(self) -> float: now = time.perf_counter() return now - self.started async def async_elapsed(self) -> float: now = await self._get_time() return now - self.started class URLPattern: """ A utility class currently used for making lookups against proxy keys... # Wildcard matching... >>> pattern = URLPattern("all://") >>> pattern.matches(httpx.URL("http://example.com")) True # Witch scheme matching... >>> pattern = URLPattern("https://") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) False # With domain matching... >>> pattern = URLPattern("https://example.com") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) False >>> pattern.matches(httpx.URL("https://other.com")) False # Wildcard scheme, with domain matching... >>> pattern = URLPattern("all://example.com") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) True >>> pattern.matches(httpx.URL("https://other.com")) False # With port matching... >>> pattern = URLPattern("https://example.com:1234") >>> pattern.matches(httpx.URL("https://example.com:1234")) True >>> pattern.matches(httpx.URL("https://example.com")) False """ def __init__(self, pattern: str) -> None: from ._urls import URL if pattern and ":" not in pattern: raise ValueError( f"Proxy keys should use proper URL forms rather " f"than plain scheme strings. " f'Instead of "{pattern}", use "{pattern}://"' ) url = URL(pattern) self.pattern = pattern self.scheme = "" if url.scheme == "all" else url.scheme self.host = "" if url.host == "*" else url.host self.port = url.port if not url.host or url.host == "*": self.host_regex: typing.Pattern[str] | None = None elif url.host.startswith("*."): # *.example.com should match "www.example.com", but not "example.com" domain = re.escape(url.host[2:]) self.host_regex = re.compile(f"^.+\\.{domain}$") elif url.host.startswith("*"): # *example.com should match "www.example.com" and "example.com" domain = re.escape(url.host[1:]) self.host_regex = re.compile(f"^(.+\\.)?{domain}$") else: # example.com should match "example.com" but not "www.example.com" domain = re.escape(url.host) self.host_regex = re.compile(f"^{domain}$") def matches(self, other: URL) -> bool: if self.scheme and self.scheme != other.scheme: return False if ( self.host and self.host_regex is not None and not self.host_regex.match(other.host) ): return False if self.port is not None and self.port != other.port: return False return True @property def priority(self) -> tuple[int, int, int]: """ The priority allows URLPattern instances to be sortable, so that we can match from most specific to least specific. """ # URLs with a port should take priority over URLs without a port. port_priority = 0 if self.port is not None else 1 # Longer hostnames should match first. host_priority = -len(self.host) # Longer schemes should match first. scheme_priority = -len(self.scheme) return (port_priority, host_priority, scheme_priority) def __hash__(self) -> int: return hash(self.pattern) def __lt__(self, other: URLPattern) -> bool: return self.priority < other.priority def __eq__(self, other: typing.Any) -> bool: return isinstance(other, URLPattern) and self.pattern == other.pattern def is_ipv4_hostname(hostname: str) -> bool: try: ipaddress.IPv4Address(hostname.split("/")[0]) except Exception: return False return True def is_ipv6_hostname(hostname: str) -> bool: try: ipaddress.IPv6Address(hostname.split("/")[0]) except Exception: return False return True