web-dev-qa-db-fra.com

Les axes de plt.subplots () sont un objet "numpy.ndarray" et n'ont pas d'attribut "plot"

Les informations ci-dessous peuvent être superflues si vous essayez de comprendre le message d'erreur. Veuillez commencer par lire la réponse par @ user707650 .

En utilisant MatPlotLib, je voulais un script généralisable qui crée ce qui suit à partir de mes données.

Une fenêtre contenant a sous-tracés disposés de manière à ce qu'il y ait b sous-tracés par colonne. Je veux pouvoir changer les valeurs de a et b.

Si j'ai des données pour 2a sous-parcelles, je veux 2 fenêtres, chacune avec le précédent décrit " a sous-tracés organisés selon b sous-tracés par colonne ".

Les données x et y que je trace sont des flottants stockés dans np.arrays et sont structurées comme suit:

  • Les données x sont toujours les mêmes pour toutes les parcelles et sont de longueur 5.

     'x_vector': [0.000, 0.005, 0.010, 0.020, 0.030, 0.040]
    
  • Les données y de tous les tracés sont stockées dans y_vector où les données du premier tracé sont stockées aux indices 0 à 5. Les données du deuxième tracé sont stocké aux index 6 à 11. Le troisième graphique obtient 12-18, le quatrième 19-24, etc.

Au total, pour cet ensemble de données, j'ai 91 tracés (soit 91 * 6 = 546 valeurs stockées dans y_vector).

Tentative:

import matplotlib.pyplot as plt

# Options:
plots_tot = 14 # Total number of plots. In reality there is going to be 7*13 = 91 plots.
location_of_ydata = 6 # The values for the n:th plot can be found in the y_vector at index 'n*6' through 'n*6 + 6'.
plots_window = 7 # Total number of plots per window.
rows = 2 # Number of rows, i.e. number of subplots per column.

# Calculating number of columns:
prim_cols = plots_window / rows
extra_cols = 0
if plots_window % rows > 0:
    extra_cols = 1
cols = prim_cols + extra_cols

print 'cols:', cols
print 'rows:', rows

# Plotting:
n=0
x=0
fig, ax = plt.subplots(rows, cols)
while x <= plots_tot:
    ax[x].plot(x_vector, y_vector[n:(n+location_of_ydata)], 'ro')
    if x % plots_window == plots_window - 1:
        plt.show() # New window for every 7 plots.
    n = n+location_of_ydata
    x = x+1

J'obtiens l'erreur suivante:

cols: 4
rows: 2
Traceback (most recent call last):
  File "Script.py", line 222, in <module>
    ax[x].plot(x_vector, y_vector[n:(n+location_of_ydata)], 'ro')
AttributeError: 'numpy.ndarray' object has no attribute 'plot'
14
Lucubrator

Si vous déboguez votre programme en imprimant simplement ax, vous découvrirez rapidement que ax est un tableau à deux dimensions: une dimension pour les lignes, une pour les colonnes.

Ainsi, vous avez besoin de deux index pour indexer ax pour récupérer l'instance réelle AxesSubplot, comme:

ax[1,1].plot(...)

Si vous voulez parcourir les sous-tracés comme vous le faites maintenant, en aplatissant ax d'abord:

ax = ax.flatten()

et maintenant ax est un tableau unidimensionnel. Je ne sais pas si les lignes ou les colonnes sont franchies en premier, mais si ce n'est pas le cas, utilisez la transposition:

ax = ax.T.flatten()

Bien sûr, il est désormais plus logique de simplement créer chaque sous-intrigue à la volée, car il a déjà un index et les deux autres nombres sont fixes:

for x < plots_tot:
     ax = plt.subplot(nrows, ncols, x+1)

Remarque: vous avez x <= plots_tot, mais avec x commençant à 0, vous obtiendrez un IndexError avec votre code actuel (après l'aplatissement de votre tableau). Matplotlib est (malheureusement) indexé 1 pour les sous-parcelles. Je préfère utiliser une variable indexée 0 (style Python) et ajouter simplement +1 pour l'index de sous-intrigue (comme ci-dessus).

33
user707650

Si vous utilisez des graphiques N par 1, par exemple si vous aimez fig, ax = plt.subplots(3, 1) alors faites comme ax[plot_count].plot(...)

0
Cloud Cho