Source code for brian2tools.plotting.base

'''
Base module for the plotting facilities.
'''
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

from brian2.core.variables import VariableView
from brian2.spatialneuron.morphology import Morphology
from brian2.monitors import SpikeMonitor, StateMonitor, PopulationRateMonitor
from brian2.monitors.statemonitor import StateMonitorView
from brian2.units.fundamentalunits import Quantity
from brian2.units.stdunits import ms
from brian2.utils.logger import get_logger
from brian2.synapses.synapses import Synapses, SynapticPathway

from .data import plot_raster, plot_state, plot_rate, _get_best_unit
from .morphology import plot_dendrogram
from .synapses import plot_synapses

__all__ = ['brian_plot', 'add_background_pattern']

logger = get_logger(__name__)


def _resolve_var_kwds(kwds, var_name, values):
    '''
    Build a per-variable keyword dict, resolving ``var_name`` and ``var_unit``
    that may be plain values (applied to every variable) or dicts keyed by
    variable name.
    '''
    kwds_var = {k: v for k, v in kwds.items()
                if k not in ('var_name', 'var_unit')}

    vn = kwds.get('var_name')
    if isinstance(vn, dict):
        kwds_var['var_name'] = vn.get(var_name, var_name)
    elif vn is not None:
        kwds_var['var_name'] = vn
    else:
        kwds_var['var_name'] = var_name

    vu = kwds.get('var_unit')
    if isinstance(vu, dict):
        unit = vu.get(var_name)
        if unit is not None:
            kwds_var['var_unit'] = unit
        elif isinstance(values, Quantity):
            kwds_var['var_unit'] = _get_best_unit(values)
    elif vu is not None:
        kwds_var['var_unit'] = vu
    elif isinstance(values, Quantity):
        kwds_var['var_unit'] = _get_best_unit(values)

    return kwds_var


def _plot_state_variables(brian_obj, record_variables, axes, **kwds):
    '''
    Plot one or more recorded state variables.  When more than one variable
    is present a column of subplots sharing the time axis is created
    automatically (unless the caller supplies a matching array of axes).
    '''
    n_vars = len(record_variables)

    if n_vars == 1:
        var_name = record_variables[0]
        values = getattr(brian_obj, var_name).T
        kwds_var = _resolve_var_kwds(kwds, var_name, values)
        return plot_state(brian_obj.t, values, axes=axes, **kwds_var)

    if axes is None:
        fig, axes_arr = plt.subplots(n_vars, 1, sharex=True)
    else:
        axes_arr = np.asarray(axes).ravel()
        if len(axes_arr) != n_vars:
            raise TypeError(
                f"If multiple variables are recorded, 'axes' must be an "
                f"array-like of Axes with length {n_vars} (got {len(axes_arr)})."
            )

    ret_axes = []
    for ax, var_name in zip(axes_arr, record_variables):
        values = getattr(brian_obj, var_name).T
        kwds_var = _resolve_var_kwds(kwds, var_name, values)
        ret_axes.append(
            plot_state(brian_obj.t, values, axes=ax, **kwds_var))
    return ret_axes


def _setup_axes_matplotlib(axes):
    '''
    Helper function to create new figures/axes for matplotlib, depending on
    arguments provided by the user.
    '''
    if axes is None:
        axes = plt.gca()
    return axes


def _setup_axes_mayavi(axes):
    '''
    Helper function to create new figures/axes for mayavi, depending on
    arguments provided by the user.
    '''
    import mayavi.mlab as mayavi
    if axes is None:
        axes = mayavi.figure(bgcolor=(0.95, 0.95, 0.95))
    return axes


