# mypy: ignore-errors from typing import Any, Iterable from ._vendor.packaging.version import InvalidVersion, Version from .version import __version__ as internal_version __all__ = ["TorchVersion"] class TorchVersion(str): """A string with magic powers to compare to both Version and iterables! Prior to 1.10.0 torch.__version__ was stored as a str and so many did comparisons against torch.__version__ as if it were a str. In order to not break them we have TorchVersion which masquerades as a str while also having the ability to compare against both packaging.version.Version as well as tuples of values, eg. (1, 2, 1) Examples: Comparing a TorchVersion object to a Version object TorchVersion('1.10.0a') > Version('1.10.0a') Comparing a TorchVersion object to a Tuple object TorchVersion('1.10.0a') > (1, 2) # 1.2 TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 Comparing a TorchVersion object against a string TorchVersion('1.10.0a') > '1.2' TorchVersion('1.10.0a') > '1.2.1' """ # fully qualified type names here to appease mypy def _convert_to_version(self, inp: Any) -> Any: if isinstance(inp, Version): return inp elif isinstance(inp, str): return Version(inp) elif isinstance(inp, Iterable): # Ideally this should work for most cases by attempting to group # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) # Examples: # * (1) -> Version("1") # * (1, 20) -> Version("1.20") # * (1, 20, 1) -> Version("1.20.1") return Version(".".join(str(item) for item in inp)) else: raise InvalidVersion(inp) def _cmp_wrapper(self, cmp: Any, method: str) -> bool: try: return getattr(Version(self), method)(self._convert_to_version(cmp)) except BaseException as e: if not isinstance(e, InvalidVersion): raise # Fall back to regular string comparison if dealing with an invalid # version like 'parrot' return getattr(super(), method)(cmp) for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: setattr( TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method), ) __version__ = TorchVersion(internal_version)