"""Integration method that emulates by-hand techniques. This module also provides functionality to get the steps used to evaluate a particular integral, in the ``integral_steps`` function. This will return nested ``Rule`` s representing the integration rules used. Each ``Rule`` class represents a (maybe parametrized) integration rule, e.g. ``SinRule`` for integrating ``sin(x)`` and ``ReciprocalSqrtQuadraticRule`` for integrating ``1/sqrt(a+b*x+c*x**2)``. The ``eval`` method returns the integration result. The ``manualintegrate`` function computes the integral by calling ``eval`` on the rule returned by ``integral_steps``. The integrator can be extended with new heuristics and evaluation techniques. To do so, extend the ``Rule`` class, implement ``eval`` method, then write a function that accepts an ``IntegralInfo`` object and returns either a ``Rule`` instance or ``None``. If the new technique requires a new match, add the key and call to the antiderivative function to integral_steps. To enable simple substitutions, add the match to find_substitutions. """ from __future__ import annotations from typing import NamedTuple, Type, Callable, Sequence from abc import ABC, abstractmethod from dataclasses import dataclass from collections import defaultdict from collections.abc import Mapping from sympy.core.add import Add from sympy.core.cache import cacheit from sympy.core.containers import Dict from sympy.core.expr import Expr from sympy.core.function import Derivative from sympy.core.logic import fuzzy_not from sympy.core.mul import Mul from sympy.core.numbers import Integer, Number, E from sympy.core.power import Pow from sympy.core.relational import Eq, Ne, Boolean from sympy.core.singleton import S from sympy.core.symbol import Dummy, Symbol, Wild from sympy.functions.elementary.complexes import Abs from sympy.functions.elementary.exponential import exp, log from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch, cosh, coth, sech, sinh, tanh, asinh) from sympy.functions.elementary.miscellaneous import sqrt from sympy.functions.elementary.piecewise import Piecewise from sympy.functions.elementary.trigonometric import (TrigonometricFunction, cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec) from sympy.functions.special.delta_functions import Heaviside, DiracDelta from sympy.functions.special.error_functions import (erf, erfi, fresnelc, fresnels, Ci, Chi, Si, Shi, Ei, li) from sympy.functions.special.gamma_functions import uppergamma from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f from sympy.functions.special.polynomials import (chebyshevt, chebyshevu, legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi, OrthogonalPolynomial) from sympy.functions.special.zeta_functions import polylog from .integrals import Integral from sympy.logic.boolalg import And from sympy.ntheory.factor_ import primefactors from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly from sympy.simplify.radsimp import fraction from sympy.simplify.simplify import simplify from sympy.solvers.solvers import solve from sympy.strategies.core import switch, do_one, null_safe, condition from sympy.utilities.iterables import iterable from sympy.utilities.misc import debug @dataclass class Rule(ABC): integrand: Expr variable: Symbol @abstractmethod def eval(self) -> Expr: pass @abstractmethod def contains_dont_know(self) -> bool: pass @dataclass class AtomicRule(Rule, ABC): """A simple rule that does not depend on other rules""" def contains_dont_know(self) -> bool: return False @dataclass class ConstantRule(AtomicRule): """integrate(a, x) -> a*x""" def eval(self) -> Expr: return self.integrand * self.variable @dataclass class ConstantTimesRule(Rule): """integrate(a*f(x), x) -> a*integrate(f(x), x)""" constant: Expr other: Expr substep: Rule def eval(self) -> Expr: return self.constant * self.substep.eval() def contains_dont_know(self) -> bool: return self.substep.contains_dont_know() @dataclass class PowerRule(AtomicRule): """integrate(x**a, x)""" base: Expr exp: Expr def eval(self) -> Expr: return Piecewise( ((self.base**(self.exp + 1))/(self.exp + 1), Ne(self.exp, -1)), (log(self.base), True), ) @dataclass class NestedPowRule(AtomicRule): """integrate((x**a)**b, x)""" base: Expr exp: Expr def eval(self) -> Expr: m = self.base * self.integrand return Piecewise((m / (self.exp + 1), Ne(self.exp, -1)), (m * log(self.base), True)) @dataclass class AddRule(Rule): """integrate(f(x) + g(x), x) -> integrate(f(x), x) + integrate(g(x), x)""" substeps: list[Rule] def eval(self) -> Expr: return Add(*(substep.eval() for substep in self.substeps)) def contains_dont_know(self) -> bool: return any(substep.contains_dont_know() for substep in self.substeps) @dataclass class URule(Rule): """integrate(f(g(x))*g'(x), x) -> integrate(f(u), u), u = g(x)""" u_var: Symbol u_func: Expr substep: Rule def eval(self) -> Expr: result = self.substep.eval() if self.u_func.is_Pow: base, exp_ = self.u_func.as_base_exp() if exp_ == -1: # avoid needless -log(1/x) from substitution result = result.subs(log(self.u_var), -log(base)) return result.subs(self.u_var, self.u_func) def contains_dont_know(self) -> bool: return self.substep.contains_dont_know() @dataclass class PartsRule(Rule): """integrate(u(x)*v'(x), x) -> u(x)*v(x) - integrate(u'(x)*v(x), x)""" u: Symbol dv: Expr v_step: Rule second_step: Rule | None # None when is a substep of CyclicPartsRule def eval(self) -> Expr: assert self.second_step is not None v = self.v_step.eval() return self.u * v - self.second_step.eval() def contains_dont_know(self) -> bool: return self.v_step.contains_dont_know() or ( self.second_step is not None and self.second_step.contains_dont_know()) @dataclass class CyclicPartsRule(Rule): """Apply PartsRule multiple times to integrate exp(x)*sin(x)""" parts_rules: list[PartsRule] coefficient: Expr def eval(self) -> Expr: result = [] sign = 1 for rule in self.parts_rules: result.append(sign * rule.u * rule.v_step.eval()) sign *= -1 return Add(*result) / (1 - self.coefficient) def contains_dont_know(self) -> bool: return any(substep.contains_dont_know() for substep in self.parts_rules) @dataclass class TrigRule(AtomicRule, ABC): pass @dataclass class SinRule(TrigRule): """integrate(sin(x), x) -> -cos(x)""" def eval(self) -> Expr: return -cos(self.variable) @dataclass class CosRule(TrigRule): """integrate(cos(x), x) -> sin(x)""" def eval(self) -> Expr: return sin(self.variable) @dataclass class SecTanRule(TrigRule): """integrate(sec(x)*tan(x), x) -> sec(x)""" def eval(self) -> Expr: return sec(self.variable) @dataclass class CscCotRule(TrigRule): """integrate(csc(x)*cot(x), x) -> -csc(x)""" def eval(self) -> Expr: return -csc(self.variable) @dataclass class Sec2Rule(TrigRule): """integrate(sec(x)**2, x) -> tan(x)""" def eval(self) -> Expr: return tan(self.variable) @dataclass class Csc2Rule(TrigRule): """integrate(csc(x)**2, x) -> -cot(x)""" def eval(self) -> Expr: return -cot(self.variable) @dataclass class HyperbolicRule(AtomicRule, ABC): pass @dataclass class SinhRule(HyperbolicRule): """integrate(sinh(x), x) -> cosh(x)""" def eval(self) -> Expr: return cosh(self.variable) @dataclass class CoshRule(HyperbolicRule): """integrate(cosh(x), x) -> sinh(x)""" def eval(self): return sinh(self.variable) @dataclass class ExpRule(AtomicRule): """integrate(a**x, x) -> a**x/ln(a)""" base: Expr exp: Expr def eval(self) -> Expr: return self.integrand / log(self.base) @dataclass class ReciprocalRule(AtomicRule): """integrate(1/x, x) -> ln(x)""" base: Expr def eval(self) -> Expr: return log(self.base) @dataclass class ArcsinRule(AtomicRule): """integrate(1/sqrt(1-x**2), x) -> asin(x)""" def eval(self) -> Expr: return asin(self.variable) @dataclass class ArcsinhRule(AtomicRule): """integrate(1/sqrt(1+x**2), x) -> asin(x)""" def eval(self) -> Expr: return asinh(self.variable) @dataclass class ReciprocalSqrtQuadraticRule(AtomicRule): """integrate(1/sqrt(a+b*x+c*x**2), x) -> log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)""" a: Expr b: Expr c: Expr def eval(self) -> Expr: a, b, c, x = self.a, self.b, self.c, self.variable return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c) @dataclass class SqrtQuadraticDenomRule(AtomicRule): """integrate(poly(x)/sqrt(a+b*x+c*x**2), x)""" a: Expr b: Expr c: Expr coeffs: list[Expr] def eval(self) -> Expr: a, b, c, coeffs, x = self.a, self.b, self.c, self.coeffs.copy(), self.variable # Integrate poly/sqrt(a+b*x+c*x**2) using recursion. # coeffs are coefficients of the polynomial. # Let I_n = x**n/sqrt(a+b*x+c*x**2), then # I_n = A * x**(n-1)*sqrt(a+b*x+c*x**2) - B * I_{n-1} - C * I_{n-2} # where A = 1/(n*c), B = (2*n-1)*b/(2*n*c), C = (n-1)*a/(n*c) # See https://github.com/sympy/sympy/pull/23608 for proof. result_coeffs = [] coeffs = coeffs.copy() for i in range(len(coeffs)-2): n = len(coeffs)-1-i coeff = coeffs[i]/(c*n) result_coeffs.append(coeff) coeffs[i+1] -= (2*n-1)*b/2*coeff coeffs[i+2] -= (n-1)*a*coeff d, e = coeffs[-1], coeffs[-2] s = sqrt(a+b*x+c*x**2) constant = d-b*e/(2*c) if constant == 0: I0 = 0 else: step = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False) I0 = constant*step.eval() return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i) for i in range(len(result_coeffs))), e/c)*s + I0 @dataclass class SqrtQuadraticRule(AtomicRule): """integrate(sqrt(a+b*x+c*x**2), x)""" a: Expr b: Expr c: Expr def eval(self) -> Expr: step = sqrt_quadratic_rule(IntegralInfo(self.integrand, self.variable), degenerate=False) return step.eval() @dataclass class AlternativeRule(Rule): """Multiple ways to do integration.""" alternatives: list[Rule] def eval(self) -> Expr: return self.alternatives[0].eval() def contains_dont_know(self) -> bool: return any(substep.contains_dont_know() for substep in self.alternatives) @dataclass class DontKnowRule(Rule): """Leave the integral as is.""" def eval(self) -> Expr: return Integral(self.integrand, self.variable) def contains_dont_know(self) -> bool: return True @dataclass class DerivativeRule(AtomicRule): """integrate(f'(x), x) -> f(x)""" def eval(self) -> Expr: assert isinstance(self.integrand, Derivative) variable_count = list(self.integrand.variable_count) for i, (var, count) in enumerate(variable_count): if var == self.variable: variable_count[i] = (var, count - 1) break return Derivative(self.integrand.expr, *variable_count) @dataclass class RewriteRule(Rule): """Rewrite integrand to another form that is easier to handle.""" rewritten: Expr substep: Rule def eval(self) -> Expr: return self.substep.eval() def contains_dont_know(self) -> bool: return self.substep.contains_dont_know() @dataclass class CompleteSquareRule(RewriteRule): """Rewrite a+b*x+c*x**2 to a-b**2/(4*c) + c*(x+b/(2*c))**2""" pass @dataclass class PiecewiseRule(Rule): subfunctions: Sequence[tuple[Rule, bool | Boolean]] def eval(self) -> Expr: return Piecewise(*[(substep.eval(), cond) for substep, cond in self.subfunctions]) def contains_dont_know(self) -> bool: return any(substep.contains_dont_know() for substep, _ in self.subfunctions) @dataclass class HeavisideRule(Rule): harg: Expr ibnd: Expr substep: Rule def eval(self) -> Expr: # If we are integrating over x and the integrand has the form # Heaviside(m*x+b)*g(x) == Heaviside(harg)*g(symbol) # then there needs to be continuity at -b/m == ibnd, # so we subtract the appropriate term. result = self.substep.eval() return Heaviside(self.harg) * (result - result.subs(self.variable, self.ibnd)) def contains_dont_know(self) -> bool: return self.substep.contains_dont_know() @dataclass class DiracDeltaRule(AtomicRule): n: Expr a: Expr b: Expr def eval(self) -> Expr: n, a, b, x = self.n, self.a, self.b, self.variable if n == 0: return Heaviside(a+b*x)/b return DiracDelta(a+b*x, n-1)/b @dataclass class TrigSubstitutionRule(Rule): theta: Expr func: Expr rewritten: Expr substep: Rule restriction: bool | Boolean def eval(self) -> Expr: theta, func, x = self.theta, self.func, self.variable func = func.subs(sec(theta), 1/cos(theta)) func = func.subs(csc(theta), 1/sin(theta)) func = func.subs(cot(theta), 1/tan(theta)) trig_function = list(func.find(TrigonometricFunction)) assert len(trig_function) == 1 trig_function = trig_function[0] relation = solve(x - func, trig_function) assert len(relation) == 1 numer, denom = fraction(relation[0]) if isinstance(trig_function, sin): opposite = numer hypotenuse = denom adjacent = sqrt(denom**2 - numer**2) inverse = asin(relation[0]) elif isinstance(trig_function, cos): adjacent = numer hypotenuse = denom opposite = sqrt(denom**2 - numer**2) inverse = acos(relation[0]) else: # tan opposite = numer adjacent = denom hypotenuse = sqrt(denom**2 + numer**2) inverse = atan(relation[0]) substitution = [ (sin(theta), opposite/hypotenuse), (cos(theta), adjacent/hypotenuse), (tan(theta), opposite/adjacent), (theta, inverse) ] return Piecewise( (self.substep.eval().subs(substitution).trigsimp(), self.restriction) ) def contains_dont_know(self) -> bool: return self.substep.contains_dont_know() @dataclass class ArctanRule(AtomicRule): """integrate(a/(b*x**2+c), x) -> a/b / sqrt(c/b) * atan(x/sqrt(c/b))""" a: Expr b: Expr c: Expr def eval(self) -> Expr: a, b, c, x = self.a, self.b, self.c, self.variable return a/b / sqrt(c/b) * atan(x/sqrt(c/b)) @dataclass class OrthogonalPolyRule(AtomicRule, ABC): n: Expr @dataclass class JacobiRule(OrthogonalPolyRule): a: Expr b: Expr def eval(self) -> Expr: n, a, b, x = self.n, self.a, self.b, self.variable return Piecewise( (2*jacobi(n + 1, a - 1, b - 1, x)/(n + a + b), Ne(n + a + b, 0)), (x, Eq(n, 0)), ((a + b + 2)*x**2/4 + (a - b)*x/2, Eq(n, 1))) @dataclass class GegenbauerRule(OrthogonalPolyRule): a: Expr def eval(self) -> Expr: n, a, x = self.n, self.a, self.variable return Piecewise( (gegenbauer(n + 1, a - 1, x)/(2*(a - 1)), Ne(a, 1)), (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)), (S.Zero, True)) @dataclass class ChebyshevTRule(OrthogonalPolyRule): def eval(self) -> Expr: n, x = self.n, self.variable return Piecewise( ((chebyshevt(n + 1, x)/(n + 1) - chebyshevt(n - 1, x)/(n - 1))/2, Ne(Abs(n), 1)), (x**2/2, True)) @dataclass class ChebyshevURule(OrthogonalPolyRule): def eval(self) -> Expr: n, x = self.n, self.variable return Piecewise( (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)), (S.Zero, True)) @dataclass class LegendreRule(OrthogonalPolyRule): def eval(self) -> Expr: n, x = self.n, self.variable return(legendre(n + 1, x) - legendre(n - 1, x))/(2*n + 1) @dataclass class HermiteRule(OrthogonalPolyRule): def eval(self) -> Expr: n, x = self.n, self.variable return hermite(n + 1, x)/(2*(n + 1)) @dataclass class LaguerreRule(OrthogonalPolyRule): def eval(self) -> Expr: n, x = self.n, self.variable return laguerre(n, x) - laguerre(n + 1, x) @dataclass class AssocLaguerreRule(OrthogonalPolyRule): a: Expr def eval(self) -> Expr: return -assoc_laguerre(self.n + 1, self.a - 1, self.variable) @dataclass class IRule(AtomicRule, ABC): a: Expr b: Expr @dataclass class CiRule(IRule): def eval(self) -> Expr: a, b, x = self.a, self.b, self.variable return cos(b)*Ci(a*x) - sin(b)*Si(a*x) @dataclass class ChiRule(IRule): def eval(self) -> Expr: a, b, x = self.a, self.b, self.variable return cosh(b)*Chi(a*x) + sinh(b)*Shi(a*x) @dataclass class EiRule(IRule): def eval(self) -> Expr: a, b, x = self.a, self.b, self.variable return exp(b)*Ei(a*x) @dataclass class SiRule(IRule): def eval(self) -> Expr: a, b, x = self.a, self.b, self.variable return sin(b)*Ci(a*x) + cos(b)*Si(a*x) @dataclass class ShiRule(IRule): def eval(self) -> Expr: a, b, x = self.a, self.b, self.variable return sinh(b)*Chi(a*x) + cosh(b)*Shi(a*x) @dataclass class LiRule(IRule): def eval(self) -> Expr: a, b, x = self.a, self.b, self.variable return li(a*x + b)/a @dataclass class ErfRule(AtomicRule): a: Expr b: Expr c: Expr def eval(self) -> Expr: a, b, c, x = self.a, self.b, self.c, self.variable if a.is_extended_real: return Piecewise( (sqrt(S.Pi)/sqrt(-a)/2 * exp(c - b**2/(4*a)) * erf((-2*a*x - b)/(2*sqrt(-a))), a < 0), (sqrt(S.Pi)/sqrt(a)/2 * exp(c - b**2/(4*a)) * erfi((2*a*x + b)/(2*sqrt(a))), True)) return sqrt(S.Pi)/sqrt(a)/2 * exp(c - b**2/(4*a)) * \ erfi((2*a*x + b)/(2*sqrt(a))) @dataclass class FresnelCRule(AtomicRule): a: Expr b: Expr c: Expr def eval(self) -> Expr: a, b, c, x = self.a, self.b, self.c, self.variable return sqrt(S.Pi)/sqrt(2*a) * ( cos(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)) + sin(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi))) @dataclass class FresnelSRule(AtomicRule): a: Expr b: Expr c: Expr def eval(self) -> Expr: a, b, c, x = self.a, self.b, self.c, self.variable return sqrt(S.Pi)/sqrt(2*a) * ( cos(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)) - sin(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi))) @dataclass class PolylogRule(AtomicRule): a: Expr b: Expr def eval(self) -> Expr: return polylog(self.b + 1, self.a * self.variable) @dataclass class UpperGammaRule(AtomicRule): a: Expr e: Expr def eval(self) -> Expr: a, e, x = self.a, self.e, self.variable return x**e * (-a*x)**(-e) * uppergamma(e + 1, -a*x)/a @dataclass class EllipticFRule(AtomicRule): a: Expr d: Expr def eval(self) -> Expr: return elliptic_f(self.variable, self.d/self.a)/sqrt(self.a) @dataclass class EllipticERule(AtomicRule): a: Expr d: Expr def eval(self) -> Expr: return elliptic_e(self.variable, self.d/self.a)*sqrt(self.a) class IntegralInfo(NamedTuple): integrand: Expr symbol: Symbol def manual_diff(f, symbol): """Derivative of f in form expected by find_substitutions SymPy's derivatives for some trig functions (like cot) are not in a form that works well with finding substitutions; this replaces the derivatives for those particular forms with something that works better. """ if f.args: arg = f.args[0] if isinstance(f, tan): return arg.diff(symbol) * sec(arg)**2 elif isinstance(f, cot): return -arg.diff(symbol) * csc(arg)**2 elif isinstance(f, sec): return arg.diff(symbol) * sec(arg) * tan(arg) elif isinstance(f, csc): return -arg.diff(symbol) * csc(arg) * cot(arg) elif isinstance(f, Add): return sum(manual_diff(arg, symbol) for arg in f.args) elif isinstance(f, Mul): if len(f.args) == 2 and isinstance(f.args[0], Number): return f.args[0] * manual_diff(f.args[1], symbol) return f.diff(symbol) def manual_subs(expr, *args): """ A wrapper for `expr.subs(*args)` with additional logic for substitution of invertible functions. """ if len(args) == 1: sequence = args[0] if isinstance(sequence, (Dict, Mapping)): sequence = sequence.items() elif not iterable(sequence): raise ValueError("Expected an iterable of (old, new) pairs") elif len(args) == 2: sequence = [args] else: raise ValueError("subs accepts either 1 or 2 arguments") new_subs = [] for old, new in sequence: if isinstance(old, log): # If log(x) = y, then exp(a*log(x)) = exp(a*y) # that is, x**a = exp(a*y). Replace nontrivial powers of x # before subs turns them into `exp(y)**a`, but # do not replace x itself yet, to avoid `log(exp(y))`. x0 = old.args[0] expr = expr.replace(lambda x: x.is_Pow and x.base == x0, lambda x: exp(x.exp*new)) new_subs.append((x0, exp(new))) return expr.subs(list(sequence) + new_subs) # Method based on that on SIN, described in "Symbolic Integration: The # Stormy Decade" inverse_trig_functions = (atan, asin, acos, acot, acsc, asec) def find_substitutions(integrand, symbol, u_var): results = [] def test_subterm(u, u_diff): if u_diff == 0: return False substituted = integrand / u_diff debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var)) substituted = manual_subs(substituted, u, u_var).cancel() if substituted.has_free(symbol): return False # avoid increasing the degree of a rational function if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var): deg_before = max(degree(t, symbol) for t in integrand.as_numer_denom()) deg_after = max(degree(t, u_var) for t in substituted.as_numer_denom()) if deg_after > deg_before: return False return substituted.as_independent(u_var, as_Add=False) def exp_subterms(term: Expr): linear_coeffs = [] terms = [] n = Wild('n', properties=[lambda n: n.is_Integer]) for exp_ in term.find(exp): arg = exp_.args[0] if symbol not in arg.free_symbols: continue match = arg.match(n*symbol) if match: linear_coeffs.append(match[n]) else: terms.append(exp_) if linear_coeffs: terms.append(exp(gcd_list(linear_coeffs)*symbol)) return terms def possible_subterms(term): if isinstance(term, (TrigonometricFunction, HyperbolicFunction, *inverse_trig_functions, exp, log, Heaviside)): return [term.args[0]] elif isinstance(term, (chebyshevt, chebyshevu, legendre, hermite, laguerre)): return [term.args[1]] elif isinstance(term, (gegenbauer, assoc_laguerre)): return [term.args[2]] elif isinstance(term, jacobi): return [term.args[3]] elif isinstance(term, Mul): r = [] for u in term.args: r.append(u) r.extend(possible_subterms(u)) return r elif isinstance(term, Pow): r = [arg for arg in term.args if arg.has(symbol)] if term.exp.is_Integer: r.extend([term.base**d for d in primefactors(term.exp) if 1 < d < abs(term.args[1])]) if term.base.is_Add: r.extend([t for t in possible_subterms(term.base) if t.is_Pow]) return r elif isinstance(term, Add): r = [] for arg in term.args: r.append(arg) r.extend(possible_subterms(arg)) return r return [] for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))): if u == symbol: continue u_diff = manual_diff(u, symbol) new_integrand = test_subterm(u, u_diff) if new_integrand is not False: constant, new_integrand = new_integrand if new_integrand == integrand.subs(symbol, u_var): continue substitution = (u, constant, new_integrand) if substitution not in results: results.append(substitution) return results def rewriter(condition, rewrite): """Strategy that rewrites an integrand.""" def _rewriter(integral): integrand, symbol = integral debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol)) if condition(*integral): rewritten = rewrite(*integral) if rewritten != integrand: substep = integral_steps(rewritten, symbol) if not isinstance(substep, DontKnowRule) and substep: return RewriteRule(integrand, symbol, rewritten, substep) return _rewriter def proxy_rewriter(condition, rewrite): """Strategy that rewrites an integrand based on some other criteria.""" def _proxy_rewriter(criteria): criteria, integral = criteria integrand, symbol = integral debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria)) args = criteria + list(integral) if condition(*args): rewritten = rewrite(*args) if rewritten != integrand: return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) return _proxy_rewriter def multiplexer(conditions): """Apply the rule that matches the condition, else None""" def multiplexer_rl(expr): for key, rule in conditions.items(): if key(expr): return rule(expr) return multiplexer_rl def alternatives(*rules): """Strategy that makes an AlternativeRule out of multiple possible results.""" def _alternatives(integral): alts = [] count = 0 debug("List of Alternative Rules") for rule in rules: count = count + 1 debug("Rule {}: {}".format(count, rule)) result = rule(integral) if (result and not isinstance(result, DontKnowRule) and result != integral and result not in alts): alts.append(result) if len(alts) == 1: return alts[0] elif alts: doable = [rule for rule in alts if not rule.contains_dont_know()] if doable: return AlternativeRule(*integral, doable) else: return AlternativeRule(*integral, alts) return _alternatives def constant_rule(integral): return ConstantRule(*integral) def power_rule(integral): integrand, symbol = integral base, expt = integrand.as_base_exp() if symbol not in expt.free_symbols and isinstance(base, Symbol): if simplify(expt + 1) == 0: return ReciprocalRule(integrand, symbol, base) return PowerRule(integrand, symbol, base, expt) elif symbol not in base.free_symbols and isinstance(expt, Symbol): rule = ExpRule(integrand, symbol, base, expt) if fuzzy_not(log(base).is_zero): return rule elif log(base).is_zero: return ConstantRule(1, symbol) return PiecewiseRule(integrand, symbol, [ (rule, Ne(log(base), 0)), (ConstantRule(1, symbol), True) ]) def exp_rule(integral): integrand, symbol = integral if isinstance(integrand.args[0], Symbol): return ExpRule(integrand, symbol, E, integrand.args[0]) def orthogonal_poly_rule(integral): orthogonal_poly_classes = { jacobi: JacobiRule, gegenbauer: GegenbauerRule, chebyshevt: ChebyshevTRule, chebyshevu: ChebyshevURule, legendre: LegendreRule, hermite: HermiteRule, laguerre: LaguerreRule, assoc_laguerre: AssocLaguerreRule } orthogonal_poly_var_index = { jacobi: 3, gegenbauer: 2, assoc_laguerre: 2 } integrand, symbol = integral for klass in orthogonal_poly_classes: if isinstance(integrand, klass): var_index = orthogonal_poly_var_index.get(klass, 1) if (integrand.args[var_index] is symbol and not any(v.has(symbol) for v in integrand.args[:var_index])): return orthogonal_poly_classes[klass](integrand, symbol, *integrand.args[:var_index]) _special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = [] _wilds = [] _symbol = Dummy('x') def special_function_rule(integral): integrand, symbol = integral if not _special_function_patterns: a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero]) b = Wild('b', exclude=[_symbol]) c = Wild('c', exclude=[_symbol]) d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero]) e = Wild('e', exclude=[_symbol], properties=[ lambda x: not (x.is_nonnegative and x.is_integer)]) _wilds.extend((a, b, c, d, e)) # patterns consist of a SymPy class, a wildcard expr, an optional # condition coded as a lambda (when Wild properties are not enough), # followed by an applicable rule linear_pattern = a*_symbol + b quadratic_pattern = a*_symbol**2 + b*_symbol + c _special_function_patterns.extend(( (Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule), (Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule), (Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule), (Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule), (Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule), (Pow, 1/log(linear_pattern, evaluate=False), None, LiRule), (exp, exp(quadratic_pattern, evaluate=False), None, ErfRule), (sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule), (cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule), (Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule), (Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule), (Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2), lambda a, d: a != d, EllipticFRule), (Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2), lambda a, d: a != d, EllipticERule), )) _integrand = integrand.subs(symbol, _symbol) for type_, pattern, constraint, rule in _special_function_patterns: if isinstance(_integrand, type_): match = _integrand.match(pattern) if match: wild_vals = tuple(match.get(w) for w in _wilds if match.get(w) is not None) if constraint is None or constraint(*wild_vals): return rule(integrand, symbol, *wild_vals) def _add_degenerate_step(generic_cond, generic_step: Rule, degenerate_step: Rule | None) -> Rule: if degenerate_step is None: return generic_step if isinstance(generic_step, PiecewiseRule): subfunctions = [(substep, (cond & generic_cond).simplify()) for substep, cond in generic_step.subfunctions] else: subfunctions = [(generic_step, generic_cond)] if isinstance(degenerate_step, PiecewiseRule): subfunctions += degenerate_step.subfunctions else: subfunctions.append((degenerate_step, S.true)) return PiecewiseRule(generic_step.integrand, generic_step.variable, subfunctions) def nested_pow_rule(integral: IntegralInfo): # nested (c*(a+b*x)**d)**e integrand, x = integral a_ = Wild('a', exclude=[x]) b_ = Wild('b', exclude=[x, 0]) pattern = a_+b_*x generic_cond = S.true class NoMatch(Exception): pass def _get_base_exp(expr: Expr) -> tuple[Expr, Expr]: if not expr.has_free(x): return S.One, S.Zero if expr.is_Mul: _, terms = expr.as_coeff_mul() if not terms: return S.One, S.Zero results = [_get_base_exp(term) for term in terms] bases = {b for b, _ in results} bases.discard(S.One) if len(bases) == 1: return bases.pop(), Add(*(e for _, e in results)) raise NoMatch if expr.is_Pow: b, e = expr.base, expr.exp # type: ignore if e.has_free(x): raise NoMatch base_, sub_exp = _get_base_exp(b) return base_, sub_exp * e match = expr.match(pattern) if match: a, b = match[a_], match[b_] base_ = x + a/b nonlocal generic_cond generic_cond = Ne(b, 0) return base_, S.One raise NoMatch try: base, exp_ = _get_base_exp(integrand) except NoMatch: return if generic_cond is S.true: degenerate_step = None else: # equivalent with subs(b, 0) but no need to find b degenerate_step = ConstantRule(integrand.subs(x, 0), x) generic_step = NestedPowRule(integrand, x, base, exp_) return _add_degenerate_step(generic_cond, generic_step, degenerate_step) def inverse_trig_rule(integral: IntegralInfo, degenerate=True): """ Set degenerate=False on recursive call where coefficient of quadratic term is assumed non-zero. """ integrand, symbol = integral base, exp = integrand.as_base_exp() a = Wild('a', exclude=[symbol]) b = Wild('b', exclude=[symbol]) c = Wild('c', exclude=[symbol, 0]) match = base.match(a + b*symbol + c*symbol**2) if not match: return def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h) -> Rule: u_var = Dummy("u") rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2) # a>0, c>0 quadratic_base = sqrt(c/a)*(symbol-h) constant = 1/sqrt(c) u_func = None if quadratic_base is not symbol: u_func = quadratic_base quadratic_base = u_var standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2) substep = RuleClass(standard_form, quadratic_base) if constant != 1: substep = ConstantTimesRule(constant*standard_form, symbol, constant, standard_form, substep) if u_func is not None: substep = URule(rewritten, symbol, u_var, u_func, substep) if h != 0: substep = CompleteSquareRule(integrand, symbol, rewritten, substep) return substep a, b, c = [match.get(i, S.Zero) for i in (a, b, c)] generic_cond = Ne(c, 0) if not degenerate or generic_cond is S.true: degenerate_step = None elif b.is_zero: degenerate_step = ConstantRule(a ** exp, symbol) else: degenerate_step = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol)) if simplify(2*exp + 1) == 0: h, k = -b/(2*c), a - b**2/(4*c) # rewrite base to k + c*(symbol-h)**2 non_square_cond = Ne(k, 0) square_step = None if non_square_cond is not S.true: square_step = NestedPowRule(1/sqrt(c*(symbol-h)**2), symbol, symbol-h, S.NegativeOne) if non_square_cond is S.false: return square_step generic_step = ReciprocalSqrtQuadraticRule(integrand, symbol, a, b, c) step = _add_degenerate_step(non_square_cond, generic_step, square_step) if k.is_real and c.is_real: # list of ((rule, base_exp, a, sign_a, b, sign_b), condition) rules = [] for args, cond in ( # don't apply ArccoshRule to x**2-1 ((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)), # 1-x**2 ((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)), # 1+x**2 ): if cond is S.true: return make_inverse_trig(*args) if cond is not S.false: rules.append((make_inverse_trig(*args), cond)) if rules: if not k.is_positive: # conditions are not thorough, need fall back rule rules.append((generic_step, S.true)) step = PiecewiseRule(integrand, symbol, rules) else: step = generic_step return _add_degenerate_step(generic_cond, step, degenerate_step) if exp == S.Half: step = SqrtQuadraticRule(integrand, symbol, a, b, c) return _add_degenerate_step(generic_cond, step, degenerate_step) def add_rule(integral): integrand, symbol = integral results = [integral_steps(g, symbol) for g in integrand.as_ordered_terms()] return None if None in results else AddRule(integrand, symbol, results) def mul_rule(integral: IntegralInfo): integrand, symbol = integral # Constant times function case coeff, f = integrand.as_independent(symbol) if coeff != 1: next_step = integral_steps(f, symbol) if next_step is not None: return ConstantTimesRule(integrand, symbol, coeff, f, next_step) def _parts_rule(integrand, symbol) -> tuple[Expr, Expr, Expr, Expr, Rule] | None: # LIATE rule: # log, inverse trig, algebraic, trigonometric, exponential def pull_out_algebraic(integrand): integrand = integrand.cancel().together() # iterating over Piecewise args would not work here algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)]) if algebraic: u = Mul(*algebraic) dv = (integrand / u).cancel() return u, dv def pull_out_u(*functions) -> Callable[[Expr], tuple[Expr, Expr] | None]: def pull_out_u_rl(integrand: Expr) -> tuple[Expr, Expr] | None: if any(integrand.has(f) for f in functions): args = [arg for arg in integrand.args if any(isinstance(arg, cls) for cls in functions)] if args: u = Mul(*args) dv = integrand / u return u, dv return None return pull_out_u_rl liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions), pull_out_algebraic, pull_out_u(sin, cos), pull_out_u(exp)] dummy = Dummy("temporary") # we can integrate log(x) and atan(x) by setting dv = 1 if isinstance(integrand, (log, *inverse_trig_functions)): integrand = dummy * integrand for index, rule in enumerate(liate_rules): result = rule(integrand) if result: u, dv = result # Don't pick u to be a constant if possible if symbol not in u.free_symbols and not u.has(dummy): return None u = u.subs(dummy, 1) dv = dv.subs(dummy, 1) # Don't pick a non-polynomial algebraic to be differentiated if rule == pull_out_algebraic and not u.is_polynomial(symbol): return None # Don't trade one logarithm for another if isinstance(u, log): rec_dv = 1/dv if (rec_dv.is_polynomial(symbol) and degree(rec_dv, symbol) == 1): return None # Can integrate a polynomial times OrthogonalPolynomial if rule == pull_out_algebraic: if dv.is_Derivative or dv.has(TrigonometricFunction) or \ isinstance(dv, OrthogonalPolynomial): v_step = integral_steps(dv, symbol) if v_step.contains_dont_know(): return None else: du = u.diff(symbol) v = v_step.eval() return u, dv, v, du, v_step # make sure dv is amenable to integration accept = False if index < 2: # log and inverse trig are usually worth trying accept = True elif (rule == pull_out_algebraic and dv.args and all(isinstance(a, (sin, cos, exp)) for a in dv.args)): accept = True else: for lrule in liate_rules[index + 1:]: r = lrule(integrand) if r and r[0].subs(dummy, 1).equals(dv): accept = True break if accept: du = u.diff(symbol) v_step = integral_steps(simplify(dv), symbol) if not v_step.contains_dont_know(): v = v_step.eval() return u, dv, v, du, v_step return None def parts_rule(integral): integrand, symbol = integral constant, integrand = integrand.as_coeff_Mul() result = _parts_rule(integrand, symbol) steps = [] if result: u, dv, v, du, v_step = result debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step)) steps.append(result) if isinstance(v, Integral): return # Set a limit on the number of times u can be used if isinstance(u, (sin, cos, exp, sinh, cosh)): cachekey = u.xreplace({symbol: _cache_dummy}) if _parts_u_cache[cachekey] > 2: return _parts_u_cache[cachekey] += 1 # Try cyclic integration by parts a few times for _ in range(4): debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand)) coefficient = ((v * du) / integrand).cancel() if coefficient == 1: break if symbol not in coefficient.free_symbols: rule = CyclicPartsRule(integrand, symbol, [PartsRule(None, None, u, dv, v_step, None) for (u, dv, v, du, v_step) in steps], (-1) ** len(steps) * coefficient) if (constant != 1) and rule: rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule) return rule # _parts_rule is sensitive to constants, factor it out next_constant, next_integrand = (v * du).as_coeff_Mul() result = _parts_rule(next_integrand, symbol) if result: u, dv, v, du, v_step = result u *= next_constant du *= next_constant steps.append((u, dv, v, du, v_step)) else: break def make_second_step(steps, integrand): if steps: u, dv, v, du, v_step = steps[0] return PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du)) return integral_steps(integrand, symbol) if steps: u, dv, v, du, v_step = steps[0] rule = PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du)) if (constant != 1) and rule: rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule) return rule def trig_rule(integral): integrand, symbol = integral if integrand == sin(symbol): return SinRule(integrand, symbol) if integrand == cos(symbol): return CosRule(integrand, symbol) if integrand == sec(symbol)**2: return Sec2Rule(integrand, symbol) if integrand == csc(symbol)**2: return Csc2Rule(integrand, symbol) if isinstance(integrand, tan): rewritten = sin(*integrand.args) / cos(*integrand.args) elif isinstance(integrand, cot): rewritten = cos(*integrand.args) / sin(*integrand.args) elif isinstance(integrand, sec): arg = integrand.args[0] rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) / (sec(arg) + tan(arg))) elif isinstance(integrand, csc): arg = integrand.args[0] rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) / (csc(arg) + cot(arg))) else: return return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) def trig_product_rule(integral: IntegralInfo): integrand, symbol = integral if integrand == sec(symbol) * tan(symbol): return SecTanRule(integrand, symbol) if integrand == csc(symbol) * cot(symbol): return CscCotRule(integrand, symbol) def quadratic_denom_rule(integral): integrand, symbol = integral a = Wild('a', exclude=[symbol]) b = Wild('b', exclude=[symbol]) c = Wild('c', exclude=[symbol]) match = integrand.match(a / (b * symbol ** 2 + c)) if match: a, b, c = match[a], match[b], match[c] general_rule = ArctanRule(integrand, symbol, a, b, c) if b.is_extended_real and c.is_extended_real: positive_cond = c/b > 0 if positive_cond is S.true: return general_rule coeff = a/(2*sqrt(-c)*sqrt(b)) constant = sqrt(-c/b) r1 = 1/(symbol-constant) r2 = 1/(symbol+constant) log_steps = [ReciprocalRule(r1, symbol, symbol-constant), ConstantTimesRule(-r2, symbol, -1, r2, ReciprocalRule(r2, symbol, symbol+constant))] rewritten = sub = r1 - r2 negative_step = AddRule(sub, symbol, log_steps) if coeff != 1: rewritten = Mul(coeff, sub, evaluate=False) negative_step = ConstantTimesRule(rewritten, symbol, coeff, sub, negative_step) negative_step = RewriteRule(integrand, symbol, rewritten, negative_step) if positive_cond is S.false: return negative_step return PiecewiseRule(integrand, symbol, [(general_rule, positive_cond), (negative_step, S.true)]) power = PowerRule(integrand, symbol, symbol, -2) if b != 1: power = ConstantTimesRule(integrand, symbol, 1/b, symbol**-2, power) return PiecewiseRule(integrand, symbol, [(general_rule, Ne(c, 0)), (power, True)]) d = Wild('d', exclude=[symbol]) match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d)) if match2: b, c = match2[b], match2[c] if b.is_zero: return u = Dummy('u') u_func = symbol + c/(2*b) integrand2 = integrand.subs(symbol, u - c / (2*b)) next_step = integral_steps(integrand2, u) if next_step: return URule(integrand2, symbol, u, u_func, next_step) else: return e = Wild('e', exclude=[symbol]) match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e)) if match3: a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e] if c.is_zero: return denominator = c * symbol**2 + d * symbol + e const = a/(2*c) numer1 = (2*c*symbol+d) numer2 = - const*d + b u = Dummy('u') step1 = URule(integrand, symbol, u, denominator, integral_steps(u**(-1), u)) if const != 1: step1 = ConstantTimesRule(const*numer1/denominator, symbol, const, numer1/denominator, step1) if numer2.is_zero: return step1 step2 = integral_steps(numer2/denominator, symbol) substeps = AddRule(integrand, symbol, [step1, step2]) rewriten = const*numer1/denominator+numer2/denominator return RewriteRule(integrand, symbol, rewriten, substeps) return def sqrt_linear_rule(integral: IntegralInfo): """ Substitute common (a+b*x)**(1/n) """ integrand, x = integral a = Wild('a', exclude=[x]) b = Wild('b', exclude=[x, 0]) a0 = b0 = 0 bases, qs, bs = [], [], [] for pow_ in integrand.find(Pow): # collect all (a+b*x)**(p/q) base, exp_ = pow_.base, pow_.exp if exp_.is_Integer or x not in base.free_symbols: # skip 1/x and sqrt(2) continue if not exp_.is_Rational: # exclude x**pi return match = base.match(a+b*x) if not match: # skip non-linear continue # for sqrt(x+sqrt(x)), although base is non-linear, we can still substitute sqrt(x) a1, b1 = match[a], match[b] if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative: # cannot transform sqrt(x) to sqrt(x+1) or sqrt(-x) return if b0 == 0 or (b0/b1 > 1) is S.true: # choose the latter of sqrt(2*x) and sqrt(x) as representative a0, b0 = a1, b1 bases.append(base) bs.append(b1) qs.append(exp_.q) if b0 == 0: # no such pattern found return q0: Integer = lcm_list(qs) u_x = (a0 + b0*x)**(1/q0) u = Dummy("u") substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q) for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0) substep = integral_steps(substituted*u**(q0-1)*q0/b0, u) if not substep.contains_dont_know(): step: Rule = URule(integrand, x, u, u_x, substep) generic_cond = Ne(b0, 0) if generic_cond is not S.true: # possible degenerate case simplified = integrand.subs(dict.fromkeys(bs, 0)) degenerate_step = integral_steps(simplified, x) step = PiecewiseRule(integrand, x, [(step, generic_cond), (degenerate_step, S.true)]) return step def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True): integrand, x = integral a = Wild('a', exclude=[x]) b = Wild('b', exclude=[x]) c = Wild('c', exclude=[x, 0]) f = Wild('f') n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd]) match = integrand.match(f*sqrt(a+b*x+c*x**2)**n) if not match: return a, b, c, f, n = match[a], match[b], match[c], match[f], match[n] f_poly = f.as_poly(x) if f_poly is None: return generic_cond = Ne(c, 0) if not degenerate or generic_cond is S.true: degenerate_step = None elif b.is_zero: degenerate_step = integral_steps(f*sqrt(a)**n, x) else: degenerate_step = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x)) def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr): denom = sqrt(a+b*x+c*x**2) deg = numer_poly.degree() if deg <= 1: # integrand == (d+e*x)/sqrt(a+b*x+c*x**2) e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr()) # rewrite numerator to A*(2*c*x+b) + B A = e/(2*c) B = d-A*b pre_substitute = (2*c*x+b)/denom constant_step: Rule | None = None linear_step: Rule | None = None if A != 0: u = Dummy("u") pow_rule = PowerRule(1/sqrt(u), u, u, -S.Half) linear_step = URule(pre_substitute, x, u, a+b*x+c*x**2, pow_rule) if A != 1: linear_step = ConstantTimesRule(A*pre_substitute, x, A, pre_substitute, linear_step) if B != 0: constant_step = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False) if B != 1: constant_step = ConstantTimesRule(B/denom, x, B, 1/denom, constant_step) # type: ignore if linear_step and constant_step: add = Add(A*pre_substitute, B/denom, evaluate=False) step: Rule | None = RewriteRule(integrand, x, add, AddRule(add, x, [linear_step, constant_step])) else: step = linear_step or constant_step else: coeffs = numer_poly.all_coeffs() step = SqrtQuadraticDenomRule(integrand, x, a, b, c, coeffs) return step if n > 0: # rewrite poly * sqrt(s)**(2*k-1) to poly*s**k / sqrt(s) numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2) rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2) substep = sqrt_quadratic_denom_rule(numer_poly, rewritten) generic_step = RewriteRule(integrand, x, rewritten, substep) elif n == -1: generic_step = sqrt_quadratic_denom_rule(f_poly, integrand) else: return # todo: handle n < -1 case return _add_degenerate_step(generic_cond, generic_step, degenerate_step) def hyperbolic_rule(integral: tuple[Expr, Symbol]): integrand, symbol = integral if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol: if integrand.func == sinh: return SinhRule(integrand, symbol) if integrand.func == cosh: return CoshRule(integrand, symbol) u = Dummy('u') if integrand.func == tanh: rewritten = sinh(symbol)/cosh(symbol) return RewriteRule(integrand, symbol, rewritten, URule(rewritten, symbol, u, cosh(symbol), ReciprocalRule(1/u, u, u))) if integrand.func == coth: rewritten = cosh(symbol)/sinh(symbol) return RewriteRule(integrand, symbol, rewritten, URule(rewritten, symbol, u, sinh(symbol), ReciprocalRule(1/u, u, u))) else: rewritten = integrand.rewrite(tanh) if integrand.func == sech: return RewriteRule(integrand, symbol, rewritten, URule(rewritten, symbol, u, tanh(symbol/2), ArctanRule(2/(u**2 + 1), u, S(2), S.One, S.One))) if integrand.func == csch: return RewriteRule(integrand, symbol, rewritten, URule(rewritten, symbol, u, tanh(symbol/2), ReciprocalRule(1/u, u, u))) @cacheit def make_wilds(symbol): a = Wild('a', exclude=[symbol]) b = Wild('b', exclude=[symbol]) m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)]) n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)]) return a, b, m, n @cacheit def sincos_pattern(symbol): a, b, m, n = make_wilds(symbol) pattern = sin(a*symbol)**m * cos(b*symbol)**n return pattern, a, b, m, n @cacheit def tansec_pattern(symbol): a, b, m, n = make_wilds(symbol) pattern = tan(a*symbol)**m * sec(b*symbol)**n return pattern, a, b, m, n @cacheit def cotcsc_pattern(symbol): a, b, m, n = make_wilds(symbol) pattern = cot(a*symbol)**m * csc(b*symbol)**n return pattern, a, b, m, n @cacheit def heaviside_pattern(symbol): m = Wild('m', exclude=[symbol]) b = Wild('b', exclude=[symbol]) g = Wild('g') pattern = Heaviside(m*symbol + b) * g return pattern, m, b, g def uncurry(func): def uncurry_rl(args): return func(*args) return uncurry_rl def trig_rewriter(rewrite): def trig_rewriter_rl(args): a, b, m, n, integrand, symbol = args rewritten = rewrite(a, b, m, n, integrand, symbol) if rewritten != integrand: return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) return trig_rewriter_rl sincos_botheven_condition = uncurry( lambda a, b, m, n, i, s: m.is_even and n.is_even and m.is_nonnegative and n.is_nonnegative) sincos_botheven = trig_rewriter( lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) * (((1 + cos(2*b*symbol)) / 2) ** (n / 2)) )) sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3) sincos_sinodd = trig_rewriter( lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) * sin(a*symbol) * cos(b*symbol) ** n)) sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3) sincos_cosodd = trig_rewriter( lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) * cos(b*symbol) * sin(a*symbol) ** m)) tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4) tansec_seceven = trig_rewriter( lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) * sec(b*symbol)**2 * tan(a*symbol) ** m )) tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd) tansec_tanodd = trig_rewriter( lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) * tan(a*symbol) * sec(b*symbol) ** n )) tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0) tan_tansquared = trig_rewriter( lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1)) cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4) cotcsc_csceven = trig_rewriter( lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) * csc(b*symbol)**2 * cot(a*symbol) ** m )) cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd) cotcsc_cotodd = trig_rewriter( lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) * cot(a*symbol) * csc(b*symbol) ** n )) def trig_sincos_rule(integral): integrand, symbol = integral if any(integrand.has(f) for f in (sin, cos)): pattern, a, b, m, n = sincos_pattern(symbol) match = integrand.match(pattern) if not match: return return multiplexer({ sincos_botheven_condition: sincos_botheven, sincos_sinodd_condition: sincos_sinodd, sincos_cosodd_condition: sincos_cosodd })(tuple( [match.get(i, S.Zero) for i in (a, b, m, n)] + [integrand, symbol])) def trig_tansec_rule(integral): integrand, symbol = integral integrand = integrand.subs({ 1 / cos(symbol): sec(symbol) }) if any(integrand.has(f) for f in (tan, sec)): pattern, a, b, m, n = tansec_pattern(symbol) match = integrand.match(pattern) if not match: return return multiplexer({ tansec_tanodd_condition: tansec_tanodd, tansec_seceven_condition: tansec_seceven, tan_tansquared_condition: tan_tansquared })(tuple( [match.get(i, S.Zero) for i in (a, b, m, n)] + [integrand, symbol])) def trig_cotcsc_rule(integral): integrand, symbol = integral integrand = integrand.subs({ 1 / sin(symbol): csc(symbol), 1 / tan(symbol): cot(symbol), cos(symbol) / tan(symbol): cot(symbol) }) if any(integrand.has(f) for f in (cot, csc)): pattern, a, b, m, n = cotcsc_pattern(symbol) match = integrand.match(pattern) if not match: return return multiplexer({ cotcsc_cotodd_condition: cotcsc_cotodd, cotcsc_csceven_condition: cotcsc_csceven })(tuple( [match.get(i, S.Zero) for i in (a, b, m, n)] + [integrand, symbol])) def trig_sindouble_rule(integral): integrand, symbol = integral a = Wild('a', exclude=[sin(2*symbol)]) match = integrand.match(sin(2*symbol)*a) if match: sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol) return integral_steps(integrand * sin_double, symbol) def trig_powers_products_rule(integral): return do_one(null_safe(trig_sincos_rule), null_safe(trig_tansec_rule), null_safe(trig_cotcsc_rule), null_safe(trig_sindouble_rule))(integral) def trig_substitution_rule(integral): integrand, symbol = integral A = Wild('a', exclude=[0, symbol]) B = Wild('b', exclude=[0, symbol]) theta = Dummy("theta") target_pattern = A + B*symbol**2 matches = integrand.find(target_pattern) for expr in matches: match = expr.match(target_pattern) a = match.get(A, S.Zero) b = match.get(B, S.Zero) a_positive = ((a.is_number and a > 0) or a.is_positive) b_positive = ((b.is_number and b > 0) or b.is_positive) a_negative = ((a.is_number and a < 0) or a.is_negative) b_negative = ((b.is_number and b < 0) or b.is_negative) x_func = None if a_positive and b_positive: # a**2 + b*x**2. Assume sec(theta) > 0, -pi/2 < theta < pi/2 x_func = (sqrt(a)/sqrt(b)) * tan(theta) # Do not restrict the domain: tan(theta) takes on any real # value on the interval -pi/2 < theta < pi/2 so x takes on # any value restriction = True elif a_positive and b_negative: # a**2 - b*x**2. Assume cos(theta) > 0, -pi/2 < theta < pi/2 constant = sqrt(a)/sqrt(-b) x_func = constant * sin(theta) restriction = And(symbol > -constant, symbol < constant) elif a_negative and b_positive: # b*x**2 - a**2. Assume sin(theta) > 0, 0 < theta < pi constant = sqrt(-a)/sqrt(b) x_func = constant * sec(theta) restriction = And(symbol > -constant, symbol < constant) if x_func: # Manually simplify sqrt(trig(theta)**2) to trig(theta) # Valid due to assumed domain restriction substitutions = {} for f in [sin, cos, tan, sec, csc, cot]: substitutions[sqrt(f(theta)**2)] = f(theta) substitutions[sqrt(f(theta)**(-2))] = 1/f(theta) replaced = integrand.subs(symbol, x_func).trigsimp() replaced = manual_subs(replaced, substitutions) if not replaced.has(symbol): replaced *= manual_diff(x_func, theta) replaced = replaced.trigsimp() secants = replaced.find(1/cos(theta)) if secants: replaced = replaced.xreplace({ 1/cos(theta): sec(theta) }) substep = integral_steps(replaced, theta) if not substep.contains_dont_know(): return TrigSubstitutionRule(integrand, symbol, theta, x_func, replaced, substep, restriction) def heaviside_rule(integral): integrand, symbol = integral pattern, m, b, g = heaviside_pattern(symbol) match = integrand.match(pattern) if match and 0 != match[g]: # f = Heaviside(m*x + b)*g substep = integral_steps(match[g], symbol) m, b = match[m], match[b] return HeavisideRule(integrand, symbol, m*symbol + b, -b/m, substep) def dirac_delta_rule(integral: IntegralInfo): integrand, x = integral if len(integrand.args) == 1: n = S.Zero else: n = integrand.args[1] if not n.is_Integer or n < 0: return a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0]) match = integrand.args[0].match(a+b*x) if not match: return a, b = match[a], match[b] generic_cond = Ne(b, 0) if generic_cond is S.true: degenerate_step = None else: degenerate_step = ConstantRule(DiracDelta(a, n), x) generic_step = DiracDeltaRule(integrand, x, n, a, b) return _add_degenerate_step(generic_cond, generic_step, degenerate_step) def substitution_rule(integral): integrand, symbol = integral u_var = Dummy("u") substitutions = find_substitutions(integrand, symbol, u_var) count = 0 if substitutions: debug("List of Substitution Rules") ways = [] for u_func, c, substituted in substitutions: subrule = integral_steps(substituted, u_var) count = count + 1 debug("Rule {}: {}".format(count, subrule)) if subrule.contains_dont_know(): continue if simplify(c - 1) != 0: _, denom = c.as_numer_denom() if subrule: subrule = ConstantTimesRule(c * substituted, u_var, c, substituted, subrule) if denom.free_symbols: piecewise = [] could_be_zero = [] if isinstance(denom, Mul): could_be_zero = denom.args else: could_be_zero.append(denom) for expr in could_be_zero: if not fuzzy_not(expr.is_zero): substep = integral_steps(manual_subs(integrand, expr, 0), symbol) if substep: piecewise.append(( substep, Eq(expr, 0) )) piecewise.append((subrule, True)) subrule = PiecewiseRule(substituted, symbol, piecewise) ways.append(URule(integrand, symbol, u_var, u_func, subrule)) if len(ways) > 1: return AlternativeRule(integrand, symbol, ways) elif ways: return ways[0] partial_fractions_rule = rewriter( lambda integrand, symbol: integrand.is_rational_function(), lambda integrand, symbol: integrand.apart(symbol)) cancel_rule = rewriter( # lambda integrand, symbol: integrand.is_algebraic_expr(), # lambda integrand, symbol: isinstance(integrand, Mul), lambda integrand, symbol: True, lambda integrand, symbol: integrand.cancel()) distribute_expand_rule = rewriter( lambda integrand, symbol: ( isinstance(integrand, (Pow, Mul)) or all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)), lambda integrand, symbol: integrand.expand()) trig_expand_rule = rewriter( # If there are trig functions with different arguments, expand them lambda integrand, symbol: ( len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1), lambda integrand, symbol: integrand.expand(trig=True)) def derivative_rule(integral): integrand = integral[0] diff_variables = integrand.variables undifferentiated_function = integrand.expr integrand_variables = undifferentiated_function.free_symbols if integral.symbol in integrand_variables: if integral.symbol in diff_variables: return DerivativeRule(*integral) else: return DontKnowRule(integrand, integral.symbol) else: return ConstantRule(*integral) def rewrites_rule(integral): integrand, symbol = integral if integrand.match(1/cos(symbol)): rewritten = integrand.subs(1/cos(symbol), sec(symbol)) return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) def fallback_rule(integral): return DontKnowRule(*integral) # Cache is used to break cyclic integrals. # Need to use the same dummy variable in cached expressions for them to match. # Also record "u" of integration by parts, to avoid infinite repetition. _integral_cache: dict[Expr, Expr | None] = {} _parts_u_cache: dict[Expr, int] = defaultdict(int) _cache_dummy = Dummy("z") def integral_steps(integrand, symbol, **options): """Returns the steps needed to compute an integral. Explanation =========== This function attempts to mirror what a student would do by hand as closely as possible. SymPy Gamma uses this to provide a step-by-step explanation of an integral. The code it uses to format the results of this function can be found at https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py. Examples ======== >>> from sympy import exp, sin >>> from sympy.integrals.manualintegrate import integral_steps >>> from sympy.abc import x >>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \ # doctest: +NORMALIZE_WHITESPACE URule(integrand=exp(x)/(exp(2*x) + 1), variable=x, u_var=_u, u_func=exp(x), substep=ArctanRule(integrand=1/(_u**2 + 1), variable=_u, a=1, b=1, c=1)) >>> print(repr(integral_steps(sin(x), x))) \ # doctest: +NORMALIZE_WHITESPACE SinRule(integrand=sin(x), variable=x) >>> print(repr(integral_steps((x**2 + 3)**2, x))) \ # doctest: +NORMALIZE_WHITESPACE RewriteRule(integrand=(x**2 + 3)**2, variable=x, rewritten=x**4 + 6*x**2 + 9, substep=AddRule(integrand=x**4 + 6*x**2 + 9, variable=x, substeps=[PowerRule(integrand=x**4, variable=x, base=x, exp=4), ConstantTimesRule(integrand=6*x**2, variable=x, constant=6, other=x**2, substep=PowerRule(integrand=x**2, variable=x, base=x, exp=2)), ConstantRule(integrand=9, variable=x)])) Returns ======= rule : Rule The first step; most rules have substeps that must also be considered. These substeps can be evaluated using ``manualintegrate`` to obtain a result. """ cachekey = integrand.xreplace({symbol: _cache_dummy}) if cachekey in _integral_cache: if _integral_cache[cachekey] is None: # Stop this attempt, because it leads around in a loop return DontKnowRule(integrand, symbol) else: # TODO: This is for future development, as currently # _integral_cache gets no values other than None return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol), symbol) else: _integral_cache[cachekey] = None integral = IntegralInfo(integrand, symbol) def key(integral): integrand = integral.integrand if symbol not in integrand.free_symbols: return Number for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial): if isinstance(integrand, cls): return cls return type(integrand) def integral_is_subclass(*klasses): def _integral_is_subclass(integral): k = key(integral) return k and issubclass(k, klasses) return _integral_is_subclass result = do_one( null_safe(special_function_rule), null_safe(switch(key, { Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule), null_safe(sqrt_linear_rule), null_safe(quadratic_denom_rule)), Symbol: power_rule, exp: exp_rule, Add: add_rule, Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule), null_safe(heaviside_rule), null_safe(quadratic_denom_rule), null_safe(sqrt_linear_rule), null_safe(sqrt_quadratic_rule)), Derivative: derivative_rule, TrigonometricFunction: trig_rule, Heaviside: heaviside_rule, DiracDelta: dirac_delta_rule, OrthogonalPolynomial: orthogonal_poly_rule, Number: constant_rule })), do_one( null_safe(trig_rule), null_safe(hyperbolic_rule), null_safe(alternatives( rewrites_rule, substitution_rule, condition( integral_is_subclass(Mul, Pow), partial_fractions_rule), condition( integral_is_subclass(Mul, Pow), cancel_rule), condition( integral_is_subclass(Mul, log, *inverse_trig_functions), parts_rule), condition( integral_is_subclass(Mul, Pow), distribute_expand_rule), trig_powers_products_rule, trig_expand_rule )), null_safe(condition(integral_is_subclass(Mul, Pow), nested_pow_rule)), null_safe(trig_substitution_rule) ), fallback_rule)(integral) del _integral_cache[cachekey] return result def manualintegrate(f, var): """manualintegrate(f, var) Explanation =========== Compute indefinite integral of a single variable using an algorithm that resembles what a student would do by hand. Unlike :func:`~.integrate`, var can only be a single symbol. Examples ======== >>> from sympy import sin, cos, tan, exp, log, integrate >>> from sympy.integrals.manualintegrate import manualintegrate >>> from sympy.abc import x >>> manualintegrate(1 / x, x) log(x) >>> integrate(1/x) log(x) >>> manualintegrate(log(x), x) x*log(x) - x >>> integrate(log(x)) x*log(x) - x >>> manualintegrate(exp(x) / (1 + exp(2 * x)), x) atan(exp(x)) >>> integrate(exp(x) / (1 + exp(2 * x))) RootSum(4*_z**2 + 1, Lambda(_i, _i*log(2*_i + exp(x)))) >>> manualintegrate(cos(x)**4 * sin(x), x) -cos(x)**5/5 >>> integrate(cos(x)**4 * sin(x), x) -cos(x)**5/5 >>> manualintegrate(cos(x)**4 * sin(x)**3, x) cos(x)**7/7 - cos(x)**5/5 >>> integrate(cos(x)**4 * sin(x)**3, x) cos(x)**7/7 - cos(x)**5/5 >>> manualintegrate(tan(x), x) -log(cos(x)) >>> integrate(tan(x), x) -log(cos(x)) See Also ======== sympy.integrals.integrals.integrate sympy.integrals.integrals.Integral.doit sympy.integrals.integrals.Integral """ result = integral_steps(f, var).eval() # Clear the cache of u-parts _parts_u_cache.clear() # If we got Piecewise with two parts, put generic first if isinstance(result, Piecewise) and len(result.args) == 2: cond = result.args[0][1] if isinstance(cond, Eq) and result.args[1][1] == True: result = result.func( (result.args[1][0], Ne(*cond.args)), (result.args[0][0], True)) return result