[docs] def brian_plot(brian_obj, axes=None, **kwds): ''' Plot the data of the given object (e.g. a monitor). This function will call an adequate plotting function for the object, e.g. `plot_raster` for a `~brian2.monitors.spikemonitor.SpikeMonitor`. The plotting may apply heuristics to get a generally useful plot (e.g. for a `~brian2.monitors.ratemonitor.PopulationRateMonitor`, it will plot the rates smoothed with a Gaussian window of 1 ms), the exact details are subject to change. This function is therefore mostly meant as a quick and easy way to plot an object, for full control use one of the specific plotting functions. When a `~brian2.monitors.statemonitor.StateMonitor` that records several variables is given, a column of subplots sharing the time axis is created automatically. Parameters ---------- brian_obj : object The Brian object to plot. axes : `~matplotlib.axes.Axes` or array-like of Axes, optional The `~matplotlib.axes.Axes` instance used for plotting. Defaults to ``None`` which means that a new `~matplotlib.axes.Axes` will be created for the plot. For a multi-variable `~brian2.monitors.statemonitor.StateMonitor`, pass an array-like of `~matplotlib.axes.Axes` with one entry per recorded variable. kwds : dict, optional Any additional keywords command will be handed over to matplotlib's `~matplotlib.axes.Axes.plot` command. This can be used to set plot properties such as the ``color``. For multi-variable `~brian2.monitors.statemonitor.StateMonitor` objects, ``var_name`` and ``var_unit`` may be dictionaries keyed by variable name, e.g. ``var_name={'v': 'membrane potential', 'I': 'input current'}``. Returns ------- axes : `~matplotlib.axes.Axes` or list of `~matplotlib.axes.Axes` The `~matplotlib.axes.Axes` instance(s) used for plotting. A list is returned when multiple state variables are plotted. ''' if isinstance(brian_obj, SpikeMonitor): return plot_raster(brian_obj.i, brian_obj.t, axes=axes, **kwds) elif isinstance(brian_obj, StateMonitor): return _plot_state_variables(brian_obj, brian_obj.record_variables, axes, **kwds) elif isinstance(brian_obj, StateMonitorView): return _plot_state_variables(brian_obj, brian_obj.monitor.record_variables, axes, **kwds) elif isinstance(brian_obj, PopulationRateMonitor): smooth_rate = brian_obj.smooth_rate(width=1*ms) if 'rate_unit' not in kwds: kwds['rate_unit'] = _get_best_unit(smooth_rate) return plot_rate(brian_obj.t, smooth_rate, axes=axes, **kwds) elif isinstance(brian_obj, Morphology): return plot_dendrogram(brian_obj, axes=axes, **kwds) elif isinstance(brian_obj, Synapses): if len(brian_obj) == 0: raise TypeError('Synapses object does not have any synapses.') min_sources, max_sources = np.min(brian_obj.i[:]), np.max(brian_obj.i[:]) min_targets, max_targets = np.min(brian_obj.j[:]), np.max(brian_obj.j[:]) source_range = max_sources - min_sources target_range = max_targets - min_targets if source_range < 1000 and target_range < 1000: plot_type = 'image' elif len(brian_obj) < 10000: plot_type = 'scatter' else: plot_type = 'hexbin' return plot_synapses(brian_obj.i, brian_obj.j, plot_type=plot_type, axes=axes) # brian_obj.group can be a weak proxy, we can therefore not use isinstance elif (isinstance(brian_obj, VariableView) and issubclass(brian_obj.group.__class__, (Synapses, SynapticPathway))): # synaptic variable synapses = brian_obj.group sources = synapses.i[:] targets = synapses.j[:] min_sources, max_sources = np.min(sources), np.max(sources) min_targets, max_targets = np.min(targets), np.max(targets) source_range = max_sources - min_sources target_range = max_targets - min_targets if source_range < 1000 and target_range < 1000: plot_type = 'image' elif len(brian_obj) < 10000: plot_type = 'scatter' else: plot_type = 'hexbin' values = brian_obj[:] if 'var_name' not in kwds: kwds['var_name'] = brian_obj.name if 'var_unit' not in kwds and isinstance(values, Quantity): kwds['var_unit'] = _get_best_unit(values) return plot_synapses(sources, targets, values, plot_type=plot_type, axes=axes, **kwds) else: raise NotImplementedError('Do not know how to plot object of type ' '%s' % type(brian_obj))
[docs] def add_background_pattern(axes, hatch='xxx', fill=True, fc=(0.9, 0.9, 0.9), ec=(0.8, 0.8, 0.8), zorder=-10, **kwds): ''' Add a "hatching" pattern to the background of the axes (can be useful to make a difference between "no value" and a value mapped to a color value that is identical to the background color). By default, it uses a cross hatching pattern in gray which can be changed by providing the respective arguments. All additional keyword arguments are passed on to the `~matplotlib.patches.Rectangle` initializer. Parameters ---------- axes : `matplotlib.axes.Axes` The axes where the background pattern should be added. hatch : str, optional See `matplotlib.patches.Patch.set_hatch`. Defaults to `'xxx'`. fill : bool, optional See `matplotlib.patches.Patch.set_fill`. Defaults to `True`. fc : mpl color spec or None or 'none' See `matplotlib.patches.Patch.set_facecolor`. Defaults to `(0.9, 0.9, 0.9)`. ec : mpl color spec or None or 'none' See `matplotlib.patches.Patch.set_edgecolor`. Defaults to `(0.8, 0.8, 0.8)`. zorder : int See `matplotlib.artist.Artist.set_zorder`. Defaults to `-10`. ''' xmin, xmax = axes.get_xlim() ymin, ymax = axes.get_ylim() xy = (xmin, ymin) width = xmax - xmin height = ymax - ymin p = patches.Rectangle(xy, width, height, hatch=hatch, fill=fill, fc=fc, ec=ec, zorder=zorder, **kwds) axes.add_patch(p)