sqfa.plot.data

Scatter data of different classes with color code.

Functions

scatter_data(data, labels[, ax, values, ...])

Plot scatter of the data to different categories.

sqfa.plot.data.scatter_data(data, labels, ax=None, values=None, dim_pair=(0, 1), n_points=1000, classes_plot=None, legend_type='none', **kwargs)

Plot scatter of the data to different categories.

Parameters:
  • data (torch.Tensor) – Responses to the stimuli. Shape (n_stimuli, n_filters).

  • labels (torch.int64) – Class labels of each point with shape (n_points).

  • ax (matplotlib.axes.Axes, optional) – Axes to plot the scatter. If None, a new figure is created. The default is None.

  • values (torch.Tensor, optional) – Values to color the classes. The default is linearly spaced values between -1 and 1.

  • dim_pair (tuple, optional) – Pair of filters to plot. The default is (0, 1).

  • n_points (int, optional) – Number of points per class to plot. The default is 1000.

  • classes_plot (list, optional) – List of classes to plot. The default is all classes.

  • legend_type (str, optional) – Type of legend to add: ‘none’, ‘continuous’, ‘discrete’.

Returns:

ax – Axes with the scatter plot.

Return type:

matplotlib.axes.Axes