"""Implementation of :class:`ModularInteger` class. """ from __future__ import annotations from typing import Any import operator from sympy.polys.polyutils import PicklableWithSlots from sympy.polys.polyerrors import CoercionFailed from sympy.polys.domains.domainelement import DomainElement from sympy.utilities import public from sympy.utilities.exceptions import sympy_deprecation_warning @public class ModularInteger(PicklableWithSlots, DomainElement): """A class representing a modular integer. """ mod, dom, sym, _parent = None, None, None, None __slots__ = ('val',) def parent(self): return self._parent def __init__(self, val): if isinstance(val, self.__class__): self.val = val.val % self.mod else: self.val = self.dom.convert(val) % self.mod def modulus(self): return self.mod def __hash__(self): return hash((self.val, self.mod)) def __repr__(self): return "%s(%s)" % (self.__class__.__name__, self.val) def __str__(self): return "%s mod %s" % (self.val, self.mod) def __int__(self): return int(self.val) def to_int(self): sympy_deprecation_warning( """ModularInteger.to_int() is deprecated. Use int(a) or K = GF(p) and K.to_int(a) instead of a.to_int(). """, deprecated_since_version="1.13", active_deprecations_target="modularinteger-to-int", ) if self.sym: if self.val <= self.mod // 2: return self.val else: return self.val - self.mod else: return self.val def __pos__(self): return self def __neg__(self): return self.__class__(-self.val) @classmethod def _get_val(cls, other): if isinstance(other, cls): return other.val else: try: return cls.dom.convert(other) except CoercionFailed: return None def __add__(self, other): val = self._get_val(other) if val is not None: return self.__class__(self.val + val) else: return NotImplemented def __radd__(self, other): return self.__add__(other) def __sub__(self, other): val = self._get_val(other) if val is not None: return self.__class__(self.val - val) else: return NotImplemented def __rsub__(self, other): return (-self).__add__(other) def __mul__(self, other): val = self._get_val(other) if val is not None: return self.__class__(self.val * val) else: return NotImplemented def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): val = self._get_val(other) if val is not None: return self.__class__(self.val * self._invert(val)) else: return NotImplemented def __rtruediv__(self, other): return self.invert().__mul__(other) def __mod__(self, other): val = self._get_val(other) if val is not None: return self.__class__(self.val % val) else: return NotImplemented def __rmod__(self, other): val = self._get_val(other) if val is not None: return self.__class__(val % self.val) else: return NotImplemented def __pow__(self, exp): if not exp: return self.__class__(self.dom.one) if exp < 0: val, exp = self.invert().val, -exp else: val = self.val return self.__class__(pow(val, int(exp), self.mod)) def _compare(self, other, op): val = self._get_val(other) if val is None: return NotImplemented return op(self.val, val % self.mod) def _compare_deprecated(self, other, op): val = self._get_val(other) if val is None: return NotImplemented sympy_deprecation_warning( """Ordered comparisons with modular integers are deprecated. Use e.g. int(a) < int(b) instead of a < b. """, deprecated_since_version="1.13", active_deprecations_target="modularinteger-compare", stacklevel=4, ) return op(self.val, val % self.mod) def __eq__(self, other): return self._compare(other, operator.eq) def __ne__(self, other): return self._compare(other, operator.ne) def __lt__(self, other): return self._compare_deprecated(other, operator.lt) def __le__(self, other): return self._compare_deprecated(other, operator.le) def __gt__(self, other): return self._compare_deprecated(other, operator.gt) def __ge__(self, other): return self._compare_deprecated(other, operator.ge) def __bool__(self): return bool(self.val) @classmethod def _invert(cls, value): return cls.dom.invert(value, cls.mod) def invert(self): return self.__class__(self._invert(self.val)) _modular_integer_cache: dict[tuple[Any, Any, Any], type[ModularInteger]] = {} def ModularIntegerFactory(_mod, _dom, _sym, parent): """Create custom class for specific integer modulus.""" try: _mod = _dom.convert(_mod) except CoercionFailed: ok = False else: ok = True if not ok or _mod < 1: raise ValueError("modulus must be a positive integer, got %s" % _mod) key = _mod, _dom, _sym try: cls = _modular_integer_cache[key] except KeyError: class cls(ModularInteger): mod, dom, sym = _mod, _dom, _sym _parent = parent if _sym: cls.__name__ = "SymmetricModularIntegerMod%s" % _mod else: cls.__name__ = "ModularIntegerMod%s" % _mod _modular_integer_cache[key] = cls return cls