Lorsque je dessine un tracé de points à l'aide de matplotlib, je voudrais compenser les points de données qui se chevauchent pour les garder tous visibles. Par exemple, si j'ai
CategoryA: 0,0,3,0,5
CategoryB: 5,10,5,5,10
Je veux que chacun des points de données CategoryA
"0" soit placé côte à côte, plutôt que directement les uns sur les autres, tout en restant distinct de CategoryB
.
Dans R (ggplot2
) Il y a un "jitter"
option qui fait cela. Existe-t-il une option similaire dans matplotlib, ou existe-t-il une autre approche qui conduirait à un résultat similaire?
Modifier: pour clarifier, le "beeswarm"
plot in R est essentiellement ce que j'ai en tête, et pybeeswarm
est un début précoce mais utile dans une version matplotlib/Python.
Edit: pour ajouter que le Seaborn Swarmplot , introduit dans la version 0.7, est un excellent la mise en œuvre de ce que je voulais.
En étendant la réponse par @ user2467675, voici comment je l'ai fait:
def Rand_jitter(arr):
stdev = .01*(max(arr)-min(arr))
return arr + np.random.randn(len(arr)) * stdev
def jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None, **kwargs):
return scatter(Rand_jitter(x), Rand_jitter(y), s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, verts=verts, hold=hold, **kwargs)
La variable stdev
s'assure que la gigue est suffisante pour être vue à différentes échelles, mais elle suppose que les limites des axes sont 0 et la valeur max.
Vous pouvez alors appeler jitter
au lieu de scatter
.
J'ai utilisé numpy.random pour "diffuser/réchauffer les abeilles" les données le long de l'axe X mais autour d'un point fixe pour chaque catégorie, puis je fais essentiellement pyplot.scatter () pour chaque catégorie:
import matplotlib.pyplot as plt
import numpy as np
#random data for category A, B, with B "taller"
yA, yB = np.random.randn(100), 5.0+np.random.randn(1000)
xA, xB = np.random.normal(1, 0.1, len(yA)),
np.random.normal(3, 0.1, len(yB))
plt.scatter(xA, yA)
plt.scatter(xB, yB)
plt.show()
Une façon d'aborder le problème est de penser à chaque "ligne" de votre graphique de dispersion/point/chaleur d'abeille comme un bac dans un histogramme:
data = np.random.randn(100)
width = 0.8 # the maximum width of each 'row' in the scatter plot
xpos = 0 # the centre position of the scatter plot in x
counts, edges = np.histogram(data, bins=20)
centres = (edges[:-1] + edges[1:]) / 2.
yvals = centres.repeat(counts)
max_offset = width / counts.max()
offsets = np.hstack((np.arange(cc) - 0.5 * (cc - 1)) for cc in counts)
xvals = xpos + (offsets * max_offset)
fig, ax = plt.subplots(1, 1)
ax.scatter(xvals, yvals, s=30, c='b')
Cela implique évidemment de regrouper les données, vous risquez donc de perdre une certaine précision. Si vous avez des données discrètes, vous pouvez remplacer:
counts, edges = np.histogram(data, bins=20)
centres = (edges[:-1] + edges[1:]) / 2.
avec:
centres, counts = np.unique(data, return_counts=True)
Une approche alternative qui préserve les coordonnées y exactes, même pour des données continues, consiste à utiliser un estimation de la densité du noya pour mettre à l'échelle l'amplitude de la gigue aléatoire sur l'axe des x:
from scipy.stats import gaussian_kde
kde = gaussian_kde(data)
density = kde(data) # estimate the local density at each datapoint
# generate some random jitter between 0 and 1
jitter = np.random.Rand(*data.shape) - 0.5
# scale the jitter by the KDE estimate and add it to the centre x-coordinate
xvals = 1 + (density * jitter * width * 2)
ax.scatter(xvals, data, s=30, c='g')
for sp in ['top', 'bottom', 'right']:
ax.spines[sp].set_visible(False)
ax.tick_params(top=False, bottom=False, right=False)
ax.set_xticks([0, 1])
ax.set_xticklabels(['Histogram', 'KDE'], fontsize='x-large')
fig.tight_layout()
Cette seconde méthode est vaguement basée sur le fonctionnement de violon plots . Il ne peut toujours pas garantir qu'aucun des points ne se chevauchent, mais je trouve qu'en pratique, il a tendance à donner des résultats assez agréables tant qu'il y a un nombre décent de points (> 20), et la distribution peut être raisonnablement bien approximée par une somme de gaussiens.
Seaborn fournit des tracés de points catégoriels de type histogramme via sns.swarmplot()
et des tracés de points catégoriels instables via sns.stripplot()
:
import seaborn as sns
sns.set(style='ticks', context='talk')
iris = sns.load_dataset('iris')
sns.swarmplot('species', 'sepal_length', data=iris)
sns.despine()
sns.stripplot('species', 'sepal_length', data=iris, jitter=0.2)
sns.despine()
Ne connaissant pas d'alternative directe mpl ici vous avez une proposition très rudimentaire:
from matplotlib import pyplot as plt
from itertools import groupby
CA = [0,4,0,3,0,5]
CB = [0,0,4,4,2,2,2,2,3,0,5]
x = []
y = []
for indx, klass in enumerate([CA, CB]):
klass = groupby(sorted(klass))
for item, objt in klass:
objt = list(objt)
points = len(objt)
pos = 1 + indx + (1 - points) / 50.
for item in objt:
x.append(pos)
y.append(item)
pos += 0.04
plt.plot(x, y, 'o')
plt.xlim((0,3))
plt.show()
Le swarmplot de Seaborn semble être le plus approprié pour ce que vous avez en tête, mais vous pouvez également trembler avec le regplot de Seaborn:
import seaborn as sns
iris = sns.load_dataset('iris')
sns.regplot(x='sepal_length',
y='sepal_width',
data=iris,
fit_reg=False, # do not fit a regression line
x_jitter=0.1, # could also dynamically set this with range of data
y_jitter=0.1,
scatter_kws={'alpha': 0.5}) # set transparency to 50%