Exemple de matrice de nuage de points
Existe-t-il une telle fonction dans matplotlib.pyplot?
De manière générale, matplotlib ne contient généralement pas de fonctions de traçage qui fonctionnent sur plusieurs objets axes (sous-tracé, dans ce cas). On s'attend à ce que vous écriviez une fonction simple pour enchaîner les choses comme vous le souhaitez.
Je ne sais pas trop à quoi ressemblent vos données, mais il est assez simple de simplement créer une fonction pour le faire à partir de zéro. Si vous allez toujours travailler avec des tableaux structurés ou rec, alors vous pouvez simplifier cela d'une touche. (C'est-à-dire qu'il y a toujours un nom associé à chaque série de données, vous pouvez donc omettre de spécifier des noms.)
Par exemple:
import itertools
import numpy as np
import matplotlib.pyplot as plt
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
def scatterplot_matrix(data, names, **kwargs):
"""Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid."""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.05, wspace=0.05)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "Edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in Zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
axes[x,y].plot(data[x], data[y], **kwargs)
# Label the diagonal subplots...
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in Zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
return fig
main()
Pour ceux qui ne veulent pas définir leurs propres fonctions, il existe une grande bibliothèque d'analyse de données en Python, appelée Pandas , où l'on peut trouver la méthode scatter_matrix () :
from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')
Vous pouvez également utiliser fonction pairplot
de Seaborn :
import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
Merci de partager votre code! Vous avez compris toutes les choses difficiles pour nous. Pendant que je travaillais avec, j'ai remarqué quelques petites choses qui ne semblaient pas tout à fait correctes.
[FIX # 1] Les tics des axes ne s'alignaient pas comme je m'y attendais (c'est-à-dire, dans votre exemple ci-dessus, vous devriez pouvoir tracer une ligne verticale et horizontale à travers n'importe quel point de tous les tracés et les lignes devraient traverser la ligne correspondante point dans les autres parcelles, mais comme il se trouve maintenant, cela ne se produit pas.
[FIX # 2] Si vous avez un nombre impair de variables avec lesquelles vous tracez, les axes en bas à droite ne tirent pas les xtics ou ytics corrects. Il le laisse juste comme les ticks par défaut 0..1.
Pas un correctif, mais je l'ai rendu facultatif pour entrer explicitement names
, afin qu'il place un xi
par défaut pour la variable i dans les positions diagonales.
Vous trouverez ci-dessous une version mise à jour de votre code qui répond à ces deux points, sinon préservant la beauté de votre code.
import itertools
import numpy as np
import matplotlib.pyplot as plt
def scatterplot_matrix(data, names=[], **kwargs):
"""
Plots a scatterplot matrix of subplots. Each row of "data" is plotted
against other rows, resulting in a nrows by nrows grid of subplots with the
diagonal subplots labeled with "names". Additional keyword arguments are
passed on to matplotlib's "plot" command. Returns the matplotlib figure
object containg the subplot grid.
"""
numvars, numdata = data.shape
fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
fig.subplots_adjust(hspace=0.0, wspace=0.0)
for ax in axes.flat:
# Hide all ticks and labels
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
# Set up ticks only on one side for the "Edge" subplots...
if ax.is_first_col():
ax.yaxis.set_ticks_position('left')
if ax.is_last_col():
ax.yaxis.set_ticks_position('right')
if ax.is_first_row():
ax.xaxis.set_ticks_position('top')
if ax.is_last_row():
ax.xaxis.set_ticks_position('bottom')
# Plot the data.
for i, j in Zip(*np.triu_indices_from(axes, k=1)):
for x, y in [(i,j), (j,i)]:
# FIX #1: this needed to be changed from ...(data[x], data[y],...)
axes[x,y].plot(data[y], data[x], **kwargs)
# Label the diagonal subplots...
if not names:
names = ['x'+str(i) for i in range(numvars)]
for i, label in enumerate(names):
axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
ha='center', va='center')
# Turn on the proper x or y axes ticks.
for i, j in Zip(range(numvars), itertools.cycle((-1, 0))):
axes[j,i].xaxis.set_visible(True)
axes[i,j].yaxis.set_visible(True)
# FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
# correct axes limits, so we pull them from other axes
if numvars%2:
xlimits = axes[0,-1].get_xlim()
ylimits = axes[-1,0].get_ylim()
axes[-1,-1].set_xlim(xlimits)
axes[-1,-1].set_ylim(ylimits)
return fig
if __name__=='__main__':
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
linestyle='none', marker='o', color='black', mfc='none')
fig.suptitle('Simple Scatterplot Matrix')
plt.show()
Merci encore d'avoir partagé cela avec nous. Je l'ai utilisé plusieurs fois! Oh, et j'ai réorganisé la partie main()
du code afin qu'il puisse être un exemple de code formel ou ne pas être appelé s'il est importé dans un autre morceau de code.
En lisant la question, je m'attendais à voir une réponse comprenant rpy . Je pense que c'est une bonne option en profitant de deux belles langues. Voici donc:
import rpy
import numpy as np
def main():
np.random.seed(1977)
numvars, numdata = 4, 10
data = 10 * np.random.random((numvars, numdata))
mpg = data[0,:]
disp = data[1,:]
drat = data[2,:]
wt = data[3,:]
rpy.set_default_mode(rpy.NO_CONVERSION)
R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)
# Figure saved as eps
rpy.r.postscript('pairsPlot.eps')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
# Figure saved as png
rpy.r.png('pairsPlot.png')
rpy.r.pairs(R_data,
main="Simple Scatterplot Matrix Via RPy")
rpy.r.dev_off()
rpy.set_default_mode(rpy.BASIC_CONVERSION)
if __== '__main__': main()
Je ne peux pas poster une image pour montrer le résultat :( désolé!