Source code for ultraplot.figure
#!/usr/bin/env python3
"""
The figure class used for all ultraplot figures.
"""
import functools
import inspect
import os
from numbers import Integral
from packaging import version
try:
from typing import List, Optional, Tuple, Union
except ImportError:
from typing_extensions import List, Optional, Tuple, Union
import matplotlib.axes as maxes
import matplotlib.figure as mfigure
import matplotlib.gridspec as mgridspec
import matplotlib.projections as mproj
import matplotlib.text as mtext
import matplotlib.transforms as mtransforms
import numpy as np
try:
from typing import override
except:
from typing_extensions import override
from . import axes as paxes
from . import constructor
from . import gridspec as pgridspec
from .config import rc, rc_matplotlib
from .internals import (
_not_none,
_pop_params,
_pop_rc,
_translate_loc,
context,
docstring,
ic, # noqa: F401
labels,
warnings,
)
from .utils import _Crawler, units
__all__ = [
"Figure",
]
# Preset figure widths or sizes based on academic journal recommendations
# NOTE: Please feel free to add to this!
JOURNAL_SIZES = {
"aaas1": "5.5cm",
"aaas2": "12cm",
"agu1": ("95mm", "115mm"),
"agu2": ("190mm", "115mm"),
"agu3": ("95mm", "230mm"),
"agu4": ("190mm", "230mm"),
"ams1": 3.2,
"ams2": 4.5,
"ams3": 5.5,
"ams4": 6.5,
"cop1": "8.3cm",
"cop2": "12cm",
"nat1": "89mm",
"nat2": "183mm",
"pnas1": "8.7cm",
"pnas2": "11.4cm",
"pnas3": "17.8cm",
}
# Figure docstring
_figure_docstring = """
refnum : int, optional
The reference subplot number. The `refwidth`, `refheight`, and `refaspect`
keyword args are applied to this subplot, and the aspect ratio is conserved
for this subplot in the `~Figure.auto_layout`. The default is the first
subplot created in the figure.
refaspect : float or 2-tuple of float, optional
The reference subplot aspect ratio. If scalar, this indicates the width
divided by height. If 2-tuple, this indicates the (width, height). Ignored
if both `figwidth` *and* `figheight` or both `refwidth` *and* `refheight` were
passed. The default value is ``1`` or the "data aspect ratio" if the latter
is explicitly fixed (as with `~ultraplot.axes.PlotAxes.imshow` plots and
`~ultraplot.axes.Axes.GeoAxes` projections; see :func:`~matplotlib.axes.Axes.set_aspect`).
refwidth, refheight : unit-spec, default: :rc:`subplots.refwidth`
The width, height of the reference subplot.
%(units.in)s
Ignored if `figwidth`, `figheight`, or `figsize` was passed. If you
specify just one, `refaspect` will be respected.
ref, aspect, axwidth, axheight
Aliases for `refnum`, `refaspect`, `refwidth`, `refheight`.
*These may be deprecated in a future release.*
figwidth, figheight : unit-spec, optional
The figure width and height. Default behavior is to use `refwidth`.
%(units.in)s
If you specify just one, `refaspect` will be respected.
width, height
Aliases for `figwidth`, `figheight`.
figsize : 2-tuple, optional
Tuple specifying the figure ``(width, height)``.
sharex, sharey, share \
: {0, False, 1, 'labels', 'labs', 2, 'limits', 'lims', 3, True, 4, 'all'}, \
default: :rc:`subplots.share`
The axis sharing "level" for the *x* axis, *y* axis, or both
axes. Options are as follows:
* ``0`` or ``False``: No axis sharing. This also sets the default `spanx`
and `spany` values to ``False``.
* ``1`` or ``'labels'`` or ``'labs'``: Only draw axis labels on the bottommost
row or leftmost column of subplots. Tick labels still appear on every subplot.
* ``2`` or ``'limits'`` or ``'lims'``: As above but force the axis limits, scales,
and tick locations to be identical. Tick labels still appear on every subplot.
* ``3`` or ``True``: As above but only show the tick labels on the bottommost
row and leftmost column of subplots.
* ``4`` or ``'all'``: As above but also share the axis limits, scales, and
tick locations between subplots not in the same row or column.
spanx, spany, span : bool or {0, 1}, default: :rc:`subplots.span`
Whether to use "spanning" axis labels for the *x* axis, *y* axis, or both
axes. Default is ``False`` if `sharex`, `sharey`, or `share` are ``0`` or
``False``. When ``True``, a single, centered axis label is used for all axes
with bottom and left edges in the same row or column. This can considerably
redundancy in your figure. "Spanning" labels integrate with "shared" axes. For
example, for a 3-row, 3-column figure, with ``sharey > 1`` and ``spany == True``,
your figure will have 1 y axis label instead of 9 y axis labels.
alignx, aligny, align : bool or {0, 1}, default: :rc:`subplots.align`
Whether to `"align" axis labels \
<https://matplotlib.org/stable/gallery/subplots_axes_and_figures/align_labels_demo.html>`__
for the *x* axis, *y* axis, or both axes. Aligned labels always appear in the same
row or column. This is ignored if `spanx`, `spany`, or `span` are ``True``.
%(gridspec.shared)s
%(gridspec.scalar)s
tight : bool, default: :rc`subplots.tight`
Whether automatic calls to `~Figure.auto_layout` should include
:ref:`tight layout adjustments <ug_tight>`. If you manually specified a spacing
in the call to `~ultraplot.ui.subplots`, it will be used to override the tight
layout spacing. For example, with ``left=1``, the left margin is set to 1
em-width, while the remaining margin widths are calculated automatically.
%(gridspec.tight)s
journal : str, optional
String corresponding to an academic journal standard used to control the figure
width `figwidth` and, if specified, the figure height `figheight`. See the below
table. Feel free to add to this table by submitting a pull request.
.. _journal_table:
=========== ==================== \
===============================================================================
Key Size description Organization
=========== ==================== \
===============================================================================
``'aaas1'`` 1-column \
`American Association for the Advancement of Science <aaas_>`_ (e.g. *Science*)
``'aaas2'`` 2-column ”
``'agu1'`` 1-column `American Geophysical Union <agu_>`_
``'agu2'`` 2-column ”
``'agu3'`` full height 1-column ”
``'agu4'`` full height 2-column ”
``'ams1'`` 1-column `American Meteorological Society <ams_>`_
``'ams2'`` small 2-column ”
``'ams3'`` medium 2-column ”
``'ams4'`` full 2-column ”
``'cop1'`` 1-column \
`Copernicus Publications <cop_>`_ (e.g. *The Cryosphere*, *Geoscientific Model Development*)
``'cop2'`` 2-column ”
``'nat1'`` 1-column `Nature Research <nat_>`_
``'nat2'`` 2-column ”
``'pnas1'`` 1-column \
`Proceedings of the National Academy of Sciences <pnas_>`_
``'pnas2'`` 2-column ”
``'pnas3'`` landscape page ”
=========== ==================== \
===============================================================================
.. _aaas: \
https://www.sciencemag.org/authors/instructions-preparing-initial-manuscript
.. _agu: \
https://www.agu.org/Publish-with-AGU/Publish/Author-Resources/Graphic-Requirements
.. _ams: \
https://www.ametsoc.org/ams/index.cfm/publications/authors/journal-and-bams-authors/figure-information-for-authors/
.. _cop: \
https://publications.copernicus.org/for_authors/manuscript_preparation.html#figurestables
.. _nat: \
https://www.nature.com/nature/for-authors/formatting-guide
.. _pnas: \
https://www.pnas.org/page/authors/format
"""
docstring._snippet_manager["figure.figure"] = _figure_docstring
# Multiple subplots
_subplots_params_docstring = """
array : `ultraplot.gridspec.GridSpec` or array-like of int, optional
The subplot grid specifier. If a :class:`~ultraplot.gridspec.GridSpec`, one subplot is
drawn for each unique :class:`~ultraplot.gridspec.GridSpec` slot. If a 2D array of integers,
one subplot is drawn for each unique integer in the array. Think of this array as
a "picture" of the subplot grid -- for example, the array ``[[1, 1], [2, 3]]``
creates one long subplot in the top row, two smaller subplots in the bottom row.
Integers must range from 1 to the number of plots, and ``0`` indicates an
empty space -- for example, ``[[1, 1, 1], [2, 0, 3]]`` creates one long subplot
in the top row with two subplots in the bottom row separated by a space.
nrows, ncols : int, default: 1
The number of rows and columns in the subplot grid. Ignored
if `array` was passed. Use these arguments for simple subplot grids.
order : {'C', 'F'}, default: 'C'
Whether subplots are numbered in column-major (``'C'``) or row-major (``'F'``)
order. Analogous to `numpy.array` ordering. This controls the order that
subplots appear in the `SubplotGrid` returned by this function, and the order
of subplot a-b-c labels (see `~ultraplot.axes.Axes.format`).
%(axes.proj)s
To use different projections for different subplots, you have
two options:
* Pass a *list* of projection specifications, one for each subplot.
For example, ``uplt.subplots(ncols=2, proj=('cart', 'robin'))``.
* Pass a *dictionary* of projection specifications, where the
keys are integers or tuples of integers that indicate the projection
to use for the corresponding subplot number(s). If a key is not
provided, the default projection ``'cartesian'`` is used. For example,
``uplt.subplots(ncols=4, proj={2: 'cyl', (3, 4): 'stere'})`` creates
a figure with a default Cartesian axes for the first subplot, a Mercator
projection for the second subplot, and a Stereographic projection
for the third and fourth subplots.
%(axes.proj_kw)s
If dictionary of properties, applies globally. If list or dictionary of
dictionaries, applies to specific subplots, as with `proj`. For example,
``uplt.subplots(ncols=2, proj='cyl', proj_kw=({'lon_0': 0}, {'lon_0': 180})``
centers the projection in the left subplot on the prime meridian and in the
right subplot on the international dateline.
%(axes.backend)s
If string, applies to all subplots. If list or dict, applies to specific
subplots, as with `proj`.
%(gridspec.shared)s
%(gridspec.vector)s
%(gridspec.tight)s
"""
docstring._snippet_manager["figure.subplots_params"] = _subplots_params_docstring
# Extra args docstring
_axes_params_docstring = """
**kwargs
Passed to the ultraplot class `ultraplot.axes.CartesianAxes`, `ultraplot.axes.PolarAxes`,
`ultraplot.axes.GeoAxes`, or `ultraplot.axes.ThreeAxes`. This can include keyword
arguments for projection-specific ``format`` commands.
"""
docstring._snippet_manager["figure.axes_params"] = _axes_params_docstring
# Multiple subplots docstring
_subplots_docstring = """
Add an arbitrary grid of subplots to the figure.
Parameters
----------
%(figure.subplots_params)s
Other parameters
----------------
%(figure.figure)s
%(figure.axes_params)s
Returns
-------
axs : SubplotGrid
The axes instances stored in a `SubplotGrid`.
See also
--------
ultraplot.ui.figure
ultraplot.ui.subplots
ultraplot.figure.Figure.subplot
ultraplot.figure.Figure.add_subplot
ultraplot.gridspec.SubplotGrid
ultraplot.axes.Axes
"""
docstring._snippet_manager["figure.subplots"] = _subplots_docstring
# Single subplot docstring
_subplot_docstring = """
Add a subplot axes to the figure.
Parameters
----------
*args : int, tuple, or `~matplotlib.gridspec.SubplotSpec`, optional
The subplot location specifier. Your options are:
* A single 3-digit integer argument specifying the number of rows,
number of columns, and gridspec number (using row-major indexing).
* Three positional arguments specifying the number of rows, number of
columns, and gridspec number (int) or number range (2-tuple of int).
* A `~matplotlib.gridspec.SubplotSpec` instance generated by indexing
a ultraplot :class:`~ultraplot.gridspec.GridSpec`.
For integer input, the implied geometry must be compatible with the implied
geometry from previous calls -- for example, ``fig.add_subplot(331)`` followed
by ``fig.add_subplot(132)`` is valid because the 1 row of the second input can
be tiled into the 3 rows of the the first input, but ``fig.add_subplot(232)``
will raise an error because 2 rows cannot be tiled into 3 rows. For
`~matplotlib.gridspec.SubplotSpec` input, the `~matplotlig.gridspec.SubplotSpec`
must be derived from the :class:`~ultraplot.gridspec.GridSpec` used in previous calls.
These restrictions arise because we allocate a single,
unique `~Figure.gridspec` for each figure.
number : int, optional
The axes number used for a-b-c labeling. See `~ultraplot.axes.Axes.format` for
details. By default this is incremented automatically based on the other subplots
in the figure. Use e.g. ``number=None`` or ``number=False`` to ensure the subplot
has no a-b-c label. Note the number corresponding to `a` is ``1``, not ``0``.
autoshare : bool, default: True
Whether to automatically share the *x* and *y* axes with subplots spanning the
same rows and columns based on the figure-wide `sharex` and `sharey` settings.
This has no effect if :rcraw:`subplots.share` is ``False`` or if ``sharex=False``
or ``sharey=False`` were passed to the figure.
%(axes.proj)s
%(axes.proj_kw)s
%(axes.backend)s
Other parameters
----------------
%(figure.axes_params)s
See also
--------
ultraplot.figure.Figure.add_axes
ultraplot.figure.Figure.subplots
ultraplot.figure.Figure.add_subplots
"""
docstring._snippet_manager["figure.subplot"] = _subplot_docstring
# Single axes
_axes_docstring = """
Add a non-subplot axes to the figure.
Parameters
----------
rect : 4-tuple of float
The (left, bottom, width, height) dimensions of the axes in
figure-relative coordinates.
%(axes.proj)s
%(axes.proj_kw)s
%(axes.backend)s
Other parameters
----------------
%(figure.axes_params)s
See also
--------
ultraplot.figure.Figure.subplot
ultraplot.figure.Figure.add_subplot
ultraplot.figure.Figure.subplots
ultraplot.figure.Figure.add_subplots
"""
docstring._snippet_manager["figure.axes"] = _axes_docstring
# Colorbar or legend panel docstring
_space_docstring = """
loc : str, optional
The {name} location. Valid location keys are as follows.
%(axes.panel_loc)s
space : float or str, default: None
The fixed space between the {name} and the subplot grid edge.
%(units.em)s
When the :ref:`tight layout algorithm <ug_tight>` is active for the figure,
`space` is computed automatically (see `pad`). Otherwise, `space` is set to
a suitable default.
pad : float or str, default: :rc:`subplots.innerpad` or :rc:`subplots.panelpad`
The :ref:`tight layout padding <ug_tight>` between the {name} and the
subplot grid. Default is :rcraw:`subplots.innerpad` for the first {name}
and :rcraw:`subplots.panelpad` for subsequently "stacked" {name}s.
%(units.em)s
row, rows
Aliases for `span` for {name}s on the left or right side.
col, cols
Aliases for `span` for {name}s on the top or bottom side.
span : int or 2-tuple of int, default: None
Integer(s) indicating the span of the {name} across rows and columns of
subplots. For example, ``fig.{name}(loc='b', col=1)`` draws a {name} beneath
the leftmost column of subplots, and ``fig.{name}(loc='b', cols=(1, 2))``
draws a {name} beneath the left two columns of subplots. By default
the {name} will span every subplot row and column.
align : {{'center', 'top', 't', 'bottom', 'b', 'left', 'l', 'right', 'r'}}, optional
For outer {name}s only. How to align the {name} against the
subplot edge. The values ``'top'`` and ``'bottom'`` are valid for left and
right {name}s and ``'left'`` and ``'right'`` are valid for top and bottom
{name}s. The default is always ``'center'``.
"""
docstring._snippet_manager["figure.legend_space"] = _space_docstring.format(
name="legend"
) # noqa: E501
docstring._snippet_manager["figure.colorbar_space"] = _space_docstring.format(
name="colorbar"
) # noqa: E501
# Save docstring
_save_docstring = """
Save the figure.
Parameters
----------
path : path-like, optional
The file path. User paths are expanded with `os.path.expanduser`.
**kwargs
Passed to `~matplotlib.figure.Figure.savefig`
See also
--------
Figure.save
Figure.savefig
matplotlib.figure.Figure.savefig
"""
docstring._snippet_manager["figure.save"] = _save_docstring
def _get_journal_size(preset):
"""
Return the width and height corresponding to the given preset.
"""
value = JOURNAL_SIZES.get(preset, None)
if value is None:
raise ValueError(
f"Unknown preset figure size specifier {preset!r}. "
"Current options are: " + ", ".join(map(repr, JOURNAL_SIZES.keys()))
)
figwidth = figheight = None
try:
figwidth, figheight = value
except (TypeError, ValueError):
figwidth = value
return figwidth, figheight
def _add_canvas_preprocessor(canvas, method, cache=False):
"""
Return a pre-processer that can be used to override instance-level
canvas draw() and print_figure() methods. This applies tight layout
and aspect ratio-conserving adjustments and aligns labels. Required
so canvas methods instantiate renderers with the correct dimensions.
"""
# NOTE: Renderer must be (1) initialized with the correct figure size or
# (2) changed inplace during draw, but vector graphic renderers *cannot*
# be changed inplace. So options include (1) monkey patch
# canvas.get_width_height, overriding figure.get_size_inches, and exploit
# the FigureCanvasAgg.get_renderer() implementation (because FigureCanvasAgg
# queries the bbox directly rather than using get_width_height() so requires
# workaround), (2) override bbox and bbox_inches as *properties* (but these
# are really complicated, dangerous, and result in unnecessary extra draws),
# or (3) simply override canvas draw methods. Our choice is #3.
def _canvas_preprocess(self, *args, **kwargs):
fig = self.figure # update even if not stale! needed after saves
func = getattr(type(self), method) # the original method
# Bail out if we are already adjusting layout
# NOTE: The _is_adjusting check necessary when inserting new
# gridspec rows or columns with the qt backend.
# NOTE: Return value for macosx _draw is the renderer, for qt draw is
# nothing, and for print_figure is some figure object, but this block
# has never been invoked when calling print_figure.
if fig._is_adjusting:
if method == "_draw": # macosx backend
return fig._get_renderer()
else:
return
# Adjust layout
# NOTE: The authorized_context is needed because some backends disable
# constrained layout or tight layout before printing the figure.
ctx1 = fig._context_adjusting(cache=cache)
ctx2 = fig._context_authorized() # skip backend set_constrained_layout()
ctx3 = rc.context(fig._render_context) # draw with figure-specific setting
with ctx1, ctx2, ctx3:
fig.auto_layout()
return func(self, *args, **kwargs)
# Add preprocessor
setattr(canvas, method, _canvas_preprocess.__get__(canvas))
return canvas
def _clear_border_cache(func):
"""
Decorator that clears the border cache after function execution.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
if hasattr(self, "_cached_border_axes"):
delattr(self, "_cached_border_axes")
return result
return wrapper
[docs]
class Figure(mfigure.Figure):
"""
The `~matplotlib.figure.Figure` subclass used by ultraplot.
"""
# Shared error and warning messages
_share_message = (
"Axis sharing level can be 0 or False (share nothing), "
"1 or 'labels' or 'labs' (share axis labels), "
"2 or 'limits' or 'lims' (share axis limits and axis labels), "
"3 or True (share axis limits, axis labels, and tick labels), "
"or 4 or 'all' (share axis labels and tick labels in the same gridspec "
"rows and columns and share axis limits across all subplots)."
)
_space_message = (
"To set the left, right, bottom, top, wspace, or hspace gridspec values, "
"pass them as keyword arguments to uplt.figure() or uplt.subplots(). Please "
"note they are now specified in physical units, with strings interpreted by "
"uplt.units() and floats interpreted as font size-widths."
)
_tight_message = (
"ultraplot uses its own tight layout algorithm that is activated by default. "
"To disable it, set uplt.rc['subplots.tight'] to False or pass tight=False "
"to uplt.subplots(). For details, see fig.auto_layout()."
)
_warn_interactive = True # disabled after first warning
def __repr__(self):
opts = {}
for attr in ("refaspect", "refwidth", "refheight", "figwidth", "figheight"):
value = getattr(self, "_" + attr)
if value is not None:
opts[attr] = np.round(value, 2)
geom = ""
if self.gridspec:
nrows, ncols = self.gridspec.get_geometry()
geom = f"nrows={nrows}, ncols={ncols}, "
opts = ", ".join(f"{key}={value!r}" for key, value in opts.items())
return f"Figure({geom}{opts})"
# NOTE: If _rename_kwargs argument is an invalid identifier, it is
# simply used in the warning message.
@docstring._obfuscate_kwargs
@docstring._snippet_manager
@warnings._rename_kwargs(
"0.7.0", axpad="innerpad", autoformat="uplt.rc.autoformat = {}"
)
def __init__(
self,
*,
refnum=None,
ref=None,
refaspect=None,
aspect=None,
refwidth=None,
refheight=None,
axwidth=None,
axheight=None,
figwidth=None,
figheight=None,
width=None,
height=None,
journal=None,
sharex=None,
sharey=None,
share=None, # used for default spaces
spanx=None,
spany=None,
span=None,
alignx=None,
aligny=None,
align=None,
left=None,
right=None,
top=None,
bottom=None,
wspace=None,
hspace=None,
space=None,
tight=None,
outerpad=None,
innerpad=None,
panelpad=None,
wpad=None,
hpad=None,
pad=None,
wequal=None,
hequal=None,
equal=None,
wgroup=None,
hgroup=None,
group=None,
**kwargs,
):
"""
Parameters
----------
%(figure.figure)s
Other parameters
----------------
%(figure.format)s
**kwargs
Passed to `matplotlib.figure.Figure`.
See also
--------
Figure.format
ultraplot.ui.figure
ultraplot.ui.subplots
matplotlib.figure.Figure
"""
# Add figure sizing settings
# NOTE: We cannot catpure user-input 'figsize' here because it gets
# automatically filled by the figure manager. See ui.figure().
# NOTE: The figure size is adjusted according to these arguments by the
# canvas preprocessor. Although in special case where both 'figwidth' and
# 'figheight' were passes we update 'figsize' to limit side effects.
refnum = _not_none(refnum=refnum, ref=ref, default=1) # never None
refaspect = _not_none(refaspect=refaspect, aspect=aspect)
refwidth = _not_none(refwidth=refwidth, axwidth=axwidth)
refheight = _not_none(refheight=refheight, axheight=axheight)
figwidth = _not_none(figwidth=figwidth, width=width)
figheight = _not_none(figheight=figheight, height=height)
messages = []
if journal is not None:
jwidth, jheight = _get_journal_size(journal)
if jwidth is not None and figwidth is not None:
messages.append(("journal", journal, "figwidth", figwidth))
if jheight is not None and figheight is not None:
messages.append(("journal", journal, "figheight", figheight))
figwidth = _not_none(jwidth, figwidth)
figheight = _not_none(jheight, figheight)
if figwidth is not None and refwidth is not None:
messages.append(("figwidth", figwidth, "refwidth", refwidth))
refwidth = None
if figheight is not None and refheight is not None:
messages.append(("figheight", figheight, "refheight", refheight))
refheight = None
if (
figwidth is None
and figheight is None
and refwidth is None
and refheight is None
): # noqa: E501
refwidth = rc["subplots.refwidth"] # always inches
if np.iterable(refaspect):
refaspect = refaspect[0] / refaspect[1]
for key1, val1, key2, val2 in messages:
warnings._warn_ultraplot(
f"Got conflicting figure size arguments {key1}={val1!r} and "
f"{key2}={val2!r}. Ignoring {key2!r}."
)
self._refnum = refnum
self._refaspect = refaspect
self._refaspect_default = 1 # updated for imshow and geographic plots
self._refwidth = units(refwidth, "in")
self._refheight = units(refheight, "in")
self._figwidth = figwidth = units(figwidth, "in")
self._figheight = figheight = units(figheight, "in")
# Add special consideration for interactive backends
backend = _not_none(rc.backend, "")
backend = backend.lower()
interactive = "nbagg" in backend or "ipympl" in backend
if not interactive:
pass
elif figwidth is None or figheight is None:
figsize = rc["figure.figsize"] # modified by ultraplot
self._figwidth = figwidth = _not_none(figwidth, figsize[0])
self._figheight = figheight = _not_none(figheight, figsize[1])
self._refwidth = self._refheight = None # critical!
if self._warn_interactive:
Figure._warn_interactive = False # set class attribute
warnings._warn_ultraplot(
"Auto-sized ultraplot figures are not compatible with interactive "
"backends like '%matplotlib widget' and '%matplotlib notebook'. "
f"Reverting to the figure size ({figwidth}, {figheight}). To make "
"auto-sized figures, please consider using the non-interactive "
"(default) backend. This warning message is shown the first time "
"you create a figure without explicitly specifying the size."
)
# Add space settings
# NOTE: This is analogous to 'subplotpars' but we don't worry about
# user mutability. Think it's perfectly fine to ask users to simply
# pass these to uplt.figure() or uplt.subplots(). Also overriding
# 'subplots_adjust' would be confusing since we switch to absolute
# units and that function is heavily used outside of ultraplot.
params = {
"left": left,
"right": right,
"top": top,
"bottom": bottom,
"wspace": wspace,
"hspace": hspace,
"space": space,
"wequal": wequal,
"hequal": hequal,
"equal": equal,
"wgroup": wgroup,
"hgroup": hgroup,
"group": group,
"wpad": wpad,
"hpad": hpad,
"pad": pad,
"outerpad": outerpad,
"innerpad": innerpad,
"panelpad": panelpad,
}
self._gridspec_params = params # used to initialize the gridspec
for key, value in tuple(params.items()):
if not isinstance(value, str) and np.iterable(value) and len(value) > 1:
raise ValueError(
f"Invalid gridspec parameter {key}={value!r}. Space parameters "
"passed to Figure() must be scalar. For vector spaces use "
"GridSpec() or pass space parameters to subplots()."
)
# Add tight layout setting and ignore native settings
pars = kwargs.pop("subplotpars", None)
if pars is not None:
warnings._warn_ultraplot(
f"Ignoring subplotpars={pars!r}. " + self._space_message
)
if kwargs.pop("tight_layout", None):
warnings._warn_ultraplot(
"Ignoring tight_layout=True. " + self._tight_message
)
if kwargs.pop("constrained_layout", None):
warnings._warn_ultraplot(
"Ignoring constrained_layout=True. " + self._tight_message
)
if rc_matplotlib.get("figure.autolayout", False):
warnings._warn_ultraplot(
"Setting rc['figure.autolayout'] to False. " + self._tight_message
)
if rc_matplotlib.get("figure.constrained_layout.use", False):
warnings._warn_ultraplot(
"Setting rc['figure.constrained_layout.use'] to False. "
+ self._tight_message # noqa: E501
)
try:
rc_matplotlib["figure.autolayout"] = False # this is rcParams
except KeyError:
pass
try:
rc_matplotlib["figure.constrained_layout.use"] = False # this is rcParams
except KeyError:
pass
self._tight_active = _not_none(tight, rc["subplots.tight"])
# Translate share settings
translate = {"labels": 1, "labs": 1, "limits": 2, "lims": 2, "all": 4}
sharex = _not_none(sharex, share, rc["subplots.share"])
sharey = _not_none(sharey, share, rc["subplots.share"])
sharex = 3 if sharex is True else translate.get(sharex, sharex)
sharey = 3 if sharey is True else translate.get(sharey, sharey)
if sharex not in range(5):
raise ValueError(f"Invalid sharex={sharex!r}. " + self._share_message)
if sharey not in range(5):
raise ValueError(f"Invalid sharey={sharey!r}. " + self._share_message)
self._sharex = int(sharex)
self._sharey = int(sharey)
# Translate span and align settings
spanx = _not_none(
spanx, span, False if not sharex else None, rc["subplots.span"]
) # noqa: E501
spany = _not_none(
spany, span, False if not sharey else None, rc["subplots.span"]
) # noqa: E501
if spanx and (alignx or align): # only warn when explicitly requested
warnings._warn_ultraplot('"alignx" has no effect when spanx=True.')
if spany and (aligny or align):
warnings._warn_ultraplot('"aligny" has no effect when spany=True.')
self._spanx = bool(spanx)
self._spany = bool(spany)
alignx = _not_none(alignx, align, rc["subplots.align"])
aligny = _not_none(aligny, align, rc["subplots.align"])
self._alignx = bool(alignx)
self._aligny = bool(aligny)
# Initialize the figure
# NOTE: Super labels are stored inside {axes: text} dictionaries
self._gridspec = None
self._panel_dict = {"left": [], "right": [], "bottom": [], "top": []}
self._subplot_dict = {} # subplots indexed by number
self._subplot_counter = 0 # avoid add_subplot() returning an existing subplot
self._is_adjusting = False
self._is_authorized = False
self._includepanels = None
self._render_context = {}
rc_kw, rc_mode = _pop_rc(kwargs)
kw_format = _pop_params(kwargs, self._format_signature)
if figwidth is not None and figheight is not None:
kwargs["figsize"] = (figwidth, figheight)
with self._context_authorized():
super().__init__(**kwargs)
# Super labels. We don't rely on private matplotlib _suptitle attribute and
# _align_axis_labels supports arbitrary spanning labels for subplot groups.
# NOTE: Don't use 'anchor' rotation mode otherwise switching to horizontal
# left and right super labels causes overlap. Current method is fine.
self._suptitle = self.text(0.5, 0.95, "", ha="center", va="bottom")
self._supxlabel_dict = {} # an axes: label mapping
self._supylabel_dict = {} # an axes: label mapping
self._suplabel_dict = {"left": {}, "right": {}, "bottom": {}, "top": {}}
self._share_label_groups = {"x": {}, "y": {}} # explicit label-sharing groups
self._suptitle_pad = rc["suptitle.pad"]
d = self._suplabel_props = {} # store the super label props
d["left"] = {"va": "center", "ha": "right"}
d["right"] = {"va": "center", "ha": "left"}
d["bottom"] = {"va": "top", "ha": "center"}
d["top"] = {"va": "bottom", "ha": "center"}
d = self._suplabel_pad = {} # store the super label padding
d["left"] = rc["leftlabel.pad"]
d["right"] = rc["rightlabel.pad"]
d["bottom"] = rc["bottomlabel.pad"]
d["top"] = rc["toplabel.pad"]
# Format figure
# NOTE: This ignores user-input rc_mode.
self.format(rc_kw=rc_kw, rc_mode=1, skip_axes=True, **kw_format)
[docs]
@override
def draw(self, renderer):
# implement the tick sharing here
# should be shareable --> either all cartesian or all geographic
# but no mixing (panels can be mixed)
# check which ticks are on for x or y and push the labels to the
# outer most on a given column or row.
# we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars
self._share_ticklabels(axis="x")
self._share_ticklabels(axis="y")
self._apply_share_label_groups()
super().draw(renderer)
def _share_ticklabels(self, *, axis: str) -> None:
"""
Tick label sharing is determined at the figure level. While
each subplot controls the limits, we are dealing with the ticklabels
here as the complexity is easier to deal with.
axis: str 'x' or 'y', row or columns to update
"""
if not self.stale:
return
outer_axes = self._get_border_axes()
sides = ("top", "bottom") if axis == "x" else ("left", "right")
# Group axes by row (for x) or column (for y)
axes = list(self._iter_axes(panels=True, hidden=False))
groups = self._group_axes_by_axis(axes, axis)
# Version-dependent label name mapping for reading back params
label_keys = self._label_key_map()
# Process each group independently
for _, group_axes in groups.items():
# Build baseline from MAIN axes only (exclude panels)
baseline, skip_group = self._compute_baseline_tick_state(
group_axes, axis, label_keys
)
if skip_group:
continue
# Apply baseline to all axes in the group (including panels)
for axi in group_axes:
# Respect figure border sides and panel opposite sides
masked = self._apply_border_mask(axi, baseline, sides, outer_axes)
# Determine sharing level for this axes
if self._effective_share_level(axi, axis, sides) < 3:
continue
# Apply to geo/cartesian appropriately
self._set_ticklabel_state(axi, axis, masked)
self.stale = True
def _label_key_map(self):
"""
Return a mapping for version-dependent label keys for Matplotlib tick params.
"""
first_axi = next(self._iter_axes(panels=True), None)
if first_axi is None:
return {
"labelleft": "labelleft",
"labelright": "labelright",
"labeltop": "labeltop",
"labelbottom": "labelbottom",
}
return {
name: first_axi._label_key(name)
for name in ("labelleft", "labelright", "labeltop", "labelbottom")
}
def _group_axes_by_axis(self, axes, axis: str):
"""
Group axes by row (x) or column (y). Panels included; invalid subplotspec skipped.
"""
from collections import defaultdict
def _group_key(ax):
ss = ax.get_subplotspec()
return ss.rowspan.start if axis == "x" else ss.colspan.start
groups = defaultdict(list)
for axi in axes:
try:
key = _group_key(axi)
except Exception:
# If we can't get a subplotspec, skip grouping for this axes
continue
groups[key].append(axi)
return groups
def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys):
"""
Build a baseline ticklabel visibility dict from MAIN axes (panels excluded).
Returns (baseline_dict, skip_group: bool). Emits warnings when encountering
unsupported or mixed subplot types.
"""
baseline = {}
subplot_types = set()
unsupported_found = False
sides = ("top", "bottom") if axis == "x" else ("left", "right")
for axi in group_axes:
# Only main axes "vote"
if getattr(axi, "_panel_side", None):
continue
# Supported axes types
if not isinstance(
axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes)
):
warnings._warn_ultraplot(
f"Tick label sharing not implemented for {type(axi)} subplots."
)
unsupported_found = True
break
subplot_types.add(type(axi))
# Collect label visibility state
if isinstance(axi, paxes.CartesianAxes):
params = getattr(axi, f"{axis}axis").get_tick_params()
for side in sides:
key = label_keys[f"label{side}"]
if params.get(key):
baseline[key] = params[key]
elif isinstance(axi, paxes.GeoAxes):
for side in sides:
key = f"label{side}"
if axi._is_ticklabel_on(key):
baseline[key] = axi._is_ticklabel_on(key)
if unsupported_found:
return {}, True
# We cannot mix types (yet) within a group
if len(subplot_types) > 1:
warnings._warn_ultraplot(
"Tick label sharing not implemented for mixed subplot types."
)
return {}, True
return baseline, False
def _apply_border_mask(
self, axi, baseline: dict, sides: tuple[str, str], outer_axes
):
"""
Apply figure-border constraints and panel opposite-side suppression.
Keeps label key mapping per-axis for cartesian.
"""
from .axes.cartesian import OPPOSITE_SIDE
masked = baseline.copy()
for side in sides:
label = f"label{side}"
if isinstance(axi, paxes.CartesianAxes):
# Use per-axis version-mapped key when writing
label = axi._label_key(label)
# Only keep labels on true figure borders
if axi not in outer_axes[side]:
masked[label] = False
# For panels, suppress labels on their opposite side
if (
getattr(axi, "_panel_side", None)
and OPPOSITE_SIDE[axi._panel_side] == side
):
masked[label] = False
return masked
def _effective_share_level(self, axi, axis: str, sides: tuple[str, str]) -> int:
"""
Compute the effective share level for an axes, considering panel groups and
adjacent panels. Fixes the original variable leak by checking any relevant side.
"""
level = getattr(self, f"_share{axis}")
# If figure-level sharing is disabled (0/False), don't promote due to panels
if not level or (isinstance(level, (int, float)) and level < 1):
return level
# Panel group-level sharing
if getattr(axi, f"_panel_share{axis}_group", None):
return 3
# Panel member sharing
if getattr(axi, "_panel_side", None) and getattr(axi, f"_share{axis}", None):
return 3
# Adjacent panels on any relevant side
panel_dict = getattr(axi, "_panel_dict", {})
for side in sides:
side_panels = panel_dict.get(side) or []
if side_panels and getattr(side_panels[0], f"_share{axis}", False):
return 3
return level
def _set_ticklabel_state(self, axi, axis: str, state: dict):
"""Apply the computed ticklabel state to cartesian or geo axes."""
if state:
# Normalize "x"/"y" values to booleans for both Geo and Cartesian axes
cleaned = {k: (True if v in ("x", "y") else v) for k, v in state.items()}
if isinstance(axi, paxes.GeoAxes):
axi._toggle_gridliner_labels(**cleaned)
else:
getattr(axi, f"{axis}axis").set_tick_params(**cleaned)
def _context_adjusting(self, cache=True):
"""
Prevent re-running auto layout steps due to draws triggered by figure
resizes. Otherwise can get infinite loops.
"""
kw = {"_is_adjusting": True}
if not cache:
kw["_cachedRenderer"] = None # temporarily ignore it
return context._state_context(self, **kw)
def _context_authorized(self):
"""
Prevent warning message when internally calling no-op methods. Otherwise
emit warnings to help new users.
"""
return context._state_context(self, _is_authorized=True)
@staticmethod
def _parse_backend(backend=None, basemap=None):
"""
Handle deprecation of basemap and cartopy package.
"""
# Basemap is currently being developed again so are removing the deprecation warning
if backend == "basemap":
warnings._warn_ultraplot(
f"{backend=} will be deprecated in next major release (v2.0). See https://github.com/Ultraplot/ultraplot/pull/243"
)
return backend
def _parse_proj(
self,
proj=None,
projection=None,
proj_kw=None,
projection_kw=None,
backend=None,
basemap=None,
**kwargs,
):
"""
Translate the user-input projection into a registered matplotlib
axes class. Input projection can be a string, `matplotlib.axes.Axes`,
`cartopy.crs.Projection`, or `mpl_toolkits.basemap.Basemap`.
"""
# Parse arguments
proj = _not_none(proj=proj, projection=projection, default="cartesian")
proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={})
backend = self._parse_backend(backend, basemap)
if isinstance(proj, str):
proj = proj.lower()
if isinstance(self, paxes.Axes):
proj = self._name
elif isinstance(self, maxes.Axes):
raise ValueError("Matplotlib axes cannot be added to ultraplot figures.")
# Search axes projections
name = None
if isinstance(proj, str):
try:
mproj.get_projection_class("ultraplot_" + proj)
except (KeyError, ValueError):
pass
else:
name = proj
# Helpful error message
if (
name is None
and backend is None
and isinstance(proj, str)
and constructor.Projection is object
and constructor.Basemap is object
):
raise ValueError(
f"Invalid projection name {proj!r}. If you are trying to generate a "
"GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap "
"then cartopy or basemap must be installed. Otherwise the known axes "
f"subclasses are:\n{paxes._cls_table}"
)
# Search geographic projections
# NOTE: Also raises errors due to unexpected projection type
if name is None:
proj = constructor.Proj(proj, backend=backend, include_axes=True, **proj_kw)
name = proj._proj_backend
kwargs["map_projection"] = proj
kwargs["projection"] = "ultraplot_" + name
return kwargs
def _get_align_axes(self, side):
"""
Return the main axes along the edge of the figure.
For 'left'/'right': select one extreme axis per row (leftmost/rightmost).
For 'top'/'bottom': select one extreme axis per column (topmost/bottommost).
"""
axs = tuple(self._subplot_dict.values())
if not axs:
return []
if side not in ("left", "right", "top", "bottom"):
raise ValueError(f"Invalid side {side!r}.")
from .utils import _get_subplot_layout
grid = _get_subplot_layout(
self._gridspec, list(self._iter_axes(panels=False, hidden=False))
)[0]
# From the @side we find the first non-zero
# entry in each row or column and collect the axes
if side == "left":
options = grid
elif side == "right":
options = grid[:, ::-1]
elif side == "top":
options = grid.T
else: # bottom
options = grid.T[:, ::-1]
uids = set()
for option in options:
idx = np.where(option != None)[0]
if idx.size > 0:
first = idx.min()
number = option[first].number
uids.add(number)
axs = []
# Collect correct axes
for axi in self._iter_axes():
if axi.number in uids and axi not in axs:
axs.append(axi)
return axs
def _get_border_axes(
self, *, same_type=False, force_recalculate=False
) -> dict[str, list[paxes.Axes]]:
"""
Identifies axes located on the outer boundaries of the GridSpec layout.
Returns a dictionary with keys 'top', 'bottom', 'left', 'right', each
containing a list of axes on that border.
"""
if hasattr(self, "_cached_border_axes") and not force_recalculate:
return self._cached_border_axes
border_axes = dict(
left=[],
right=[],
top=[],
bottom=[],
)
gs = self.gridspec
if gs is None:
return border_axes
all_axes = []
for axi in self._iter_axes(panels=True):
all_axes.append(axi)
# Handle empty cases
nrows, ncols = gs.nrows, gs.ncols
if nrows == 0 or ncols == 0 or not all_axes:
return border_axes
# We cannot use the gridspec on the axes as it
# is modified when a colorbar is added. Use self.gridspec
# as a reference.
# Reconstruct the grid based on axis locations. Note that
# spanning axes will fit into one of the boxes. Check
# this with unittest to see how empty axes are handles
gs = self.axes[0].get_gridspec()
shape = (gs.nrows_total, gs.ncols_total)
grid = np.zeros(shape, dtype=object)
grid.fill(None)
grid_axis_type = np.zeros(shape, dtype=int)
seen_axis_type = dict()
ax_type_mapping = dict()
for axi in self._iter_axes(panels=True, hidden=True):
gs = axi.get_subplotspec()
x, y = np.unravel_index(gs.num1, shape)
span = gs._get_rows_columns()
xleft, xright, yleft, yright = span
xspan = xright - xleft + 1
yspan = yright - yleft + 1
number = axi.number
axis_type = type(axi)
if isinstance(axi, (paxes.GeoAxes)):
axis_type = axi.projection
if axis_type not in seen_axis_type:
seen_axis_type[axis_type] = len(seen_axis_type)
type_number = seen_axis_type[axis_type]
ax_type_mapping[axi] = type_number
if axi.get_visible():
grid[x : x + xspan, y : y + yspan] = axi
grid_axis_type[x : x + xspan, y : y + yspan] = type_number
# We check for all axes is they are a border or not
# Note we could also write the crawler in a way where
# it find the borders by moving around in the grid, without spawning on each axis point. We may change
# this in the future
for axi in all_axes:
axis_type = ax_type_mapping[axi]
number = axi.number
if axi.number is None:
number = -axi._panel_parent.number
crawler = _Crawler(
ax=axi,
grid=grid,
target=number,
axis_type=axis_type,
grid_axis_type=grid_axis_type,
)
for direction, is_border in crawler.find_edges():
if is_border and axi not in border_axes[direction]:
border_axes[direction].append(axi)
self._cached_border_axes = border_axes
return border_axes
def _get_align_coord(self, side, axs, align="center", includepanels=False):
"""
Return the figure coordinate for positioning spanning axis labels or super titles.
Parameters
----------
side : str
Side of the figure ('top', 'bottom', 'left', 'right').
axs : list
List of axes to align across.
align : str, default 'center'
Horizontal alignment for x-axis positioning: 'left', 'center', or 'right'.
For y-axis positioning, always centers regardless of this parameter.
includepanels : bool, default False
Whether to include panel axes in the alignment calculation.
"""
# Get position in figure relative coordinates
if not all(isinstance(ax, paxes.Axes) for ax in axs):
raise RuntimeError("Axes must be ultraplot axes.")
if not all(isinstance(ax, maxes.SubplotBase) for ax in axs):
raise RuntimeError("Axes must be subplots.")
s = "y" if side in ("left", "right") else "x"
axs = [ax._panel_parent or ax for ax in axs] # deflect to main axes
if includepanels: # include panel short axes?
axs = [_ for ax in axs for _ in ax._iter_axes(panels=True, children=False)]
ranges = np.array([ax._range_subplotspec(s) for ax in axs])
min_, max_ = ranges[:, 0].min(), ranges[:, 1].max()
ax_lo = axs[np.where(ranges[:, 0] == min_)[0][0]]
ax_hi = axs[np.where(ranges[:, 1] == max_)[0][0]]
box_lo = ax_lo.get_subplotspec().get_position(self)
box_hi = ax_hi.get_subplotspec().get_position(self)
if s == "x":
# Calculate horizontal position based on alignment preference
if align == "left":
pos = box_lo.x0
elif align == "right":
pos = box_hi.x1
else: # 'center'
pos = 0.5 * (box_lo.x0 + box_hi.x1)
else:
# For vertical positioning, always center between axes
pos = 0.5 * (box_lo.y1 + box_hi.y0) # 'lo' is actually on top of figure
ax = axs[(np.argmin(ranges[:, 0]) + np.argmax(ranges[:, 1])) // 2]
ax = ax._panel_parent or ax # always use main subplot for spanning labels
return pos, ax
def _get_offset_coord(self, side, axs, renderer, *, pad=None, extra=None):
"""
Return the figure coordinate for offsetting super labels and super titles.
"""
s = "x" if side in ("left", "right") else "y"
cs = []
objs = tuple(
_
for ax in axs
for _ in ax._iter_axes(panels=True, children=True, hidden=True)
) # noqa: E501
objs = objs + (extra or ()) # e.g. top super labels
for obj in objs:
bbox = obj.get_tightbbox(renderer) # cannot use cached bbox
attr = s + "max" if side in ("top", "right") else s + "min"
c = getattr(bbox, attr)
c = (c, 0) if side in ("left", "right") else (0, c)
c = self.transFigure.inverted().transform(c)
c = c[0] if side in ("left", "right") else c[1]
cs.append(c)
width, height = self.get_size_inches()
if pad is None:
pad = self._suplabel_pad[side] / 72
pad = pad / width if side in ("left", "right") else pad / height
return min(cs) - pad if side in ("left", "bottom") else max(cs) + pad
def _get_renderer(self):
"""
Get a renderer at all costs. See matplotlib's tight_layout.py.
"""
if hasattr(self, "_cached_render"):
renderer = self._cachedRenderer
else:
canvas = self.canvas
if canvas and hasattr(canvas, "get_renderer"):
renderer = canvas.get_renderer()
else:
from matplotlib.backends.backend_agg import FigureCanvasAgg
canvas = FigureCanvasAgg(self)
renderer = canvas.get_renderer()
return renderer
@_clear_border_cache
def _add_axes_panel(
self,
ax: "paxes.Axes",
side: Optional[str] = None,
span: Optional[Union[int, Tuple[int, int]]] = None,
row: Optional[int] = None,
col: Optional[int] = None,
rows: Optional[Union[int, Tuple[int, int]]] = None,
cols: Optional[Union[int, Tuple[int, int]]] = None,
**kwargs,
) -> "paxes.Axes":
"""
Add an axes panel.
"""
# Interpret args
# NOTE: Axis sharing not implemented for figure panels, 99% of the
# time this is just used as construct for adding global colorbars and
# legends, really not worth implementing axis sharing
ax = ax._altx_parent or ax
ax = ax._alty_parent or ax
if not isinstance(ax, paxes.Axes):
raise RuntimeError("Cannot add panels to non-ultraplot axes.")
if not isinstance(ax, maxes.SubplotBase):
raise RuntimeError("Cannot add panels to non-subplot axes.")
orig = ax._panel_side
if orig is None:
pass
elif side is None or side == orig:
ax, side = ax._panel_parent, orig
else:
raise RuntimeError(f"Cannot add {side!r} panel to existing {orig!r} panel.")
side = _translate_loc(side, "panel", default=_not_none(orig, "right"))
# Add and setup the panel accounting for index changes
# NOTE: Always put tick labels on the 'outside' and permit arbitrary
# keyword arguments passed from the user.
gs = self.gridspec
if not gs:
raise RuntimeError("The gridspec must be active.")
kw = _pop_params(kwargs, gs._insert_panel_slot)
# Validate and determine span override from span/row/col/rows/cols parameters
span_override = None
if side in ("left", "right"):
# Vertical panels: should use rows parameter, not cols
if _not_none(cols, col) is not None and _not_none(rows, row) is None:
raise ValueError(
f"For {side!r} panels (vertical), use 'rows=' or 'row=' "
"to specify span, not 'cols=' or 'col='."
)
if span is not None and _not_none(rows, row) is None:
warnings._warn_ultraplot(
f"For {side!r} panels (vertical), prefer 'rows=' over 'span=' "
"for clarity. Using 'span' as rows."
)
span_override = _not_none(rows, row, span)
else:
# Horizontal panels: should use cols parameter, not rows
if _not_none(rows, row) is not None and _not_none(cols, col, span) is None:
raise ValueError(
f"For {side!r} panels (horizontal), use 'cols=' or 'span=' "
"to specify span, not 'rows=' or 'row='."
)
span_override = _not_none(cols, col, span)
# Pass span_override to gridspec if provided
if span_override is not None:
kw["span_override"] = span_override
ss, share = gs._insert_panel_slot(side, ax, **kw)
# Guard: GeoAxes with non-rectilinear projections cannot share with panels
if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear():
if share:
warnings._warn_ultraplot(
"Panel sharing disabled for non-rectilinear GeoAxes projections."
)
share = False
kwargs["autoshare"] = False
kwargs.setdefault("number", False) # power users might number panels
pax = self.add_subplot(ss, **kwargs)
pax._panel_side = side
pax._panel_share = share
pax._panel_parent = ax
ax._panel_dict[side].append(pax)
ax._apply_auto_share()
axis = pax.yaxis if side in ("left", "right") else pax.xaxis
getattr(axis, "tick_" + side)() # set tick and tick label position
axis.set_label_position(side) # set label position
# Sync limits and formatters with parent when sharing to ensure consistent ticks
# Copy limits for the shared axis
# Note: for non-geo axes this is handled by auto sharing
if share and isinstance(ax, paxes.GeoAxes):
# Align with backend: for GeoAxes, use lon/lat degree formatters on panels.
# Otherwise, copy the parent's axis formatters.
fmt_key = "deglat" if side in ("left", "right") else "deglon"
axis.set_major_formatter(constructor.Formatter(fmt_key))
# Update limits
axis._set_lim(
*getattr(ax, f"get_{'y' if side in ('left','right') else 'x'}lim")(),
auto=True,
)
# Push main axes tick labels to the outside relative to the added panel
# Skip this for filled panels (colorbars/legends)
if not kw.get("filled", False) and share:
if isinstance(ax, paxes.GeoAxes):
if side == "top":
ax._toggle_gridliner_labels(labeltop=False)
elif side == "bottom":
ax._toggle_gridliner_labels(labelbottom=False)
elif side == "left":
ax._toggle_gridliner_labels(labelleft=False)
elif side == "right":
ax._toggle_gridliner_labels(labelright=False)
else:
if side == "top":
ax.xaxis.set_tick_params(**{ax._label_key("labeltop"): False})
elif side == "bottom":
ax.xaxis.set_tick_params(**{ax._label_key("labelbottom"): False})
elif side == "left":
ax.yaxis.set_tick_params(**{ax._label_key("labelleft"): False})
elif side == "right":
ax.yaxis.set_tick_params(**{ax._label_key("labelright"): False})
# Panel labels: prefer outside only for non-sharing top/right; otherwise keep off
if side == "top":
if not share:
pax.xaxis.set_tick_params(
**{
pax._label_key("labeltop"): True,
pax._label_key("labelbottom"): False,
}
)
else:
on = ax.xaxis.get_tick_params()[ax._label_key("labeltop")]
pax.xaxis.set_tick_params(**{pax._label_key("labeltop"): on})
ax.yaxis.set_tick_params(labeltop=False)
elif side == "right":
if not share:
pax.yaxis.set_tick_params(
**{
pax._label_key("labelright"): True,
pax._label_key("labelleft"): False,
}
)
else:
on = ax.yaxis.get_tick_params()[ax._label_key("labelright")]
pax.yaxis.set_tick_params(**{pax._label_key("labelright"): on})
ax.yaxis.set_tick_params(**{ax._label_key("labelright"): False})
return pax
@_clear_border_cache
def _add_figure_panel(
self,
side: Optional[str] = None,
span: Optional[Union[int, Tuple[int, int]]] = None,
row: Optional[int] = None,
col: Optional[int] = None,
rows: Optional[Union[int, Tuple[int, int]]] = None,
cols: Optional[Union[int, Tuple[int, int]]] = None,
**kwargs,
) -> "paxes.Axes":
"""
Add a figure panel.
"""
# Interpret args and enforce sensible keyword args
side = _translate_loc(side, "panel", default="right")
if side in ("left", "right"):
for key, value in (("col", col), ("cols", cols)):
if value is not None:
raise ValueError(f"Invalid keyword {key!r} for {side!r} panel.")
span = _not_none(span=span, row=row, rows=rows)
else:
for key, value in (("row", row), ("rows", rows)):
if value is not None:
raise ValueError(f"Invalid keyword {key!r} for {side!r} panel.")
span = _not_none(span=span, col=col, cols=cols)
# Add and setup panel
# NOTE: This is only called internally by colorbar and legend so
# do not need to pass aribtrary axes keyword arguments.
gs = self.gridspec
if not gs:
raise RuntimeError("The gridspec must be active.")
ss, _ = gs._insert_panel_slot(side, span, filled=True, **kwargs)
pax = self.add_subplot(ss, autoshare=False, number=False)
plist = self._panel_dict[side]
plist.append(pax)
pax._panel_side = side
pax._panel_share = False
pax._panel_parent = None
return pax
@_clear_border_cache
def _add_subplot(self, *args, **kwargs):
"""
The driver function for adding single subplots.
"""
# Parse arguments
kwargs = self._parse_proj(**kwargs)
args = args or (1, 1, 1)
gs = self.gridspec
# Integer arg
if len(args) == 1 and isinstance(args[0], Integral):
if not 111 <= args[0] <= 999:
raise ValueError(f"Input {args[0]} must fall between 111 and 999.")
args = tuple(map(int, str(args[0])))
# Subplot spec
if len(args) == 1 and isinstance(
args[0], (maxes.SubplotBase, mgridspec.SubplotSpec)
):
ss = args[0]
if isinstance(ss, maxes.SubplotBase):
ss = ss.get_subplotspec()
if gs is None:
gs = ss.get_topmost_subplotspec().get_gridspec()
if not isinstance(gs, pgridspec.GridSpec):
raise ValueError(
"Input subplotspec must be derived from a ultraplot.GridSpec."
)
if ss.get_topmost_subplotspec().get_gridspec() is not gs:
raise ValueError(
"Input subplotspec must be derived from the active figure gridspec."
)
# Row and column spec
# TODO: How to pass spacing parameters to gridspec? Consider overriding
# subplots adjust? Or require using gridspec manually?
elif (
len(args) == 3
and all(isinstance(arg, Integral) for arg in args[:2])
and all(isinstance(arg, Integral) for arg in np.atleast_1d(args[2]))
):
nrows, ncols, num = args
i, j = np.resize(num, 2)
if gs is None:
gs = pgridspec.GridSpec(nrows, ncols)
orows, ocols = gs.get_geometry()
if orows % nrows:
raise ValueError(
f"The input number of rows {nrows} does not divide the "
f"figure gridspec number of rows {orows}."
)
if ocols % ncols:
raise ValueError(
f"The input number of columns {ncols} does not divide the "
f"figure gridspec number of columns {ocols}."
)
if any(_ < 1 or _ > nrows * ncols for _ in (i, j)):
raise ValueError(
"The input subplot indices must fall between "
f"1 and {nrows * ncols}. Instead got {i} and {j}."
)
rowfact, colfact = orows // nrows, ocols // ncols
irow, icol = divmod(i - 1, ncols) # convert to zero-based
jrow, jcol = divmod(j - 1, ncols)
irow, icol = irow * rowfact, icol * colfact
jrow, jcol = (jrow + 1) * rowfact - 1, (jcol + 1) * colfact - 1
ss = gs[irow : jrow + 1, icol : jcol + 1]
# Otherwise
else:
raise ValueError(f"Invalid add_subplot positional arguments {args!r}.")
# Add the subplot
# NOTE: Pass subplotspec as keyword arg for mpl >= 3.4 workaround
# NOTE: Must assign unique label to each subplot or else subsequent calls
# to add_subplot() in mpl < 3.4 may return an already-drawn subplot in the
# wrong location due to gridspec override. Is against OO package design.
self.gridspec = gs # trigger layout adjustment
self._subplot_counter += 1 # unique label for each subplot
kwargs.setdefault("label", f"subplot_{self._subplot_counter}")
kwargs.setdefault("number", 1 + max(self._subplot_dict, default=0))
kwargs.pop("refwidth", None) # TODO: remove this
ax = super().add_subplot(ss, _subplot_spec=ss, **kwargs)
# Allow sharing for GeoAxes if rectilinear
if self._sharex or self._sharey:
if len(self.axes) > 1 and isinstance(ax, paxes.GeoAxes):
# Compare it with a reference
ref = next(self._iter_axes(hidden=False, children=False, panels=False))
unshare = False
if not ax._is_rectilinear():
unshare = True
elif hasattr(ax, "projection") and hasattr(ref, "projection"):
if ax.projection != ref.projection:
unshare = True
if unshare:
self._unshare_axes()
# Only warn once. Note, if axes are reshared
# the warning is not reset. This is however,
# very unlikely to happen as GeoAxes are not
# typically shared and unshared.
warnings._warn_ultraplot(
f"GeoAxes can only be shared for rectilinear projections, {ax.projection=} is not a rectilinear projection."
)
if ax.number:
self._subplot_dict[ax.number] = ax
return ax
def _unshare_axes(self):
for which in "xyz":
self._toggle_axis_sharing(which=which, share=False)
# Force setting extent
# This is necessary to ensure that the axes are properly
# aligned and we don't get weird scaling issues for
# geographic axes. This action is expensive for GeoAxes
for ax in self.axes:
if isinstance(ax, paxes.GeoAxes) and hasattr(ax, "set_global"):
ax.set_global()
def _toggle_axis_sharing(
self,
*,
which="y",
share=True,
panels=False,
children=False,
hidden=False,
):
"""
Share or unshare axes in the figure along a given direction.
Parameters:
- which: 'x', 'y', 'z', or 'view'.
- share: int indicating the levels (see above)
- panels: Whether to include panel axes.
- children: Whether to include child axes.
- hidden: Whether to include hidden axes.
"""
if which not in ("x", "y", "z", "view"):
warnings._warn_ultraplot(
f"Attempting to (un)share {which=}. Options are ('x', 'y', 'z', 'view')"
)
return
axes = list(self._iter_axes(hidden=hidden, children=children, panels=panels))
if which == "x":
self._sharex = share
elif which == "y":
self._sharey = share
# Unshare first if needed
if share == 0:
for ax in axes:
ax._unshare(which=which)
return
# Grouping logic based on GridSpec
def get_key(ax):
ss = ax.get_subplotspec()
if which == "x":
return ss.rowspan.start # same row
elif which == "y":
return ss.colspan.start # same col
# Create groups of axes that should share
groups = {}
for ax in axes:
key = get_key(ax)
groups.setdefault(key, []).append(ax)
# Re-join axes per group
for group in groups.values():
ref = group[0]
for other in group[1:]:
ref._shared_axes[which].join(ref, other)
# The following manual adjustments are necessary because the
# join method does not automatically propagate the sharing state
# and axis properties to the other axes. This ensures that the
# shared axes behave consistently.
if which == "x":
other._sharex = ref
other.xaxis.major = ref.xaxis.major
other.xaxis.minor = ref.xaxis.minor
lim = ref.get_xlim()
other.set_xlim(*lim, emit=False, auto=ref.get_autoscalex_on())
other.xaxis._scale = ref.xaxis._scale
if which == "y":
# This logic is from sharey
other._sharey = ref
other.yaxis.major = ref.yaxis.major
other.yaxis.minor = ref.yaxis.minor
lim = ref.get_ylim()
other.set_ylim(*lim, emit=False, auto=ref.get_autoscaley_on())
other.yaxis._scale = ref.yaxis._scale
def _add_subplots(
self,
array=None,
nrows=1,
ncols=1,
order="C",
proj=None,
projection=None,
proj_kw=None,
projection_kw=None,
backend=None,
basemap=None,
**kwargs,
):
"""
The driver function for adding multiple subplots.
"""
# Clunky helper function
# TODO: Consider deprecating and asking users to use add_subplot()
def _axes_dict(naxs, input, kw=False, default=None):
# First build up dictionary
if not kw: # 'string' or {1: 'string1', (2, 3): 'string2'}
if np.iterable(input) and not isinstance(input, (str, dict)):
input = {num + 1: item for num, item in enumerate(input)}
elif not isinstance(input, dict):
input = {range(1, naxs + 1): input}
else: # {key: value} or {1: {key: value1}, (2, 3): {key: value2}}
nested = [isinstance(_, dict) for _ in input.values()]
if not any(nested): # any([]) == False
input = {range(1, naxs + 1): input.copy()}
elif not all(nested):
raise ValueError(f"Invalid input {input!r}.")
# Unfurl keys that contain multiple axes numbers
output = {}
for nums, item in input.items():
nums = np.atleast_1d(nums)
for num in nums.flat:
output[num] = item.copy() if kw else item
# Fill with default values
for num in range(1, naxs + 1):
if num not in output:
output[num] = {} if kw else default
if output.keys() != set(range(1, naxs + 1)):
raise ValueError(
f"Have {naxs} axes, but {input!r} includes props for the axes: "
+ ", ".join(map(repr, sorted(output)))
+ "."
)
return output
# Build the subplot array
# NOTE: Currently this may ignore user-input nrows/ncols without warning
if order not in ("C", "F"): # better error message
raise ValueError(f"Invalid order={order!r}. Options are 'C' or 'F'.")
gs = None
if array is None or isinstance(array, mgridspec.GridSpec):
if array is not None:
gs, nrows, ncols = array, array.nrows, array.ncols
array = np.arange(1, nrows * ncols + 1)[..., None]
array = array.reshape((nrows, ncols), order=order)
else:
array = np.atleast_1d(array)
array[array == None] = 0 # None or 0 both valid placeholders # noqa: E711
array = array.astype(int)
if array.ndim == 1: # interpret as single row or column
array = array[None, :] if order == "C" else array[:, None]
elif array.ndim != 2:
raise ValueError(f"Expected 1D or 2D array of integers. Got {array}.")
# Parse input format, gridspec, and projection arguments
# NOTE: Permit figure format keywords for e.g. 'collabels' (more intuitive)
nums = np.unique(array[array != 0])
naxs = len(nums)
if any(num < 0 or not isinstance(num, Integral) for num in nums.flat):
raise ValueError(f"Expected array of positive integers. Got {array}.")
proj = _not_none(projection=projection, proj=proj)
proj = _axes_dict(naxs, proj, kw=False, default="cartesian")
proj_kw = _not_none(projection_kw=projection_kw, proj_kw=proj_kw) or {}
proj_kw = _axes_dict(naxs, proj_kw, kw=True)
backend = self._parse_backend(backend, basemap)
backend = _axes_dict(naxs, backend, kw=False)
axes_kw = {
num: {"proj": proj[num], "proj_kw": proj_kw[num], "backend": backend[num]}
for num in proj
}
for key in ("gridspec_kw", "subplot_kw"):
kw = kwargs.pop(key, None)
if not kw:
continue
warnings._warn_ultraplot(
f"{key!r} is not necessary in ultraplot. Pass the "
"parameters as keyword arguments instead."
)
kwargs.update(kw or {})
figure_kw = _pop_params(kwargs, self._format_signature)
gridspec_kw = _pop_params(kwargs, pgridspec.GridSpec._update_params)
# Create or update the gridspec and add subplots with subplotspecs
# NOTE: The gridspec is added to the figure when we pass the subplotspec
if gs is None:
gs = pgridspec.GridSpec(*array.shape, **gridspec_kw)
else:
gs.update(**gridspec_kw)
axs = naxs * [None] # list of axes
axids = [np.where(array == i) for i in np.sort(np.unique(array)) if i > 0]
axcols = np.array([[x.min(), x.max()] for _, x in axids])
axrows = np.array([[y.min(), y.max()] for y, _ in axids])
for idx in range(naxs):
num = idx + 1
x0, x1 = axcols[idx, 0], axcols[idx, 1]
y0, y1 = axrows[idx, 0], axrows[idx, 1]
ss = gs[y0 : y1 + 1, x0 : x1 + 1]
kw = {**kwargs, **axes_kw[num], "number": num}
axs[idx] = self.add_subplot(ss, **kw)
self.format(skip_axes=True, **figure_kw)
return pgridspec.SubplotGrid(axs)
def _align_axis_label(self, x):
"""
Align *x* and *y* axis labels in the perpendicular and parallel directions.
"""
# NOTE: Always use 'align' if 'span' is True to get correct offset
# NOTE: Must trigger axis sharing here so that super label alignment
# with tight=False is valid. Kind of kludgey but oh well.
seen = set()
span = getattr(self, "_span" + x)
align = getattr(self, "_align" + x)
for ax in self._subplot_dict.values():
if isinstance(ax, paxes.CartesianAxes):
ax._apply_axis_sharing() # always!
else:
continue
pos = getattr(ax, x + "axis").get_label_position()
if ax in seen or pos not in ("bottom", "left"):
continue # already aligned or cannot align
axs = ax._get_span_axes(pos, panels=False) # returns panel or main axes
if self._has_share_label_groups(x) and any(
self._is_share_label_group_member(axi, x) for axi in axs
):
continue # explicit label groups override default spanning
if any(getattr(ax, "_share" + x) for ax in axs):
continue # nothing to align or axes have parents
seen.update(axs)
if span or align:
if hasattr(self, "_align_label_groups"):
group = self._align_label_groups[x]
else:
group = getattr(self, "_align_" + x + "label_grp", None)
if group is not None: # fail silently to avoid fragile API changes
for ax in axs[1:]:
group.join(axs[0], ax) # add to grouper
if span:
self._update_axis_label(pos, axs)
# Apply explicit label-sharing groups for this axis
self._apply_share_label_groups(axis=x)
def _register_share_label_group(self, axes, *, target, source=None):
"""
Register an explicit label-sharing group for a subset of axes.
"""
if not axes:
return
axes = list(axes)
axes = [ax for ax in axes if ax is not None and ax.figure is self]
if len(axes) < 2:
return
# Preserve order while de-duplicating
seen = set()
unique = []
for ax in axes:
ax_id = id(ax)
if ax_id in seen:
continue
seen.add(ax_id)
unique.append(ax)
axes = unique
if len(axes) < 2:
return
# Split by label side if mixed
axes_by_side = {}
if target == "x":
for ax in axes:
axes_by_side.setdefault(ax.xaxis.get_label_position(), []).append(ax)
else:
for ax in axes:
axes_by_side.setdefault(ax.yaxis.get_label_position(), []).append(ax)
if len(axes_by_side) > 1:
for side, side_axes in axes_by_side.items():
side_source = source if source in side_axes else None
self._register_share_label_group_for_side(
side_axes, target=target, side=side, source=side_source
)
return
side, side_axes = next(iter(axes_by_side.items()))
self._register_share_label_group_for_side(
side_axes, target=target, side=side, source=source
)
def _register_share_label_group_for_side(self, axes, *, target, side, source=None):
"""
Register a single label-sharing group for a given label side.
"""
if not axes:
return
axes = [ax for ax in axes if ax is not None and ax.figure is self]
if len(axes) < 2:
return
# Prefer label text from the source axes if available
label = None
if source in axes:
candidate = getattr(source, f"{target}axis").label
if candidate.get_text().strip():
label = candidate
if label is None:
for ax in axes:
candidate = getattr(ax, f"{target}axis").label
if candidate.get_text().strip():
label = candidate
break
text = label.get_text() if label else ""
props = None
if label is not None:
props = {
"color": label.get_color(),
"fontproperties": label.get_font_properties(),
"rotation": label.get_rotation(),
"rotation_mode": label.get_rotation_mode(),
"ha": label.get_ha(),
"va": label.get_va(),
}
group_key = tuple(sorted(id(ax) for ax in axes))
groups = self._share_label_groups[target]
group = groups.get(group_key)
if group is None:
groups[group_key] = {
"axes": axes,
"side": side,
"text": text if text.strip() else "",
"props": props,
}
else:
group["axes"] = axes
group["side"] = side
if text.strip():
group["text"] = text
group["props"] = props
def _is_share_label_group_member(self, ax, axis):
"""
Return True if the axes belongs to any explicit label-sharing group.
"""
groups = self._share_label_groups.get(axis, {})
return any(ax in group["axes"] for group in groups.values())
def _has_share_label_groups(self, axis):
"""
Return True if there are any explicit label-sharing groups for an axis.
"""
return bool(self._share_label_groups.get(axis, {}))
def _clear_share_label_groups(self, axes=None, *, target=None):
"""
Clear explicit label-sharing groups, optionally filtered by axes.
"""
targets = ("x", "y") if target is None else (target,)
for axis in targets:
groups = self._share_label_groups.get(axis, {})
if axes is None:
groups.clear()
continue
axes_set = {ax for ax in axes if ax is not None}
for key in list(groups):
if any(ax in axes_set for ax in groups[key]["axes"]):
del groups[key]
# Clear any existing spanning labels tied to these axes
if axis == "x":
for ax in axes_set:
if ax in self._supxlabel_dict:
self._supxlabel_dict[ax].set_text("")
else:
for ax in axes_set:
if ax in self._supylabel_dict:
self._supylabel_dict[ax].set_text("")
def _apply_share_label_groups(self, axis=None):
"""
Apply explicit label-sharing groups, overriding default label sharing.
"""
def _order_axes_for_side(axs, side):
if side in ("bottom", "top"):
key = (
(lambda ax: ax._range_subplotspec("y")[1])
if side == "bottom"
else (lambda ax: ax._range_subplotspec("y")[0])
)
reverse = side == "bottom"
else:
key = (
(lambda ax: ax._range_subplotspec("x")[1])
if side == "right"
else (lambda ax: ax._range_subplotspec("x")[0])
)
reverse = side == "right"
try:
return sorted(axs, key=key, reverse=reverse)
except Exception:
return list(axs)
axes = (axis,) if axis in ("x", "y") else ("x", "y")
for target in axes:
groups = self._share_label_groups.get(target, {})
for group in groups.values():
axs = [
ax for ax in group["axes"] if ax.figure is self and ax.get_visible()
]
if len(axs) < 2:
continue
side = group["side"]
ordered_axs = _order_axes_for_side(axs, side)
# Refresh label text from any axis with non-empty text
label = None
for ax in ordered_axs:
candidate = getattr(ax, f"{target}axis").label
if candidate.get_text().strip():
label = candidate
break
text = group["text"]
props = group["props"]
if label is not None:
text = label.get_text()
props = {
"color": label.get_color(),
"fontproperties": label.get_font_properties(),
"rotation": label.get_rotation(),
"rotation_mode": label.get_rotation_mode(),
"ha": label.get_ha(),
"va": label.get_va(),
}
group["text"] = text
group["props"] = props
if not text:
continue
try:
_, ax = self._get_align_coord(
side, ordered_axs, includepanels=self._includepanels
)
except Exception:
continue
axlab = getattr(ax, f"{target}axis").label
axlab.set_text(text)
if props is not None:
axlab.set_color(props["color"])
axlab.set_fontproperties(props["fontproperties"])
axlab.set_rotation(props["rotation"])
axlab.set_rotation_mode(props["rotation_mode"])
axlab.set_ha(props["ha"])
axlab.set_va(props["va"])
self._update_axis_label(side, ordered_axs)
def _align_super_labels(self, side, renderer):
"""
Adjust the position of super labels.
"""
# NOTE: Ensure title is offset only here.
for ax in self._subplot_dict.values():
ax._apply_title_above()
if side not in ("left", "right", "bottom", "top"):
raise ValueError(f"Invalid side {side!r}.")
labs = self._suplabel_dict[side]
axs = tuple(ax for ax, lab in labs.items() if lab.get_text())
if not axs:
return
c = self._get_offset_coord(side, axs, renderer)
for lab in labs.values():
s = "x" if side in ("left", "right") else "y"
lab.update({s: c})
def _align_super_title(self, renderer):
"""
Adjust the position of the super title based on user alignment preferences.
Respects horizontal and vertical alignment settings from suptitle_kw parameters,
while applying sensible defaults when no custom alignment is provided.
"""
if not self._suptitle.get_text():
return
axs = self._get_align_axes("top") # returns outermost panels
if not axs:
return
labs = tuple(t for t in self._suplabel_dict["top"].values() if t.get_text())
pad = (self._suptitle_pad / 72) / self.get_size_inches()[1]
# Get current alignment settings from suptitle (may be set via suptitle_kw)
ha = self._suptitle.get_ha()
va = self._suptitle.get_va()
# Use original centering algorithm for positioning (regardless of alignment)
x, _ = self._get_align_coord(
"top",
axs,
includepanels=self._includepanels,
align=ha,
)
y = self._get_offset_coord("top", axs, renderer, pad=pad, extra=labs)
# Set final position and alignment on the suptitle
self._suptitle.set_ha(ha)
self._suptitle.set_va(va)
self._suptitle.set_position((x, y))
def _update_axis_label(self, side, axs):
"""
Update the aligned axis label for the input axes.
"""
# Get the central axis and the spanning label (initialize if it does not exist)
# NOTE: Previously we secretly used matplotlib axis labels for spanning labels,
# offsetting them between two subplots if necessary. Now we track designated
# 'super' labels and replace the actual labels with spaces so they still impact
# the tight bounding box and thus allocate space for the spanning label.
x, y = "xy" if side in ("bottom", "top") else "yx"
c, ax = self._get_align_coord(side, axs, includepanels=self._includepanels)
axlab = getattr(ax, x + "axis").label # the central label
suplabs = getattr(self, "_sup" + x + "label_dict") # dict of spanning labels
suplab = suplabs.get(ax, None)
if suplab is None and not axlab.get_text().strip():
return # nothing to transfer from the normal label
if suplab is not None and not suplab.get_text().strip():
return # nothing to update on the super label
if suplab is None:
props = ("ha", "va", "rotation", "rotation_mode")
suplab = suplabs[ax] = self.text(0, 0, "")
suplab.update({prop: getattr(axlab, "get_" + prop)() for prop in props})
# Copy text from the central label to the spanning label
# NOTE: Must use spaces rather than newlines, otherwise tight layout
# won't make room. Reason is Text implementation (see Text._get_layout())
labels._transfer_label(axlab, suplab) # text, color, and font properties
count = 1 + suplab.get_text().count("\n")
space = "\n".join(" " * count)
for ax in axs: # includes original 'axis'
axis = getattr(ax, x + "axis")
axis.label.set_text(space)
# Update spanning label position then add simple monkey patch
# NOTE: Simply using axis._update_label_position() when this is
# called is not sufficient. Fails with e.g. inline backend.
t = mtransforms.IdentityTransform() # set in pixels
cx, cy = axlab.get_position()
if x == "x":
trans = mtransforms.blended_transform_factory(self.transFigure, t)
coord = (c, cy)
else:
trans = mtransforms.blended_transform_factory(t, self.transFigure)
coord = (cx, c)
suplab.set_transform(trans)
suplab.set_position(coord)
setpos = getattr(mtext.Text, "set_" + y)
def _set_coord(self, *args, **kwargs): # noqa: E306
setpos(self, *args, **kwargs)
setpos(suplab, *args, **kwargs)
setattr(axlab, "set_" + y, _set_coord.__get__(axlab))
def _update_super_labels(self, side, labels, **kwargs):
"""
Assign the figure super labels and update settings.
"""
# Update the label parameters
if side not in ("left", "right", "bottom", "top"):
raise ValueError(f"Invalid side {side!r}.")
kw = rc.fill(
{
"color": side + "label.color",
"rotation": side + "label.rotation",
"size": side + "label.size",
"weight": side + "label.weight",
"family": "font.family",
},
context=True,
)
kw.update(kwargs) # used when updating *existing* labels
props = self._suplabel_props[side]
props.update(kw) # used when creating *new* labels
# Get the label axes
# WARNING: In case users added labels then changed the subplot geometry we
# have to remove labels whose axes don't match the current 'align' axes.
axs = self._get_align_axes(side)
if not axs:
return # occurs if called while adding axes
if not labels:
labels = [None for _ in axs] # indicates that text should not be updated
if not kw and all(_ is None for _ in labels):
return # nothing to update
if len(labels) != len(axs):
raise ValueError(
f"Got {len(labels)} {side} labels but found {len(axs)} axes "
f"along the {side} side of the figure."
)
src = self._suplabel_dict[side]
extra = src.keys() - set(axs)
for ax in extra: # e.g. while adding axes
text = src[ax].get_text()
if text:
warnings._warn_ultraplot(
f"Removing {side} label with text {text!r} from axes {ax.number}."
)
src[ax].remove() # remove from the figure
# Update the label text
tf = self.transFigure
for ax, label in zip(axs, labels):
if ax in src:
obj = src[ax]
elif side in ("left", "right"):
trans = mtransforms.blended_transform_factory(tf, ax.transAxes)
obj = src[ax] = self.text(0, 0.5, "", transform=trans)
obj.update(props)
else:
trans = mtransforms.blended_transform_factory(ax.transAxes, tf)
obj = src[ax] = self.text(0.5, 0, "", transform=trans)
obj.update(props)
if kw:
obj.update(kw)
if label is not None:
obj.set_text(label)
def _update_super_title(self, title, **kwargs):
"""
Assign the figure super title and update settings.
"""
kw = rc.fill(
{
"size": "suptitle.size",
"weight": "suptitle.weight",
"color": "suptitle.color",
"family": "font.family",
},
context=True,
)
kw.update(kwargs)
if kw:
self._suptitle.update(kw)
if title is not None:
self._suptitle.set_text(title)
[docs]
@_clear_border_cache
@docstring._concatenate_inherited
@docstring._snippet_manager
def add_axes(self, rect, **kwargs):
"""
%(figure.axes)s
"""
kwargs = self._parse_proj(**kwargs)
return super().add_axes(rect, **kwargs)
[docs]
@docstring._concatenate_inherited
@docstring._snippet_manager
def add_subplot(self, *args, **kwargs):
"""
%(figure.subplot)s
"""
return self._add_subplot(*args, **kwargs)
[docs]
@docstring._snippet_manager
def subplot(self, *args, **kwargs): # shorthand
"""
%(figure.subplot)s
"""
return self._add_subplot(*args, **kwargs)
[docs]
@docstring._snippet_manager
def add_subplots(self, *args, **kwargs):
"""
%(figure.subplots)s
"""
return self._add_subplots(*args, **kwargs)
[docs]
@docstring._snippet_manager
def subplots(self, *args, **kwargs):
"""
%(figure.subplots)s
"""
return self._add_subplots(*args, **kwargs)
[docs]
def auto_layout(self, renderer=None, aspect=None, tight=None, resize=None):
"""
Automatically adjust the figure size and subplot positions. This is
triggered automatically whenever the figure is drawn.
Parameters
----------
renderer : `~matplotlib.backend_bases.RendererBase`, optional
The renderer. If ``None`` a default renderer will be produced.
aspect : bool, optional
Whether to update the figure size based on the reference subplot aspect
ratio. By default, this is ``True``. This only has an effect if the
aspect ratio is fixed (e.g., due to an image plot or geographic projection).
tight : bool, optional
Whether to update the figuer size and subplot positions according to
a "tight layout". By default, this takes on the value of `tight` passed
to `Figure`. If nothing was passed, it is :rc:`subplots.tight`.
resize : bool, optional
If ``False``, the current figure dimensions are fixed and automatic
figure resizing is disabled. By default, the figure size may change
unless both `figwidth` and `figheight` or `figsize` were passed
to `~Figure.subplots`, `~Figure.set_size_inches` was called manually,
or the figure was resized manually with an interactive backend.
"""
# *Impossible* to get notebook backend to work with auto resizing so we
# just do the tight layout adjustments and skip resizing.
gs = self.gridspec
renderer = self._get_renderer()
if aspect is None:
aspect = True
if tight is None:
tight = self._tight_active
if resize is False: # fix the size
self._figwidth, self._figheight = self.get_size_inches()
self._refwidth = self._refheight = None # critical!
# Helper functions
# NOTE: Have to draw legends and colorbars early (before reaching axes
# draw methods) because we have to take them into account for alignment.
# Also requires another figure resize (which triggers a gridspec update).
def _draw_content():
for ax in self._iter_axes(hidden=False, children=True):
ax._add_queued_guides() # may trigger resizes if panels are added
def _align_content(): # noqa: E306
for axis in "xy":
self._align_axis_label(axis)
for side in ("left", "right", "top", "bottom"):
self._align_super_labels(side, renderer)
self._align_super_title(renderer)
# Update the layout
# WARNING: Tried to avoid two figure resizes but made
# subsequent tight layout really weird. Have to resize twice.
_draw_content()
if not gs:
return
if aspect:
gs._auto_layout_aspect()
_align_content()
if tight:
gs._auto_layout_tight(renderer)
_align_content()
[docs]
@warnings._rename_kwargs(
"0.10.0", mathtext_fallback="uplt.rc.mathtext_fallback = {}"
)
@docstring._snippet_manager
def format(
self,
axs=None,
*,
figtitle=None,
suptitle=None,
suptitle_kw=None,
llabels=None,
leftlabels=None,
leftlabels_kw=None,
rlabels=None,
rightlabels=None,
rightlabels_kw=None,
blabels=None,
bottomlabels=None,
bottomlabels_kw=None,
tlabels=None,
toplabels=None,
toplabels_kw=None,
rowlabels=None,
collabels=None, # aliases
includepanels=None,
**kwargs,
):
"""
Modify figure-wide labels and call ``format`` for the
input axes. By default the numbered subplots are used.
Parameters
----------
axs : sequence of `~ultraplot.axes.Axes`, optional
The axes to format. Default is the numbered subplots.
%(figure.format)s
Important
---------
`leftlabelpad`, `toplabelpad`, `rightlabelpad`, and `bottomlabelpad`
keywords are actually :ref:`configuration settings <ug_config>`.
We explicitly document these arguments here because it is common to
change them for specific figures. But many :ref:`other configuration
settings <ug_format>` can be passed to ``format`` too.
Other parameters
----------------
%(axes.format)s
%(cartesian.format)s
%(polar.format)s
%(geo.format)s
%(rc.format)s
See also
--------
ultraplot.axes.Axes.format
ultraplot.axes.CartesianAxes.format
ultraplot.axes.PolarAxes.format
ultraplot.axes.GeoAxes.format
ultraplot.gridspec.SubplotGrid.format
ultraplot.config.Configurator.context
"""
# Initiate context block
axs = axs or self._subplot_dict.values()
skip_axes = kwargs.pop("skip_axes", False) # internal keyword arg
rc_kw, rc_mode = _pop_rc(kwargs)
with rc.context(rc_kw, mode=rc_mode):
# Update background patch
kw = rc.fill({"facecolor": "figure.facecolor"}, context=True)
self.patch.update(kw)
# Update super title and label spacing
pad = rc.find("suptitle.pad", context=True) # super title
if pad is not None:
self._suptitle_pad = pad
for side in tuple(self._suplabel_pad): # super labels
pad = rc.find(side + "label.pad", context=True)
if pad is not None:
self._suplabel_pad[side] = pad
if includepanels is not None:
self._includepanels = includepanels
# Update super title and labels text and settings
suptitle_kw = suptitle_kw or {}
leftlabels_kw = leftlabels_kw or {}
rightlabels_kw = rightlabels_kw or {}
bottomlabels_kw = bottomlabels_kw or {}
toplabels_kw = toplabels_kw or {}
self._update_super_title(
_not_none(figtitle=figtitle, suptitle=suptitle),
**suptitle_kw,
)
self._update_super_labels(
"left",
_not_none(rowlabels=rowlabels, leftlabels=leftlabels, llabels=llabels),
**leftlabels_kw,
)
self._update_super_labels(
"right",
_not_none(rightlabels=rightlabels, rlabels=rlabels),
**rightlabels_kw,
)
self._update_super_labels(
"bottom",
_not_none(bottomlabels=bottomlabels, blabels=blabels),
**bottomlabels_kw,
)
self._update_super_labels(
"top",
_not_none(collabels=collabels, toplabels=toplabels, tlabels=tlabels),
**toplabels_kw,
)
# Update the main axes
if skip_axes: # avoid recursion
return
# Remove all keywords that are not in the allowed signature parameters
kws = {
cls: _pop_params(kwargs, sig)
for cls, sig in paxes.Axes._format_signatures.items()
}
classes = set() # track used dictionaries
def _axis_has_share_label_text(ax, axis):
groups = self._share_label_groups.get(axis, {})
for group in groups.values():
if ax in group["axes"] and str(group.get("text", "")).strip():
return True
return False
def _axis_has_label_text(ax, axis):
text = ax.get_xlabel() if axis == "x" else ax.get_ylabel()
return bool(text and text.strip())
for number, ax in enumerate(axs):
number = number + 1 # number from 1
store_old_number = ax.number
if ax.number != number:
ax.number = number
kw = {
key: value
for cls, kw in kws.items()
for key, value in kw.items()
if isinstance(ax, cls) and not classes.add(cls)
}
if kw.get("xlabel") is not None and self._has_share_label_groups("x"):
if _axis_has_share_label_text(ax, "x") or _axis_has_label_text(ax, "x"):
kw.pop("xlabel", None)
if kw.get("ylabel") is not None and self._has_share_label_groups("y"):
if _axis_has_share_label_text(ax, "y") or _axis_has_label_text(ax, "y"):
kw.pop("ylabel", None)
ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs)
ax.number = store_old_number
# Warn unused keyword argument(s)
kw = {
key: value
for name in kws.keys() - classes
for key, value in kws[name].items()
}
if kw:
warnings._warn_ultraplot(
f"Ignoring unused projection-specific format() keyword argument(s): {kw}" # noqa: E501
)
[docs]
@docstring._concatenate_inherited
@docstring._snippet_manager
def colorbar(
self,
mappable,
values=None,
loc: Optional[str] = None,
location: Optional[str] = None,
row: Optional[int] = None,
col: Optional[int] = None,
rows: Optional[Union[int, Tuple[int, int]]] = None,
cols: Optional[Union[int, Tuple[int, int]]] = None,
span: Optional[Union[int, Tuple[int, int]]] = None,
space: Optional[Union[float, str]] = None,
pad: Optional[Union[float, str]] = None,
width: Optional[Union[float, str]] = None,
**kwargs,
):
"""
Add a colorbar along the side of the figure.
Parameters
----------
%(axes.colorbar_args)s
length : float, default: :rc:`colorbar.length`
The colorbar length. Units are relative to the span of the rows and
columns of subplots.
shrink : float, optional
Alias for `length`. This is included for consistency with
`matplotlib.figure.Figure.colorbar`.
width : unit-spec, default: :rc:`colorbar.width`
The colorbar width.
%(units.in)s
%(figure.colorbar_space)s
Has no visible effect if `length` is ``1``.
Other parameters
----------------
%(axes.colorbar_kwargs)s
See also
--------
ultraplot.axes.Axes.colorbar
matplotlib.figure.Figure.colorbar
"""
# Backwards compatibility
ax = kwargs.pop("ax", None)
ref = kwargs.pop("ref", None)
loc_ax = ref if ref is not None else ax
cax = kwargs.pop("cax", None)
if isinstance(values, maxes.Axes):
cax = _not_none(cax_positional=values, cax=cax)
values = None
if isinstance(loc, maxes.Axes):
ax = _not_none(ax_positional=loc, ax=ax)
loc = None
# Helpful warning
if kwargs.pop("use_gridspec", None) is not None:
warnings._warn_ultraplot(
"Ignoring the 'use_gridspec' keyword. ultraplot always allocates "
"additional space for colorbars using the figure gridspec "
"rather than 'stealing space' from the parent subplot."
)
# Fill this axes
if cax is not None:
with context._state_context(cax, _internal_call=True): # do not wrap pcolor
cb = super().colorbar(mappable, cax=cax, **kwargs)
# Axes panel colorbar
elif loc_ax is not None:
# Check if span parameters are provided
has_span = _not_none(span, row, col, rows, cols) is not None
# Infer span from loc_ax if it is a list and no span provided
if (
not has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)
if side:
r_min, r_max = float("inf"), float("-inf")
c_min, c_max = float("inf"), float("-inf")
valid_ax = False
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
gs = ss.get_gridspec()
if gs is not None:
try:
r1, r2 = gs._decode_indices(r1, r2, which="h")
c1, c2 = gs._decode_indices(c1, c2, which="w")
except ValueError:
# Non-panel decode can fail for panel or nested specs.
pass
r_min = min(r_min, r1)
r_max = max(r_max, r2)
c_min = min(c_min, c1)
c_max = max(c_max, c2)
valid_ax = True
if valid_ax:
if side in ("left", "right"):
rows = (r_min + 1, r_max + 1)
else:
cols = (c_min + 1, c_max + 1)
has_span = True
# Extract a single axes from array if span is provided
# Otherwise, pass the array as-is for normal colorbar behavior
if (
has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
# Pick the best axis to anchor to based on the colorbar side
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)
best_ax = None
best_coord = float("-inf")
# If side is determined, search for the edge axis
if side:
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
gs = ss.get_gridspec()
if gs is not None:
try:
r1, r2 = gs._decode_indices(r1, r2, which="h")
c1, c2 = gs._decode_indices(c1, c2, which="w")
except ValueError:
# Non-panel decode can fail for panel or nested specs.
pass
if side == "right":
val = c2 # Maximize column index
elif side == "left":
val = -c1 # Minimize column index
elif side == "bottom":
val = r2 # Maximize row index
elif side == "top":
val = -r1 # Minimize row index
else:
val = 0
if val > best_coord:
best_coord = val
best_ax = axi
# Fallback to first axis
if best_ax is None:
try:
ax_single = next(iter(loc_ax))
except (TypeError, StopIteration):
ax_single = loc_ax
else:
ax_single = best_ax
else:
ax_single = loc_ax
# Pass span parameters through to axes colorbar
cb = ax_single.colorbar(
mappable,
values,
space=space,
pad=pad,
width=width,
loc=loc,
span=span,
row=row,
col=col,
rows=rows,
cols=cols,
**kwargs,
)
# Figure panel colorbar
else:
loc = _not_none(loc=loc, location=location, default="r")
ax = self._add_figure_panel(
loc,
row=row,
col=col,
rows=rows,
cols=cols,
span=span,
width=width,
space=space,
pad=pad,
)
cb = ax.colorbar(mappable, values, loc="fill", **kwargs)
return cb
[docs]
@docstring._concatenate_inherited
@docstring._snippet_manager
def legend(
self,
handles=None,
labels=None,
loc=None,
location=None,
row=None,
col=None,
rows=None,
cols=None,
span=None,
space=None,
pad=None,
width=None,
**kwargs,
):
"""
Add a legend along the side of the figure.
Parameters
----------
%(axes.legend_args)s
%(figure.legend_space)s
width : unit-spec, optional
The space allocated for the legend box. This does nothing if
the :ref:`tight layout algorithm <ug_tight>` is active for the figure.
%(units.in)s
Other parameters
----------------
%(axes.legend_kwargs)s
See also
--------
ultraplot.axes.Axes.legend
matplotlib.axes.Axes.legend
"""
ax = kwargs.pop("ax", None)
ref = kwargs.pop("ref", None)
loc_ax = ref if ref is not None else ax
# Axes panel legend
if loc_ax is not None:
content_ax = ax if ax is not None else loc_ax
# Check if span parameters are provided
has_span = _not_none(span, row, col, rows, cols) is not None
# Automatically collect handles and labels from content axes if not provided
# Case 1: content_ax is a list (we must auto-collect)
# Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
must_collect = (
np.iterable(content_ax)
and not isinstance(content_ax, (str, maxes.Axes))
) or (content_ax is not loc_ax)
if must_collect and handles is None and labels is None:
handles, labels = [], []
# Handle list of axes
if np.iterable(content_ax) and not isinstance(
content_ax, (str, maxes.Axes)
):
for axi in content_ax:
h, l = axi.get_legend_handles_labels()
handles.extend(h)
labels.extend(l)
# Handle single axis
else:
handles, labels = content_ax.get_legend_handles_labels()
# Infer span from loc_ax if it is a list and no span provided
if (
not has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)
if side:
r_min, r_max = float("inf"), float("-inf")
c_min, c_max = float("inf"), float("-inf")
valid_ax = False
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
gs = ss.get_gridspec()
if gs is not None:
try:
r1, r2 = gs._decode_indices(r1, r2, which="h")
c1, c2 = gs._decode_indices(c1, c2, which="w")
except ValueError:
pass
r_min = min(r_min, r1)
r_max = max(r_max, r2)
c_min = min(c_min, c1)
c_max = max(c_max, c2)
valid_ax = True
if valid_ax:
if side in ("left", "right"):
rows = (r_min + 1, r_max + 1)
else:
cols = (c_min + 1, c_max + 1)
has_span = True
# Extract a single axes from array if span is provided (or if ref is a list)
# Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
if (
has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
# Pick the best axis to anchor to based on the legend side
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)
best_ax = None
best_coord = float("-inf")
# If side is determined, search for the edge axis
if side:
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
gs = ss.get_gridspec()
if gs is not None:
try:
r1, r2 = gs._decode_indices(r1, r2, which="h")
c1, c2 = gs._decode_indices(c1, c2, which="w")
except ValueError:
pass
if side == "right":
val = c2 # Maximize column index
elif side == "left":
val = -c1 # Minimize column index
elif side == "bottom":
val = r2 # Maximize row index
elif side == "top":
val = -r1 # Minimize row index
else:
val = 0
if val > best_coord:
best_coord = val
best_ax = axi
# Fallback to first axis if no best axis found (or side is None)
if best_ax is None:
try:
ax_single = next(iter(loc_ax))
except (TypeError, StopIteration):
ax_single = loc_ax
else:
ax_single = best_ax
else:
ax_single = loc_ax
if isinstance(ax_single, list):
try:
ax_single = pgridspec.SubplotGrid(ax_single)
except ValueError:
ax_single = ax_single[0]
leg = ax_single.legend(
handles,
labels,
loc=loc,
space=space,
pad=pad,
width=width,
span=span,
row=row,
col=col,
rows=rows,
cols=cols,
**kwargs,
)
# Figure panel legend
else:
loc = _not_none(loc=loc, location=location, default="r")
ax = self._add_figure_panel(
loc,
row=row,
col=col,
rows=rows,
cols=cols,
span=span,
width=width,
space=space,
pad=pad,
)
leg = ax.legend(handles, labels, loc="fill", **kwargs)
return leg
[docs]
@docstring._snippet_manager
def save(self, filename, **kwargs):
"""
%(figure.save)s
"""
return self.savefig(filename, **kwargs)
[docs]
@docstring._concatenate_inherited
@docstring._snippet_manager
def savefig(self, filename, **kwargs):
"""
%(figure.save)s
"""
# Automatically expand the user name. Undocumented because we
# do not want to overwrite the matplotlib docstring.
if isinstance(filename, str):
filename = os.path.expanduser(filename)
# NOTE: this draw ensures that we are applying ultraplots layout adjustment. It is unclear what changed with ultraplot's history that makes this necessary, but it seems to cause no issues. Future devs, if unnecessary remove this line and test.
self.canvas.draw()
super().savefig(filename, **kwargs)
[docs]
@docstring._concatenate_inherited
def set_canvas(self, canvas):
"""
Set the figure canvas. Add monkey patches for the instance-level
`~matplotlib.backend_bases.FigureCanvasBase.draw` and
`~matplotlib.backend_bases.FigureCanvasBase.print_figure` methods.
Parameters
----------
canvas : `~matplotlib.backend_bases.FigureCanvasBase`
The figure canvas.
See also
--------
matplotlib.figure.Figure.set_canvas
"""
# NOTE: Use the _draw method if it exists, e.g. for osx backends. Critical
# or else wrong renderer size is used.
# NOTE: See _add_canvas_preprocessor for details. Critical to not add cache
# print_figure renderer when the print method (print_pdf, print_png, etc.)
# calls Figure.draw(). Otherwise have issues where (1) figure size and/or
# bounds are incorrect after saving figure *then* displaying it in qt or inline
# notebook backends, and (2) figure fails to update correctly after successively
# modifying and displaying within inline notebook backend (previously worked
# around this by forcing additional draw() call in this function before
# proceeding with print_figure). Set the canvas and add monkey patches
# to the instance-level draw and print_figure methods.
method = "draw"
# if getattr(canvas, "_draw", None):
# method = "_draw"
# method = '_draw' if callable(getattr(canvas, '_draw', None)) else 'draw'
_add_canvas_preprocessor(canvas, "print_figure", cache=False) # saves, inlines
_add_canvas_preprocessor(canvas, method, cache=True) # renderer displays
super().set_canvas(canvas)
def _is_same_size(self, figsize, eps=None):
"""
Test if the figure size is unchanged up to some tolerance in inches.
"""
eps = _not_none(eps, 0.01)
figsize_active = self.get_size_inches()
if figsize is None: # e.g. GridSpec._calc_figsize() returned None
return True
else:
return np.all(np.isclose(figsize, figsize_active, rtol=0, atol=eps))
[docs]
@docstring._concatenate_inherited
def set_size_inches(self, w, h=None, *, forward=True, internal=False, eps=None):
"""
Set the figure size. If this is being called manually or from an interactive
backend, update the default layout with this fixed size. If the figure size is
unchanged or this is an internal call, do not update the default layout.
Parameters
----------
*args : float
The width and height passed as positional arguments or a 2-tuple.
forward : bool, optional
Whether to update the canvas.
internal : bool, optional
Whether this is an internal resize.
eps : float, optional
The deviation from the current size in inches required to treat this
as a user-triggered figure resize that fixes the layout.
See also
--------
matplotlib.figure.Figure.set_size_inches
"""
# Parse input args
figsize = w if h is None else (w, h)
if not np.all(np.isfinite(figsize)):
raise ValueError(f"Figure size must be finite, not {figsize}.")
# Fix the figure size if this is a user action from an interactive backend
# NOTE: If we fail to detect 'user' resize from the user, not only will
# result be incorrect, but qt backend will crash because it detects a
# recursive size change, since preprocessor size will differ.
# NOTE: Bitmap renderers calculate the figure size in inches from
# int(Figure.bbox.[width|height]) which rounds to whole pixels. When
# renderer calls set_size_inches, size may be effectively the same, but
# slightly changed due to roundoff error! Therefore only compare approx size.
attrs = ("_is_idle_drawing", "_is_drawing", "_draw_pending")
backend = any(getattr(self.canvas, attr, None) for attr in attrs)
internal = internal or self._is_adjusting
samesize = self._is_same_size(figsize, eps)
ctx = context._empty_context() # context not necessary most of the time
if not backend and not internal and not samesize:
ctx = self._context_adjusting() # do not trigger layout solver
self._figwidth, self._figheight = figsize
self._refwidth = self._refheight = None # critical!
# Apply the figure size
# NOTE: If size changes we always update the gridspec to enforce fixed spaces
# and panel widths (necessary since axes use figure relative coords)
with ctx: # avoid recursion
super().set_size_inches(figsize, forward=forward)
if not samesize: # gridspec positions will resolve differently
self.gridspec.update()
def _iter_axes(self, hidden=False, children=False, panels=True):
"""
Iterate over all axes and panels in the figure belonging to the
`~ultraplot.axes.Axes` class. Exclude inset and twin axes.
Parameters
----------
hidden : bool, optional
Whether to include "hidden" panels.
children : bool, optional
Whether to include child axes. Note this now includes "twin" axes.
panels : bool or str or sequence of str, optional
Whether to include panels or the panels to include.
"""
# Parse panels
if panels is False:
panels = ()
elif panels is True or panels is None:
panels = ("left", "right", "bottom", "top")
elif isinstance(panels, str):
panels = (panels,)
if not set(panels) <= {"left", "right", "bottom", "top"}:
raise ValueError(f"Invalid sides {panels!r}.")
# Iterate
axs = (
*self._subplot_dict.values(),
*(ax for side in panels for ax in self._panel_dict[side]),
)
for ax in axs:
if not hidden and ax._panel_hidden:
continue # ignore hidden panel and its colorbar/legend child
yield from ax._iter_axes(hidden=hidden, children=children, panels=panels)
@property
def gridspec(self):
"""
The single :class:`~ultraplot.gridspec.GridSpec` instance used for all
subplots in the figure.
See also
--------
ultraplot.figure.Figure.subplotgrid
ultraplot.gridspec.GridSpec.figure
ultraplot.gridspec.SubplotGrid.gridspec
"""
return self._gridspec
@gridspec.setter
def gridspec(self, gs):
if not isinstance(gs, pgridspec.GridSpec):
raise ValueError("Gridspec must be a ultraplot.GridSpec instance.")
self._gridspec = gs
gs.figure = self # trigger copying settings from the figure
@property
def subplotgrid(self):
"""
A :class:`~ultraplot.gridspec.SubplotGrid` containing the numbered subplots in the
figure. The subplots are ordered by increasing `~ultraplot.axes.Axes.number`.
See also
--------
ultraplot.figure.Figure.gridspec
ultraplot.gridspec.SubplotGrid.figure
"""
return pgridspec.SubplotGrid([s for _, s in sorted(self._subplot_dict.items())])
@property
def tight(self):
"""
Whether the :ref:`tight layout algorithm <ug_tight>` is active for the
figure. This value is passed to `~ultraplot.figure.Figure.auto_layout`
every time the figure is drawn. Can be changed e.g. ``fig.tight = False``.
See also
--------
ultraplot.figure.Figure.auto_layout
"""
return self._tight_active
@tight.setter
def tight(self, b):
self._tight_active = bool(b)
# Apply signature obfuscation after getting keys
# NOTE: This is needed for axes and figure instantiation.
_format_signature = inspect.signature(format)
format = docstring._obfuscate_kwargs(format)
# Add deprecated properties. There are *lots* of properties we pass to Figure
# and do not like idea of publicly tracking every single one of them. If we
# want to improve user introspection consider modifying Figure.__repr__.
for _attr in ("alignx", "aligny", "sharex", "sharey", "spanx", "spany", "tight", "ref"):
def _get_deprecated(self, attr=_attr):
warnings._warn_ultraplot(
f"The property {attr!r} is no longer public as of v0.8. It will be "
"removed in a future release."
)
return getattr(self, "_" + attr)
_getter = property(_get_deprecated)
setattr(Figure, _attr, property(_get_deprecated))
# Disable native matplotlib layout and spacing functions when called
# manually and emit warning message to help new users.
for _attr, _msg in (
("set_tight_layout", Figure._tight_message),
("set_constrained_layout", Figure._tight_message),
("tight_layout", Figure._tight_message),
("init_layoutbox", Figure._tight_message),
("execute_constrained_layout", Figure._tight_message),
("subplots_adjust", Figure._space_message),
):
_func = getattr(Figure, _attr, None)
if _func is None:
continue
@functools.wraps(_func) # noqa: E301
def _disable_method(self, *args, func=_func, message=_msg, **kwargs):
message = (
f"fig.{func.__name__}() has no effect on ultraplot figures. " + message
)
if self._is_authorized:
return func(self, *args, **kwargs)
else:
warnings._warn_ultraplot(message) # noqa: E501, U100
_disable_method.__doc__ = None # remove docs
setattr(Figure, _attr, _disable_method)