# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Tools to parse and validate a MongoDB URI.""" from __future__ import annotations import re import sys import warnings from typing import ( TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Sized, Union, cast, ) from urllib.parse import unquote_plus from pymongo.client_options import _parse_ssl_options from pymongo.common import ( INTERNAL_URI_OPTION_NAME_MAP, SRV_SERVICE_NAME, URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary, get_validated_options, ) from pymongo.errors import ConfigurationError, InvalidURI from pymongo.srv_resolver import _have_dnspython, _SrvResolver from pymongo.typings import _Address if TYPE_CHECKING: from pymongo.pyopenssl_context import SSLContext SCHEME = "mongodb://" SCHEME_LEN = len(SCHEME) SRV_SCHEME = "mongodb+srv://" SRV_SCHEME_LEN = len(SRV_SCHEME) DEFAULT_PORT = 27017 def _unquoted_percent(s: str) -> bool: """Check for unescaped percent signs. :param s: A string. `s` can have things like '%25', '%2525', and '%E2%85%A8' but cannot have unquoted percent like '%foo'. """ for i in range(len(s)): if s[i] == "%": sub = s[i : i + 3] # If unquoting yields the same string this means there was an # unquoted %. if unquote_plus(sub) == sub: return True return False def parse_userinfo(userinfo: str) -> tuple[str, str]: """Validates the format of user information in a MongoDB URI. Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", "]", "@") as per RFC 3986 must be escaped. Returns a 2-tuple containing the unescaped username followed by the unescaped password. :param userinfo: A string of the form : """ if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): raise InvalidURI( "Username and password must be escaped according to " "RFC 3986, use urllib.parse.quote_plus" ) user, _, passwd = userinfo.partition(":") # No password is expected with GSSAPI authentication. if not user: raise InvalidURI("The empty string is not valid username.") return unquote_plus(user), unquote_plus(passwd) def parse_ipv6_literal_host( entity: str, default_port: Optional[int] ) -> tuple[str, Optional[Union[str, int]]]: """Validates an IPv6 literal host:port string. Returns a 2-tuple of IPv6 literal followed by port where port is default_port if it wasn't specified in entity. :param entity: A string that represents an IPv6 literal enclosed in braces (e.g. '[::1]' or '[::1]:27017'). :param default_port: The port number to use when one wasn't specified in entity. """ if entity.find("]") == -1: raise ValueError( "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." ) i = entity.find("]:") if i == -1: return entity[1:-1], default_port return entity[1:i], entity[i + 2 :] def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: """Validates a host string Returns a 2-tuple of host followed by port where port is default_port if it wasn't specified in the string. :param entity: A host or host:port string where host could be a hostname or IP address. :param default_port: The port number to use when one wasn't specified in entity. """ host = entity port: Optional[Union[str, int]] = default_port if entity[0] == "[": host, port = parse_ipv6_literal_host(entity, default_port) elif entity.endswith(".sock"): return entity, default_port elif entity.find(":") != -1: if entity.count(":") > 1: raise ValueError( "Reserved characters such as ':' must be " "escaped according RFC 2396. An IPv6 " "address literal must be enclosed in '[' " "and ']' according to RFC 2732." ) host, port = host.split(":", 1) if isinstance(port, str): if not port.isdigit() or int(port) > 65535 or int(port) <= 0: raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") port = int(port) # Normalize hostname to lowercase, since DNS is case-insensitive: # http://tools.ietf.org/html/rfc4343 # This prevents useless rediscovery if "foo.com" is in the seed list but # "FOO.com" is in the hello response. return host.lower(), port # Options whose values are implicitly determined by tlsInsecure. _IMPLICIT_TLSINSECURE_OPTS = { "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsdisableocspendpointcheck", } def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: """Helper method for split_options which creates the options dict. Also handles the creation of a list for the URI tag_sets/ readpreferencetags portion, and the use of a unicode options string. """ options = _CaseInsensitiveDictionary() for uriopt in opts.split(delim): key, value = uriopt.split("=") if key.lower() == "readpreferencetags": options.setdefault(key, []).append(value) else: if key in options: warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) if key.lower() == "authmechanismproperties": val = value else: val = unquote_plus(value) options[key] = val return options def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: """Raise appropriate errors when conflicting TLS options are present in the options dictionary. :param options: Instance of _CaseInsensitiveDictionary containing MongoDB URI options. """ # Implicitly defined options must not be explicitly specified. tlsinsecure = options.get("tlsinsecure") if tlsinsecure is not None: for opt in _IMPLICIT_TLSINSECURE_OPTS: if opt in options: err_msg = "URI options %s and %s cannot be specified simultaneously." raise InvalidURI( err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) ) # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") if tlsallowinvalidcerts is not None: if "tlsdisableocspendpointcheck" in options: err_msg = "URI options %s and %s cannot be specified simultaneously." raise InvalidURI( err_msg % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) ) if tlsallowinvalidcerts is True: options["tlsdisableocspendpointcheck"] = True # Handle co-occurence of CRL and OCSP-related options. tlscrlfile = options.get("tlscrlfile") if tlscrlfile is not None: for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): if options.get(opt) is True: err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." raise InvalidURI(err_msg % (opt,)) if "ssl" in options and "tls" in options: def truth_value(val: Any) -> Any: if val in ("true", "false"): return val == "true" if isinstance(val, bool): return val return val if truth_value(options.get("ssl")) != truth_value(options.get("tls")): err_msg = "Can not specify conflicting values for URI options %s and %s." raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) return options def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: """Issue appropriate warnings when deprecated options are present in the options dictionary. Removes deprecated option key, value pairs if the options dictionary is found to also have the renamed option. :param options: Instance of _CaseInsensitiveDictionary containing MongoDB URI options. """ for optname in list(options): if optname in URI_OPTIONS_DEPRECATION_MAP: mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] if mode == "renamed": newoptname = message if newoptname in options: warn_msg = "Deprecated option '%s' ignored in favor of '%s'." warnings.warn( warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), DeprecationWarning, stacklevel=2, ) options.pop(optname) continue warn_msg = "Option '%s' is deprecated, use '%s' instead." warnings.warn( warn_msg % (options.cased_key(optname), newoptname), DeprecationWarning, stacklevel=2, ) elif mode == "removed": warn_msg = "Option '%s' is deprecated. %s." warnings.warn( warn_msg % (options.cased_key(optname), message), DeprecationWarning, stacklevel=2, ) return options def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: """Normalizes option names in the options dictionary by converting them to their internally-used names. :param options: Instance of _CaseInsensitiveDictionary containing MongoDB URI options. """ # Expand the tlsInsecure option. tlsinsecure = options.get("tlsinsecure") if tlsinsecure is not None: for opt in _IMPLICIT_TLSINSECURE_OPTS: # Implicit options are logically the same as tlsInsecure. options[opt] = tlsinsecure for optname in list(options): intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) if intname is not None: options[intname] = options.pop(optname) return options def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: """Validates and normalizes options passed in a MongoDB URI. Returns a new dictionary of validated and normalized options. If warn is False then errors will be thrown for invalid options, otherwise they will be ignored and a warning will be issued. :param opts: A dict of MongoDB URI options. :param warn: If ``True`` then warnings will be logged and invalid options will be ignored. Otherwise invalid options will cause errors. """ return get_validated_options(opts, warn) def split_options( opts: str, validate: bool = True, warn: bool = False, normalize: bool = True ) -> MutableMapping[str, Any]: """Takes the options portion of a MongoDB URI, validates each option and returns the options in a dictionary. :param opt: A string representing MongoDB URI options. :param validate: If ``True`` (the default), validate and normalize all options. :param warn: If ``False`` (the default), suppress all warnings raised during validation of options. :param normalize: If ``True`` (the default), renames all options to their internally-used names. """ and_idx = opts.find("&") semi_idx = opts.find(";") try: if and_idx >= 0 and semi_idx >= 0: raise InvalidURI("Can not mix '&' and ';' for option separators.") elif and_idx >= 0: options = _parse_options(opts, "&") elif semi_idx >= 0: options = _parse_options(opts, ";") elif opts.find("=") != -1: options = _parse_options(opts, None) else: raise ValueError except ValueError: raise InvalidURI("MongoDB URI options are key=value pairs.") from None options = _handle_security_options(options) options = _handle_option_deprecations(options) if normalize: options = _normalize_options(options) if validate: options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) if options.get("authsource") == "": raise InvalidURI("the authSource database cannot be an empty string") return options def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: """Takes a string of the form host1[:port],host2[:port]... and splits it into (host, port) tuples. If [:port] isn't present the default_port is used. Returns a set of 2-tuples containing the host name (or IP) followed by port number. :param hosts: A string of the form host1[:port],host2[:port],... :param default_port: The port number to use when one wasn't specified for a host. """ nodes = [] for entity in hosts.split(","): if not entity: raise ConfigurationError("Empty host (or extra comma in host list).") port = default_port # Unix socket entities don't have ports if entity.endswith(".sock"): port = None nodes.append(parse_host(entity, port)) return nodes # Prohibited characters in database name. DB names also can't have ".", but for # backward-compat we allow "db.collection" in URI. _BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") _ALLOWED_TXT_OPTS = frozenset( ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] ) def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: # Ensure directConnection was not True if there are multiple seeds. if len(nodes) > 1 and options.get("directconnection"): raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") if options.get("loadbalanced"): if len(nodes) > 1: raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") if options.get("directconnection"): raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") if options.get("replicaset"): raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") def parse_uri( uri: str, default_port: Optional[int] = DEFAULT_PORT, validate: bool = True, warn: bool = False, normalize: bool = True, connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. Returns a dict of the form:: { 'nodelist': , 'username': or None, 'password': or None, 'database': or None, 'collection': or None, 'options': , 'fqdn': or None } If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done to build nodelist and options. :param uri: The MongoDB URI to parse. :param default_port: The port number to use when one wasn't specified for a host in the URI. :param validate: If ``True`` (the default), validate and normalize all options. Default: ``True``. :param warn: When validating, if ``True`` then will warn the user then ignore any invalid options or values. If ``False``, validation will error when options are unsupported or values are invalid. Default: ``False``. :param normalize: If ``True``, convert names of URI options to their internally-used names. Default: ``True``. :param connect_timeout: The maximum time in milliseconds to wait for a response from the DNS server. :param srv_service_name: A custom SRV service name .. versionchanged:: 4.6 The delimiting slash (``/``) between hosts and connection options is now optional. For example, "mongodb://example.com?tls=true" is now a valid URI. .. versionchanged:: 4.0 To better follow RFC 3986, unquoted percent signs ("%") are no longer supported. .. versionchanged:: 3.9 Added the ``normalize`` parameter. .. versionchanged:: 3.6 Added support for mongodb+srv:// URIs. .. versionchanged:: 3.5 Return the original value of the ``readPreference`` MongoDB URI option instead of the validated read preference mode. .. versionchanged:: 3.1 ``warn`` added so invalid options can be ignored. """ if uri.startswith(SCHEME): is_srv = False scheme_free = uri[SCHEME_LEN:] elif uri.startswith(SRV_SCHEME): if not _have_dnspython(): python_path = sys.executable or "python" raise ConfigurationError( 'The "dnspython" module must be ' "installed to use mongodb+srv:// URIs. " "To fix this error install pymongo again:\n " "%s -m pip install pymongo>=4.3" % (python_path) ) is_srv = True scheme_free = uri[SRV_SCHEME_LEN:] else: raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") if not scheme_free: raise InvalidURI("Must provide at least one hostname or IP.") user = None passwd = None dbase = None collection = None options = _CaseInsensitiveDictionary() host_plus_db_part, _, opts = scheme_free.partition("?") if "/" in host_plus_db_part: host_part, _, dbase = host_plus_db_part.partition("/") else: host_part = host_plus_db_part if dbase: dbase = unquote_plus(dbase) if "." in dbase: dbase, collection = dbase.split(".", 1) if _BAD_DB_CHARS.search(dbase): raise InvalidURI('Bad database name "%s"' % dbase) else: dbase = None if opts: options.update(split_options(opts, validate, warn, normalize)) if srv_service_name is None: srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) if "@" in host_part: userinfo, _, hosts = host_part.rpartition("@") user, passwd = parse_userinfo(userinfo) else: hosts = host_part if "/" in hosts: raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) hosts = unquote_plus(hosts) fqdn = None srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") if is_srv: if options.get("directConnection"): raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") nodes = split_hosts(hosts, default_port=None) if len(nodes) != 1: raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") fqdn, port = nodes[0] if port is not None: raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) nodes = dns_resolver.get_hosts() dns_options = dns_resolver.get_options() if dns_options: parsed_dns_options = split_options(dns_options, validate, warn, normalize) if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: raise ConfigurationError( "Only authSource, replicaSet, and loadBalanced are supported from DNS" ) for opt, val in parsed_dns_options.items(): if opt not in options: options[opt] = val if options.get("loadBalanced") and srv_max_hosts: raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") if options.get("replicaSet") and srv_max_hosts: raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") if "tls" not in options and "ssl" not in options: options["tls"] = True if validate else "true" elif not is_srv and options.get("srvServiceName") is not None: raise ConfigurationError( "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" ) elif not is_srv and srv_max_hosts: raise ConfigurationError( "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" ) else: nodes = split_hosts(hosts, default_port=default_port) _check_options(nodes, options) return { "nodelist": nodes, "username": user, "password": passwd, "database": dbase, "collection": collection, "options": options, "fqdn": fqdn, } def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: """Parse KMS TLS connection options.""" if not kms_tls_options: return {} if not isinstance(kms_tls_options, dict): raise TypeError("kms_tls_options must be a dict") contexts = {} for provider, options in kms_tls_options.items(): if not isinstance(options, dict): raise TypeError(f'kms_tls_options["{provider}"] must be a dict') options.setdefault("tls", True) opts = _CaseInsensitiveDictionary(options) opts = _handle_security_options(opts) opts = _normalize_options(opts) opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) if ssl_context is None: raise ConfigurationError("TLS is required for KMS providers") if allow_invalid_hostnames: raise ConfigurationError("Insecure TLS options prohibited") for n in [ "tlsInsecure", "tlsAllowInvalidCertificates", "tlsAllowInvalidHostnames", "tlsDisableCertificateRevocationCheck", ]: if n in opts: raise ConfigurationError(f"Insecure TLS options prohibited: {n}") contexts[provider] = ssl_context return contexts if __name__ == "__main__": import pprint try: pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 except InvalidURI as exc: print(exc) # noqa: T201 sys.exit(0)