web-dev-qa-db-fra.com

Matplotlib: afficher les valeurs d'un tableau avec imshow

J'essaie de créer une grille en utilisant une fonction matplotlib comme imshow.
À partir de ce tableau:

[[ 1  8 13 29 17 26 10  4],
[16 25 31  5 21 30 19 15]]

Je voudrais tracer la valeur sous forme de couleur ET la valeur de texte elle-même (1,2, ...) sur la même grille. Voici ce que j'ai pour le moment (je ne peux tracer que la couleur associée à chaque valeur):

from matplotlib import pyplot
import numpy as np

grid = np.array([[1,8,13,29,17,26,10,4],[16,25,31,5,21,30,19,15]])
print 'Here is the array'
print grid

fig1, (ax1, ax2)= pyplot.subplots(2, sharex = True, sharey = False)
ax1.imshow(grid, interpolation ='none', aspect = 'auto')
ax2.imshow(grid, interpolation ='bicubic', aspect = 'auto')
pyplot.show()   
21
Ludovic Aphtes

Si, pour une raison quelconque, vous devez utiliser une extension différente de celle qui est fournie naturellement par imshow la méthode suivante (même si plus artificiel) fait le travail:

enter image description here

size = 4
data = np.arange(size * size).reshape((size, size))

# Limits for the extent
x_start = 3.0
x_end = 9.0
y_start = 6.0
y_end = 12.0

extent = [x_start, x_end, y_start, y_end]

# The normal figure
fig = plt.figure(figsize=(16, 12))
ax = fig.add_subplot(111)
im = ax.imshow(data, extent=extent, Origin='lower', interpolation='None', cmap='viridis')

# Add the text
jump_x = (x_end - x_start) / (2.0 * size)
jump_y = (y_end - y_start) / (2.0 * size)
x_positions = np.linspace(start=x_start, stop=x_end, num=size, endpoint=False)
y_positions = np.linspace(start=y_start, stop=y_end, num=size, endpoint=False)

for y_index, y in enumerate(y_positions):
    for x_index, x in enumerate(x_positions):
        label = data[y_index, x_index]
        text_x = x + jump_x
        text_y = y + jump_y
        ax.text(text_x, text_y, label, color='black', ha='center', va='center')

fig.colorbar(im)
plt.show()

Si vous souhaitez mettre un autre type de données et pas nécessairement les valeurs que vous avez utilisées pour l'image, vous pouvez modifier le script ci-dessus de la manière suivante (ajouté valeurs après données):

enter image description here

size = 4
data = np.arange(size * size).reshape((size, size))
values = np.random.Rand(size, size)

# Limits for the extent
x_start = 3.0
x_end = 9.0
y_start = 6.0
y_end = 12.0

extent = [x_start, x_end, y_start, y_end]

# The normal figure
fig = plt.figure(figsize=(16, 12))
ax = fig.add_subplot(111)
im = ax.imshow(data, extent=extent, Origin='lower', interpolation='None', cmap='viridis')

# Add the text
jump_x = (x_end - x_start) / (2.0 * size)
jump_y = (y_end - y_start) / (2.0 * size)
x_positions = np.linspace(start=x_start, stop=x_end, num=size, endpoint=False)
y_positions = np.linspace(start=y_start, stop=y_end, num=size, endpoint=False)

for y_index, y in enumerate(y_positions):
    for x_index, x in enumerate(x_positions):
        label = values[y_index, x_index]
        text_x = x + jump_x
        text_y = y + jump_y
        ax.text(text_x, text_y, label, color='black', ha='center', va='center')

fig.colorbar(im)
plt.show()
8
Ramon Martinez

Vous souhaitez parcourir les valeurs dans grid et utiliser ax.text pour ajouter l'étiquette au tracé.

Heureusement, pour les tableaux 2D, numpy a ndenumerate , ce qui rend cela assez simple:

for (j,i),label in np.ndenumerate(grid):
    ax1.text(i,j,label,ha='center',va='center')
    ax2.text(i,j,label,ha='center',va='center')

enter image description here

21
tmdavison