from collections.abc import Callable from sympy.core.basic import Basic from sympy.external import import_module import sympy.plotting.backends.base_backend as base_backend from sympy.printing.latex import latex # N.B. # When changing the minimum module version for matplotlib, please change # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` def _str_or_latex(label): if isinstance(label, Basic): return latex(label, mode='inline') return str(label) def _matplotlib_list(interval_list): """ Returns lists for matplotlib ``fill`` command from a list of bounding rectangular intervals """ xlist = [] ylist = [] if len(interval_list): for intervals in interval_list: intervalx = intervals[0] intervaly = intervals[1] xlist.extend([intervalx.start, intervalx.start, intervalx.end, intervalx.end, None]) ylist.extend([intervaly.start, intervaly.end, intervaly.end, intervaly.start, None]) else: #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill`` xlist.extend((None, None, None, None)) ylist.extend((None, None, None, None)) return xlist, ylist # Don't have to check for the success of importing matplotlib in each case; # we will only be using this backend if we can successfully import matploblib class MatplotlibBackend(base_backend.Plot): """ This class implements the functionalities to use Matplotlib with SymPy plotting functions. """ def __init__(self, *series, **kwargs): super().__init__(*series, **kwargs) self.matplotlib = import_module('matplotlib', import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, min_module_version='1.1.0', catch=(RuntimeError,)) self.plt = self.matplotlib.pyplot self.cm = self.matplotlib.cm self.LineCollection = self.matplotlib.collections.LineCollection self.aspect = kwargs.get('aspect_ratio', 'auto') if self.aspect != 'auto': self.aspect = float(self.aspect[1]) / self.aspect[0] # PlotGrid can provide its figure and axes to be populated with # the data from the series. self._plotgrid_fig = kwargs.pop("fig", None) self._plotgrid_ax = kwargs.pop("ax", None) def _create_figure(self): def set_spines(ax): ax.spines['left'].set_position('zero') ax.spines['right'].set_color('none') ax.spines['bottom'].set_position('zero') ax.spines['top'].set_color('none') ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('left') if self._plotgrid_fig is not None: self.fig = self._plotgrid_fig self.ax = self._plotgrid_ax if not any(s.is_3D for s in self._series): set_spines(self.ax) else: self.fig = self.plt.figure(figsize=self.size) if any(s.is_3D for s in self._series): self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") else: self.ax = self.fig.add_subplot(1, 1, 1) set_spines(self.ax) @staticmethod def get_segments(x, y, z=None): """ Convert two list of coordinates to a list of segments to be used with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`. Parameters ========== x : list List of x-coordinates y : list List of y-coordinates z : list List of z-coordinates for a 3D line. """ np = import_module('numpy') if z is not None: dim = 3 points = (x, y, z) else: dim = 2 points = (x, y) points = np.ma.array(points).T.reshape(-1, 1, dim) return np.ma.concatenate([points[:-1], points[1:]], axis=1) def _process_series(self, series, ax): np = import_module('numpy') mpl_toolkits = import_module( 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']}) # XXX Workaround for matplotlib issue # https://github.com/matplotlib/matplotlib/issues/17130 xlims, ylims, zlims = [], [], [] for s in series: # Create the collections if s.is_2Dline: if s.is_parametric: x, y, param = s.get_data() else: x, y = s.get_data() if (isinstance(s.line_color, (int, float)) or callable(s.line_color)): segments = self.get_segments(x, y) collection = self.LineCollection(segments) collection.set_array(s.get_color_array()) ax.add_collection(collection) else: lbl = _str_or_latex(s.label) line, = ax.plot(x, y, label=lbl, color=s.line_color) elif s.is_contour: ax.contour(*s.get_data()) elif s.is_3Dline: x, y, z, param = s.get_data() if (isinstance(s.line_color, (int, float)) or callable(s.line_color)): art3d = mpl_toolkits.mplot3d.art3d segments = self.get_segments(x, y, z) collection = art3d.Line3DCollection(segments) collection.set_array(s.get_color_array()) ax.add_collection(collection) else: lbl = _str_or_latex(s.label) ax.plot(x, y, z, label=lbl, color=s.line_color) xlims.append(s._xlim) ylims.append(s._ylim) zlims.append(s._zlim) elif s.is_3Dsurface: if s.is_parametric: x, y, z, u, v = s.get_data() else: x, y, z = s.get_data() collection = ax.plot_surface(x, y, z, cmap=getattr(self.cm, 'viridis', self.cm.jet), rstride=1, cstride=1, linewidth=0.1) if isinstance(s.surface_color, (float, int, Callable)): color_array = s.get_color_array() color_array = color_array.reshape(color_array.size) collection.set_array(color_array) else: collection.set_color(s.surface_color) xlims.append(s._xlim) ylims.append(s._ylim) zlims.append(s._zlim) elif s.is_implicit: points = s.get_data() if len(points) == 2: # interval math plotting x, y = _matplotlib_list(points[0]) ax.fill(x, y, facecolor=s.line_color, edgecolor='None') else: # use contourf or contour depending on whether it is # an inequality or equality. # XXX: ``contour`` plots multiple lines. Should be fixed. ListedColormap = self.matplotlib.colors.ListedColormap colormap = ListedColormap(["white", s.line_color]) xarray, yarray, zarray, plot_type = points if plot_type == 'contour': ax.contour(xarray, yarray, zarray, cmap=colormap) else: ax.contourf(xarray, yarray, zarray, cmap=colormap) elif s.is_generic: if s.type == "markers": # s.rendering_kw["color"] = s.line_color ax.plot(*s.args, **s.rendering_kw) elif s.type == "annotations": ax.annotate(*s.args, **s.rendering_kw) elif s.type == "fill": # s.rendering_kw["color"] = s.line_color ax.fill_between(*s.args, **s.rendering_kw) elif s.type == "rectangles": # s.rendering_kw["color"] = s.line_color ax.add_patch( self.matplotlib.patches.Rectangle( *s.args, **s.rendering_kw)) else: raise NotImplementedError( '{} is not supported in the SymPy plotting module ' 'with matplotlib backend. Please report this issue.' .format(ax)) Axes3D = mpl_toolkits.mplot3d.Axes3D if not isinstance(ax, Axes3D): ax.autoscale_view( scalex=ax.get_autoscalex_on(), scaley=ax.get_autoscaley_on()) else: # XXX Workaround for matplotlib issue # https://github.com/matplotlib/matplotlib/issues/17130 if xlims: xlims = np.array(xlims) xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1])) ax.set_xlim(xlim) else: ax.set_xlim([0, 1]) if ylims: ylims = np.array(ylims) ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1])) ax.set_ylim(ylim) else: ax.set_ylim([0, 1]) if zlims: zlims = np.array(zlims) zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1])) ax.set_zlim(zlim) else: ax.set_zlim([0, 1]) # Set global options. # TODO The 3D stuff # XXX The order of those is important. if self.xscale and not isinstance(ax, Axes3D): ax.set_xscale(self.xscale) if self.yscale and not isinstance(ax, Axes3D): ax.set_yscale(self.yscale) if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check ax.set_autoscale_on(self.autoscale) if self.axis_center: val = self.axis_center if isinstance(ax, Axes3D): pass elif val == 'center': ax.spines['left'].set_position('center') ax.spines['bottom'].set_position('center') elif val == 'auto': xl, xh = ax.get_xlim() yl, yh = ax.get_ylim() pos_left = ('data', 0) if xl*xh <= 0 else 'center' pos_bottom = ('data', 0) if yl*yh <= 0 else 'center' ax.spines['left'].set_position(pos_left) ax.spines['bottom'].set_position(pos_bottom) else: ax.spines['left'].set_position(('data', val[0])) ax.spines['bottom'].set_position(('data', val[1])) if not self.axis: ax.set_axis_off() if self.legend: if ax.legend(): ax.legend_.set_visible(self.legend) if self.margin: ax.set_xmargin(self.margin) ax.set_ymargin(self.margin) if self.title: ax.set_title(self.title) if self.xlabel: xlbl = _str_or_latex(self.xlabel) ax.set_xlabel(xlbl, position=(1, 0)) if self.ylabel: ylbl = _str_or_latex(self.ylabel) ax.set_ylabel(ylbl, position=(0, 1)) if isinstance(ax, Axes3D) and self.zlabel: zlbl = _str_or_latex(self.zlabel) ax.set_zlabel(zlbl, position=(0, 1)) # xlim and ylim should always be set at last so that plot limits # doesn't get altered during the process. if self.xlim: ax.set_xlim(self.xlim) if self.ylim: ax.set_ylim(self.ylim) self.ax.set_aspect(self.aspect) def process_series(self): """ Iterates over every ``Plot`` object and further calls _process_series() """ self._create_figure() self._process_series(self._series, self.ax) def show(self): self.process_series() #TODO after fixing https://github.com/ipython/ipython/issues/1255 # you can uncomment the next line and remove the pyplot.show() call #self.fig.show() if base_backend._show: self.fig.tight_layout() self.plt.show() else: self.close() def save(self, path): self.process_series() self.fig.savefig(path) def close(self): self.plt.close(self.fig)