sqfa.plot.data
Scatter data of different classes with color code.
Functions
|
Plot scatter of the data to different categories. |
- sqfa.plot.data.scatter_data(data, labels, ax=None, values=None, dim_pair=(0, 1), n_points=1000, classes_plot=None, legend_type='none', **kwargs)
Plot scatter of the data to different categories.
- Parameters:
data (torch.Tensor) – Responses to the stimuli. Shape (n_stimuli, n_filters).
labels (torch.int64) – Class labels of each point with shape (n_points).
ax (matplotlib.axes.Axes, optional) – Axes to plot the scatter. If None, a new figure is created. The default is None.
values (torch.Tensor, optional) – Values to color the classes. The default is linearly spaced values between -1 and 1.
dim_pair (tuple, optional) – Pair of filters to plot. The default is (0, 1).
n_points (int, optional) – Number of points per class to plot. The default is 1000.
classes_plot (list, optional) – List of classes to plot. The default is all classes.
legend_type (str, optional) – Type of legend to add: ‘none’, ‘continuous’, ‘discrete’.
- Returns:
ax – Axes with the scatter plot.
- Return type:
matplotlib.axes.Axes