Source code for brian2tools.plotting.synapses

"""
Module to plot synaptic connections.
"""
from collections import Counter

import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.axes_grid1 import make_axes_locatable

from .data import _get_best_unit

__all__ = ['plot_synapses']


# Helper functions
def _int_connection_matrix(sources, targets, values):
    '''
    Return a 2D connection matrix filled with integer values (typically the
    number of synapses) in the form of a masked matrix (values equal to 0 are
    masked)

    Parameters
    ----------
    sources : ndarray of int
        The indices of the source neurons for each value.
    targets : ndarray of int
        The indices of the target neurons for each value.
    values : ndarray of int or int
        The value for each (source, target) pair.

    Returns
    -------
    matrix : ma.MaskedArray
        The connection matrix, masked for 0 values
    '''
    assert np.min(values) > 0 and np.max(values) < 256
    full_matrix = np.zeros((np.max(targets) - np.min(targets) + 1,
                            np.max(sources) - np.min(sources) + 1),
                           dtype=np.uint8)
    full_matrix[targets - np.min(targets),
                sources - np.min(sources)] = values
    return ma.masked_equal(full_matrix, 0, copy=False)


def _float_connection_matrix(sources, targets, values):
    '''
    Return a 2D connection matrix filled with floating point values (synaptic
    weights, delays, ...) in the form of a masked matrix (entries without value
    are set to NaN and masked).

    Parameters
    ----------
    sources : ndarray of int
        The indices of the source neurons for each value.
    targets : ndarray of int
        The indices of the target neurons for each value.
    values : ndarray of float
        The value for each (source, target) pair.

    Returns
    -------
    matrix : ma.MaskedArray
        The connection matrix, masked for NaN values
    '''
    full_matrix = np.ones((np.max(targets) - np.min(targets) + 1,
                           np.max(sources) - np.min(sources) + 1)) * np.nan
    full_matrix[targets - np.min(targets), sources - np.min(sources)] = values
    masked_matrix = ma.masked_invalid(full_matrix, copy=False)
    return masked_matrix


def _discrete_color_mapping(user_cmap, n_synapses):
    cmap = mpl.cm.get_cmap(user_cmap, np.max(n_synapses))
    bounds = np.arange(np.max(n_synapses) + 1) + 0.5
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
    return bounds, cmap, norm


