from sympy.external import import_module import sympy.plotting.backends.base_backend as base_backend # N.B. # When changing the minimum module version for matplotlib, please change # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` __doctest_requires__ = { ("PlotGrid",): ["matplotlib"], } class PlotGrid: """This class helps to plot subplots from already created SymPy plots in a single figure. Examples ======== .. plot:: :context: close-figs :format: doctest :include-source: True >>> from sympy import symbols >>> from sympy.plotting import plot, plot3d, PlotGrid >>> x, y = symbols('x, y') >>> p1 = plot(x, x**2, x**3, (x, -5, 5)) >>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5))) >>> p3 = plot(x**3, (x, -5, 5)) >>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5)) Plotting vertically in a single line: .. plot:: :context: close-figs :format: doctest :include-source: True >>> PlotGrid(2, 1, p1, p2) PlotGrid object containing: Plot[0]:Plot object containing: [0]: cartesian line: x for x over (-5.0, 5.0) [1]: cartesian line: x**2 for x over (-5.0, 5.0) [2]: cartesian line: x**3 for x over (-5.0, 5.0) Plot[1]:Plot object containing: [0]: cartesian line: x**2 for x over (-6.0, 6.0) [1]: cartesian line: x for x over (-5.0, 5.0) Plotting horizontally in a single line: .. plot:: :context: close-figs :format: doctest :include-source: True >>> PlotGrid(1, 3, p2, p3, p4) PlotGrid object containing: Plot[0]:Plot object containing: [0]: cartesian line: x**2 for x over (-6.0, 6.0) [1]: cartesian line: x for x over (-5.0, 5.0) Plot[1]:Plot object containing: [0]: cartesian line: x**3 for x over (-5.0, 5.0) Plot[2]:Plot object containing: [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) Plotting in a grid form: .. plot:: :context: close-figs :format: doctest :include-source: True >>> PlotGrid(2, 2, p1, p2, p3, p4) PlotGrid object containing: Plot[0]:Plot object containing: [0]: cartesian line: x for x over (-5.0, 5.0) [1]: cartesian line: x**2 for x over (-5.0, 5.0) [2]: cartesian line: x**3 for x over (-5.0, 5.0) Plot[1]:Plot object containing: [0]: cartesian line: x**2 for x over (-6.0, 6.0) [1]: cartesian line: x for x over (-5.0, 5.0) Plot[2]:Plot object containing: [0]: cartesian line: x**3 for x over (-5.0, 5.0) Plot[3]:Plot object containing: [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) """ def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs): """ Parameters ========== nrows : The number of rows that should be in the grid of the required subplot. ncolumns : The number of columns that should be in the grid of the required subplot. nrows and ncolumns together define the required grid. Arguments ========= A list of predefined plot objects entered in a row-wise sequence i.e. plot objects which are to be in the top row of the required grid are written first, then the second row objects and so on Keyword arguments ================= show : Boolean The default value is set to ``True``. Set show to ``False`` and the function will not display the subplot. The returned instance of the ``PlotGrid`` class can then be used to save or display the plot by calling the ``save()`` and ``show()`` methods respectively. size : (float, float), optional A tuple in the form (width, height) in inches to specify the size of the overall figure. The default value is set to ``None``, meaning the size will be set by the default backend. """ self.matplotlib = import_module('matplotlib', import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, min_module_version='1.1.0', catch=(RuntimeError,)) self.nrows = nrows self.ncolumns = ncolumns self._series = [] self._fig = None self.args = args for arg in args: self._series.append(arg._series) self.size = size if show and self.matplotlib: self.show() def _create_figure(self): gs = self.matplotlib.gridspec.GridSpec(self.nrows, self.ncolumns) mapping = {} c = 0 for i in range(self.nrows): for j in range(self.ncolumns): if c < len(self.args): mapping[gs[i, j]] = self.args[c] c += 1 kw = {} if not self.size else {"figsize": self.size} self._fig = self.matplotlib.pyplot.figure(**kw) for spec, p in mapping.items(): kw = ({"projection": "3d"} if (len(p._series) > 0 and p._series[0].is_3D) else {}) cur_ax = self._fig.add_subplot(spec, **kw) p._plotgrid_fig = self._fig p._plotgrid_ax = cur_ax p.process_series() @property def fig(self): if not self._fig: self._create_figure() return self._fig @property def _backend(self): return self def close(self): self.matplotlib.pyplot.close(self.fig) def show(self): if base_backend._show: self.fig.tight_layout() self.matplotlib.pyplot.show() else: self.close() def save(self, path): self.fig.savefig(path) def __str__(self): plot_strs = [('Plot[%d]:' % i) + str(plot) for i, plot in enumerate(self.args)] return 'PlotGrid object containing:\n' + '\n'.join(plot_strs)