from functools import reduce import operator from sympy.core import Basic, sympify from sympy.core.add import add, Add, _could_extract_minus_sign from sympy.core.sorting import default_sort_key from sympy.functions import adjoint from sympy.matrices.matrixbase import MatrixBase from sympy.matrices.expressions.transpose import transpose from sympy.strategies import (rm_id, unpack, flatten, sort, condition, exhaust, do_one, glom) from sympy.matrices.expressions.matexpr import MatrixExpr from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix from sympy.matrices.expressions._shape import validate_matadd_integer as validate from sympy.utilities.iterables import sift from sympy.utilities.exceptions import sympy_deprecation_warning # XXX: MatAdd should perhaps not subclass directly from Add class MatAdd(MatrixExpr, Add): """A Sum of Matrix Expressions MatAdd inherits from and operates like SymPy Add Examples ======== >>> from sympy import MatAdd, MatrixSymbol >>> A = MatrixSymbol('A', 5, 5) >>> B = MatrixSymbol('B', 5, 5) >>> C = MatrixSymbol('C', 5, 5) >>> MatAdd(A, B, C) A + B + C """ is_MatAdd = True identity = GenericZeroMatrix() def __new__(cls, *args, evaluate=False, check=None, _sympify=True): if not args: return cls.identity # This must be removed aggressively in the constructor to avoid # TypeErrors from GenericZeroMatrix().shape args = list(filter(lambda i: cls.identity != i, args)) if _sympify: args = list(map(sympify, args)) if not all(isinstance(arg, MatrixExpr) for arg in args): raise TypeError("Mix of Matrix and Scalar symbols") obj = Basic.__new__(cls, *args) if check is not None: sympy_deprecation_warning( "Passing check to MatAdd is deprecated and the check argument will be removed in a future version.", deprecated_since_version="1.11", active_deprecations_target='remove-check-argument-from-matrix-operations') if check is not False: validate(*args) if evaluate: obj = cls._evaluate(obj) return obj @classmethod def _evaluate(cls, expr): return canonicalize(expr) @property def shape(self): return self.args[0].shape def could_extract_minus_sign(self): return _could_extract_minus_sign(self) def expand(self, **kwargs): expanded = super(MatAdd, self).expand(**kwargs) return self._evaluate(expanded) def _entry(self, i, j, **kwargs): return Add(*[arg._entry(i, j, **kwargs) for arg in self.args]) def _eval_transpose(self): return MatAdd(*[transpose(arg) for arg in self.args]).doit() def _eval_adjoint(self): return MatAdd(*[adjoint(arg) for arg in self.args]).doit() def _eval_trace(self): from .trace import trace return Add(*[trace(arg) for arg in self.args]).doit() def doit(self, **hints): deep = hints.get('deep', True) if deep: args = [arg.doit(**hints) for arg in self.args] else: args = self.args return canonicalize(MatAdd(*args)) def _eval_derivative_matrix_lines(self, x): add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args] return [j for i in add_lines for j in i] add.register_handlerclass((Add, MatAdd), MatAdd) factor_of = lambda arg: arg.as_coeff_mmul()[0] matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1]) def combine(cnt, mat): if cnt == 1: return mat else: return cnt * mat def merge_explicit(matadd): """ Merge explicit MatrixBase arguments Examples ======== >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint >>> from sympy.matrices.expressions.matadd import merge_explicit >>> A = MatrixSymbol('A', 2, 2) >>> B = eye(2) >>> C = Matrix([[1, 2], [3, 4]]) >>> X = MatAdd(A, B, C) >>> pprint(X) [1 0] [1 2] A + [ ] + [ ] [0 1] [3 4] >>> pprint(merge_explicit(X)) [2 2] A + [ ] [3 5] """ groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase)) if len(groups[True]) > 1: return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])])) else: return matadd rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)), unpack, flatten, glom(matrix_of, factor_of, combine), merge_explicit, sort(default_sort_key)) canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd), do_one(*rules)))