# Plot functions
[docs]def plot_synapses(sources, targets, values=None, var_unit=None, var_name=None, plot_type='scatter', axes=None, **kwds): ''' Parameters ---------- sources : `~numpy.ndarray` of int The source indices of the connections (as returned by ``Synapses.i``). targets : `~numpy.ndarray` of int The target indices of the connections (as returned by ``Synapses.j``). values : `~brian2.units.fundamentalunits.Quantity`, `~numpy.ndarray` The values to plot, a 1D array of the same size as ``sources`` and ``targets``. var_unit : `~brian2.units.fundamentalunits.Unit`, optional The unit to use to plot the ``values`` (e.g. ``mV`` for a membrane potential). If none is given (the default), an attempt is made to find a good scale automatically based on the ``values``. var_name : str, optional The name of the variable that is plotted. Used for the axis label. plot_type : {``'scatter'``, ``'image'``, ``'hexbin'``}, optional What type of plot to use. Can be ``'scatter'`` (the default) to draw a scatter plot, ``'image'`` to display the connections as a matrix or ``'hexbin'`` to display a 2D histogram using matplotlib's `~matplotlib.axes.Axes.hexbin` function. For a large number of synapses, ``'scatter'`` will be very slow. Similarly, an ``'image'`` plot will use a lot of memory for connections between two large groups. For a small number of neurons and synapses, ``'hexbin'`` will be hard to interpret. axes : `~matplotlib.axes.Axes`, optional The `~matplotlib.axes.Axes` instance used for plotting. Defaults to ``None`` which means that a new `~matplotlib.axes.Axes` will be created for the plot. kwds : dict, optional Any additional keywords command will be handed over to the respective matplotlib command (`~matplotlib.axes.Axes.scatter` if the ``plot_type`` is ``'scatter'``, `~matplotlib.axes.Axes.imshow` for ``'image'``, and `~matplotlib.axes.Axes.hexbin` for ``'hexbin'``). This can be used to set plot properties such as the ``marker``. Returns ------- axes : `~matplotlib.axes.Axes` The `~matplotlib.axes.Axes` instance that was used for plotting. This object allows to modify the plot further, e.g. by setting the plotted range, the axis labels, the plot title, etc. ''' # Avoid circular import issues from brian2tools.plotting.base import _setup_axes_matplotlib axes = _setup_axes_matplotlib(axes) sources = np.asarray(sources) targets = np.asarray(targets) if not len(sources) == len(targets): raise TypeError('Length of sources and targets does not match.') if plot_type not in ['scatter', 'image', 'hexbin']: raise ValueError("plot_type has to be either 'scatter', 'image', or " "'hexbin' (was: %r)" % plot_type) # Get some information out of the values if provided if values is not None: if len(values) != len(sources): raise TypeError('Length of values and sources/targets does not ' 'match.') if var_name is None: var_name = getattr(values, 'name', None) # works for a VariableView if var_unit is None: try: var_unit = _get_best_unit(values[:]) except AttributeError: pass if var_unit is not None: values = values / var_unit if plot_type != 'hexbin': # For "hexbin", we are binning multiple synapses anyway, so we don't # have to make a difference for multiple synapses connection_count = Counter(zip(sources, targets)) multiple_synapses = np.any(np.array(list(connection_count.values())) > 1) edgecolor = kwds.pop('edgecolor', 'none') if plot_type != 'hexbin' and multiple_synapses: if values is not None: raise NotImplementedError("Plotting variables with multiple " "synapses per source-target pair is only " "implemented for 'hexbin' plots.") unique_sources, unique_targets = zip(*connection_count.keys()) n_synapses = list(connection_count.values()) bounds, cmap, norm = _discrete_color_mapping(kwds.pop('cmap', None), n_synapses) # Make the plot if plot_type == 'scatter': marker = kwds.pop('marker', ',') axes.scatter(unique_sources, unique_targets, marker=marker, c=n_synapses, edgecolor=edgecolor, cmap=cmap, norm=norm, **kwds) else: assert np.max(n_synapses) < 256 matrix = _int_connection_matrix(unique_sources, unique_targets, n_synapses) origin = kwds.pop('origin', 'lower') interpolation = kwds.pop('interpolation', 'nearest') axes.imshow(matrix, origin=origin, interpolation=interpolation, cmap=cmap, norm=norm, extent=(min(unique_sources) - 0.5, max(unique_sources) + 0.5, min(unique_targets) - 0.5, max(unique_targets) + 0.5), **kwds) # Add the colorbar locatable_axes = make_axes_locatable(axes) cax = locatable_axes.append_axes('right', size='5%', pad=0.05) mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm, ticks=bounds-0.5) cax.set_ylabel('number of synapses') else: if plot_type == 'scatter': marker = kwds.pop('marker', ',') color = kwds.pop('color', values if values is not None else None) plotted = axes.scatter(sources, targets, marker=marker, c=color, edgecolor=edgecolor, **kwds) elif plot_type == 'image': if values is not None: matrix = _float_connection_matrix(sources, targets, values) else: matrix = _int_connection_matrix(sources, targets, 1) origin = kwds.pop('origin', 'lower') interpolation = kwds.pop('interpolation', 'nearest') vmin = kwds.pop('vmin', 1 if values is None else None) plotted = axes.imshow(matrix, origin=origin, interpolation=interpolation, vmin=vmin, extent=(min(sources) - 0.5, max(sources) + 0.5, min(targets) - 0.5, max(targets) + 0.5), **kwds) elif plot_type == 'hexbin': if values is None: # Counting synapses mincnt = kwds.pop('mincnt', 1) else: mincnt = kwds.pop('mincnt', None) plotted = axes.hexbin(sources, targets, C=values, mincnt=mincnt, **kwds) if values is not None or plot_type == 'hexbin': # Add a colorbar locatable_axes = make_axes_locatable(axes) cax = locatable_axes.append_axes('right', size='7.5%', pad=0.05) plt.colorbar(plotted, cax=cax) if var_name is None: if var_unit is not None: cax.set_ylabel('in units of %s' % str(var_unit)) else: label = var_name if var_unit is not None: label += ' (%s)' % str(var_unit) cax.set_ylabel(label) axes.set_xlim(-0.5, max(sources) + 0.5) axes.set_ylim(-0.5, max(targets) + 0.5) axes.set_xlabel('source neuron index') axes.set_ylabel('target neuron index') # Prevent floating point values on the axes (e.g. when zooming in) axes.xaxis.set_major_locator(MaxNLocator(integer=True)) axes.yaxis.set_major_locator(MaxNLocator(integer=True)) return axes