from sympy.assumptions.ask import ask, Q from sympy.assumptions.refine import handlers_dict from sympy.core import Basic, sympify, S from sympy.core.mul import mul, Mul from sympy.core.numbers import Number, Integer from sympy.core.symbol import Dummy from sympy.functions import adjoint from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust, do_one, new) from sympy.matrices.exceptions import NonInvertibleMatrixError from sympy.matrices.matrixbase import MatrixBase from sympy.utilities.exceptions import sympy_deprecation_warning from sympy.matrices.expressions._shape import validate_matmul_integer as validate from .inverse import Inverse from .matexpr import MatrixExpr from .matpow import MatPow from .transpose import transpose from .permutation import PermutationMatrix from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix # XXX: MatMul should perhaps not subclass directly from Mul class MatMul(MatrixExpr, Mul): """ A product of matrix expressions Examples ======== >>> from sympy import MatMul, MatrixSymbol >>> A = MatrixSymbol('A', 5, 4) >>> B = MatrixSymbol('B', 4, 3) >>> C = MatrixSymbol('C', 3, 6) >>> MatMul(A, B, C) A*B*C """ is_MatMul = True identity = GenericIdentity() def __new__(cls, *args, evaluate=False, check=None, _sympify=True): if not args: return cls.identity # This must be removed aggressively in the constructor to avoid # TypeErrors from GenericIdentity().shape args = list(filter(lambda i: cls.identity != i, args)) if _sympify: args = list(map(sympify, args)) obj = Basic.__new__(cls, *args) factor, matrices = obj.as_coeff_matrices() if check is not None: sympy_deprecation_warning( "Passing check to MatMul is deprecated and the check argument will be removed in a future version.", deprecated_since_version="1.11", active_deprecations_target='remove-check-argument-from-matrix-operations') if check is not False: validate(*matrices) if not matrices: # Should it be # # return Basic.__neq__(cls, factor, GenericIdentity()) ? return factor if evaluate: return cls._evaluate(obj) return obj @classmethod def _evaluate(cls, expr): return canonicalize(expr) @property def shape(self): matrices = [arg for arg in self.args if arg.is_Matrix] return (matrices[0].rows, matrices[-1].cols) def _entry(self, i, j, expand=True, **kwargs): # Avoid cyclic imports from sympy.concrete.summations import Sum from sympy.matrices.immutable import ImmutableMatrix coeff, matrices = self.as_coeff_matrices() if len(matrices) == 1: # situation like 2*X, matmul is just X return coeff * matrices[0][i, j] indices = [None]*(len(matrices) + 1) ind_ranges = [None]*(len(matrices) - 1) indices[0] = i indices[-1] = j def f(): counter = 1 while True: yield Dummy("i_%i" % counter) counter += 1 dummy_generator = kwargs.get("dummy_generator", f()) for i in range(1, len(matrices)): indices[i] = next(dummy_generator) for i, arg in enumerate(matrices[:-1]): ind_ranges[i] = arg.shape[1] - 1 matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)] expr_in_sum = Mul.fromiter(matrices) if any(v.has(ImmutableMatrix) for v in matrices): expand = True result = coeff*Sum( expr_in_sum, *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges) ) # Don't waste time in result.doit() if the sum bounds are symbolic if not any(isinstance(v, (Integer, int)) for v in ind_ranges): expand = False return result.doit() if expand else result def as_coeff_matrices(self): scalars = [x for x in self.args if not x.is_Matrix] matrices = [x for x in self.args if x.is_Matrix] coeff = Mul(*scalars) if coeff.is_commutative is False: raise NotImplementedError("noncommutative scalars in MatMul are not supported.") return coeff, matrices def as_coeff_mmul(self): coeff, matrices = self.as_coeff_matrices() return coeff, MatMul(*matrices) def expand(self, **kwargs): expanded = super(MatMul, self).expand(**kwargs) return self._evaluate(expanded) def _eval_transpose(self): """Transposition of matrix multiplication. Notes ===== The following rules are applied. Transposition for matrix multiplied with another matrix: `\\left(A B\\right)^{T} = B^{T} A^{T}` Transposition for matrix multiplied with scalar: `\\left(c A\\right)^{T} = c A^{T}` References ========== .. [1] https://en.wikipedia.org/wiki/Transpose """ coeff, matrices = self.as_coeff_matrices() return MatMul( coeff, *[transpose(arg) for arg in matrices[::-1]]).doit() def _eval_adjoint(self): return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit() def _eval_trace(self): factor, mmul = self.as_coeff_mmul() if factor != 1: from .trace import trace return factor * trace(mmul.doit()) def _eval_determinant(self): from sympy.matrices.expressions.determinant import Determinant factor, matrices = self.as_coeff_matrices() square_matrices = only_squares(*matrices) return factor**self.rows * Mul(*list(map(Determinant, square_matrices))) def _eval_inverse(self): if all(arg.is_square for arg in self.args if isinstance(arg, MatrixExpr)): return MatMul(*( arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1 for arg in self.args[::-1] ) ).doit() return Inverse(self) def doit(self, **hints): deep = hints.get('deep', True) if deep: args = tuple(arg.doit(**hints) for arg in self.args) else: args = self.args # treat scalar*MatrixSymbol or scalar*MatPow separately expr = canonicalize(MatMul(*args)) return expr # Needed for partial compatibility with Mul def args_cnc(self, cset=False, warn=True, **kwargs): coeff_c = [x for x in self.args if x.is_commutative] coeff_nc = [x for x in self.args if not x.is_commutative] if cset: clen = len(coeff_c) coeff_c = set(coeff_c) if clen and warn and len(coeff_c) != clen: raise ValueError('repeated commutative arguments: %s' % [ci for ci in coeff_c if list(self.args).count(ci) > 1]) return [coeff_c, coeff_nc] def _eval_derivative_matrix_lines(self, x): from .transpose import Transpose with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)] lines = [] for ind in with_x_ind: left_args = self.args[:ind] right_args = self.args[ind+1:] if right_args: right_mat = MatMul.fromiter(right_args) else: right_mat = Identity(self.shape[1]) if left_args: left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)]) else: left_rev = Identity(self.shape[0]) d = self.args[ind]._eval_derivative_matrix_lines(x) for i in d: i.append_first(left_rev) i.append_second(right_mat) lines.append(i) return lines mul.register_handlerclass((Mul, MatMul), MatMul) # Rules def newmul(*args): if args[0] == 1: args = args[1:] return new(MatMul, *args) def any_zeros(mul): if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix) for arg in mul.args): matrices = [arg for arg in mul.args if arg.is_Matrix] return ZeroMatrix(matrices[0].rows, matrices[-1].cols) return mul def merge_explicit(matmul): """ Merge explicit MatrixBase arguments >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint >>> from sympy.matrices.expressions.matmul import merge_explicit >>> A = MatrixSymbol('A', 2, 2) >>> B = Matrix([[1, 1], [1, 1]]) >>> C = Matrix([[1, 2], [3, 4]]) >>> X = MatMul(A, B, C) >>> pprint(X) [1 1] [1 2] A*[ ]*[ ] [1 1] [3 4] >>> pprint(merge_explicit(X)) [4 6] A*[ ] [4 6] >>> X = MatMul(B, A, C) >>> pprint(X) [1 1] [1 2] [ ]*A*[ ] [1 1] [3 4] >>> pprint(merge_explicit(X)) [1 1] [1 2] [ ]*A*[ ] [1 1] [3 4] """ if not any(isinstance(arg, MatrixBase) for arg in matmul.args): return matmul newargs = [] last = matmul.args[0] for arg in matmul.args[1:]: if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)): last = last * arg else: newargs.append(last) last = arg newargs.append(last) return MatMul(*newargs) def remove_ids(mul): """ Remove Identities from a MatMul This is a modified version of sympy.strategies.rm_id. This is necesssary because MatMul may contain both MatrixExprs and Exprs as args. See Also ======== sympy.strategies.rm_id """ # Separate Exprs from MatrixExprs in args factor, mmul = mul.as_coeff_mmul() # Apply standard rm_id for MatMuls result = rm_id(lambda x: x.is_Identity is True)(mmul) if result != mmul: return newmul(factor, *result.args) # Recombine and return else: return mul def factor_in_front(mul): factor, matrices = mul.as_coeff_matrices() if factor != 1: return newmul(factor, *matrices) return mul def combine_powers(mul): r"""Combine consecutive powers with the same base into one, e.g. $$A \times A^2 \Rightarrow A^3$$ This also cancels out the possible matrix inverses using the knowledgebase of :class:`~.Inverse`, e.g., $$ Y \times X \times X^{-1} \Rightarrow Y $$ """ factor, args = mul.as_coeff_matrices() new_args = [args[0]] for i in range(1, len(args)): A = new_args[-1] B = args[i] if isinstance(B, Inverse) and isinstance(B.arg, MatMul): Bargs = B.arg.args l = len(Bargs) if list(Bargs) == new_args[-l:]: new_args = new_args[:-l] + [Identity(B.shape[0])] continue if isinstance(A, Inverse) and isinstance(A.arg, MatMul): Aargs = A.arg.args l = len(Aargs) if list(Aargs) == args[i:i+l]: identity = Identity(A.shape[0]) new_args[-1] = identity for j in range(i, i+l): args[j] = identity continue if A.is_square == False or B.is_square == False: new_args.append(B) continue if isinstance(A, MatPow): A_base, A_exp = A.args else: A_base, A_exp = A, S.One if isinstance(B, MatPow): B_base, B_exp = B.args else: B_base, B_exp = B, S.One if A_base == B_base: new_exp = A_exp + B_exp new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) continue elif not isinstance(B_base, MatrixBase): try: B_base_inv = B_base.inverse() except NonInvertibleMatrixError: B_base_inv = None if B_base_inv is not None and A_base == B_base_inv: new_exp = A_exp - B_exp new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) continue new_args.append(B) return newmul(factor, *new_args) def combine_permutations(mul): """Refine products of permutation matrices as the products of cycles. """ args = mul.args l = len(args) if l < 2: return mul result = [args[0]] for i in range(1, l): A = result[-1] B = args[i] if isinstance(A, PermutationMatrix) and \ isinstance(B, PermutationMatrix): cycle_1 = A.args[0] cycle_2 = B.args[0] result[-1] = PermutationMatrix(cycle_1 * cycle_2) else: result.append(B) return MatMul(*result) def combine_one_matrices(mul): """ Combine products of OneMatrix e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4) """ factor, args = mul.as_coeff_matrices() new_args = [args[0]] for B in args[1:]: A = new_args[-1] if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix): new_args.append(B) continue new_args.pop() new_args.append(OneMatrix(A.shape[0], B.shape[1])) factor *= A.shape[1] return newmul(factor, *new_args) def distribute_monom(mul): """ Simplify MatMul expressions but distributing rational term to MatMul. e.g. 2*(A+B) -> 2*A + 2*B """ args = mul.args if len(args) == 2: from .matadd import MatAdd if args[0].is_MatAdd and args[1].is_Rational: return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args]) if args[1].is_MatAdd and args[0].is_Rational: return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args]) return mul rules = ( distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1), merge_explicit, factor_in_front, flatten, combine_permutations) canonicalize = exhaust(typed({MatMul: do_one(*rules)})) def only_squares(*matrices): """factor matrices only if they are square""" if matrices[0].rows != matrices[-1].cols: raise RuntimeError("Invalid matrices being multiplied") out = [] start = 0 for i, M in enumerate(matrices): if M.cols == matrices[start].rows: out.append(MatMul(*matrices[start:i+1]).doit()) start = i+1 return out def refine_MatMul(expr, assumptions): """ >>> from sympy import MatrixSymbol, Q, assuming, refine >>> X = MatrixSymbol('X', 2, 2) >>> expr = X * X.T >>> print(expr) X*X.T >>> with assuming(Q.orthogonal(X)): ... print(refine(expr)) I """ newargs = [] exprargs = [] for args in expr.args: if args.is_Matrix: exprargs.append(args) else: newargs.append(args) last = exprargs[0] for arg in exprargs[1:]: if arg == last.T and ask(Q.orthogonal(arg), assumptions): last = Identity(arg.shape[0]) elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): last = Identity(arg.shape[0]) else: newargs.append(last) last = arg newargs.append(last) return MatMul(*newargs) handlers_dict['MatMul'] = refine_MatMul