web-dev-qa-db-fra.com

numpy: quelle est la logique des fonctions argmin () et argmax ()?

Je ne peux pas comprendre les résultats de argmax et argmin lorsqu’ils sont utilisés avec le paramètre axis. Par exemple:

>>> a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]])
>>> a
array([[ 1,  2,  4,  7],
       [ 9, 88,  6, 45],
       [ 9, 76,  3,  4]])
>>> a.shape
(3, 4)
>>> a.size
12
>>> np.argmax(a)
5
>>> np.argmax(a,axis=0)
array([1, 1, 1, 1])
>>> np.argmax(a,axis=1)
array([3, 1, 1])
>>> np.argmin(a)
0
>>> np.argmin(a,axis=0)
array([0, 0, 2, 2])
>>> np.argmin(a,axis=1)
array([0, 2, 2])

Comme vous pouvez le constater, la valeur maximale est le point (1,1) et la valeur minimale est le point (0,0). Donc, dans ma logique quand je cours:

  • np.argmin(a,axis=0) Je m'attendais array([0,0,0,0]) 
  • np.argmin(a,axis=1) Je m'attendais array([0,0,0]) 
  • np.argmax(a,axis=0) Je m'attendais array([1,1,1,1]) 
  • np.argmax(a,axis=1) Je m'attendais array([1,1,1]) 

Quel est le problème avec ma compréhension des choses?

24
user4584333

En ajoutant l'argument axis, NumPy examine les lignes et les colonnes individuellement. Lorsqu'il n'est pas donné, le tableau a est aplati dans un tableau 1D unique.

axis=0 signifie que l'opération est effectuée down les colonnes d'un tableau 2D a à son tour.

Par exemple, np.argmin(a, axis=0) renvoie l'index de la valeur minimale dans chacune des quatre colonnes. La valeur minimale dans chaque colonne est indiquée en gras ci-dessous:

>>> a
array([[ 1,  2,  4,  7],  # 0
       [ 9, 88,  6, 45],  # 1
       [ 9, 76,  3,  4]]) # 2

>>> np.argmin(a, axis=0)
array([0, 0, 2, 2])

axis=1 signifie en revanche que l'opération est effectuée sur les lignes de a

Cela signifie que np.argmin(a, axis=1) renvoie [0, 2, 2] car a a trois lignes. L'index de la valeur minimum dans la première ligne est 0, celui de la valeur minimum des deuxième et troisième lignes est 2:

>>> a
#        0   1   2   3
array([[ 1,  2,  4,  7],
       [ 9, 88,  6, 45],
       [ 9, 76,  3,  4]])

>>> np.argmin(a, axis=1)
array([0, 2, 2])
40
Alex Riley

La fonction np.argmax fonctionne par défaut le long du tableau aplati , sauf si vous spécifiez un axe. Pour voir ce qui se passe, vous pouvez utiliser flatten explicitement:

np.argmax(a)
>>> 5

a.flatten()
>>>> array([ 1,  2,  4,  7,  9, 88,  6, 45,  9, 76,  3,  4])
             0   1   2   3   4   5 

J'ai numéroté les index sous le tableau ci-dessus pour le rendre plus clair. Notez que les index sont numérotés à partir de zéro dans numpy

Dans les cas où vous spécifiez l’axe, il fonctionne également comme prévu:

np.argmax(a,axis=0)
>>> array([1, 1, 1, 1])

Cela vous indique que la plus grande valeur se trouve dans la ligne 1 (2ème valeur), pour chaque colonne de axis=0 (en bas). Vous pouvez voir cela plus clairement si vous modifiez un peu vos données:

a=np.array([[100,2,4,7],[9,88,6,45],[9,76,3,100]])
a
>>> array([[100,   2,   4,   7],
           [  9,  88,   6,  45],
           [  9,  76,   3, 100]])

np.argmax(a, axis=0)
>>> array([0, 1, 1, 2])

Comme vous pouvez le constater, il identifie maintenant la valeur maximale de la ligne 0 pour la colonne 1, de la ligne 1 pour les colonnes 2 et 3 et de la ligne 3 pour la colonne 4.

Il existe un guide utile pour l’indexation numpy dans la documentation .

5
mfitzp

Remarque: si vous voulez trouver les coordonnées de votre valeur maximale dans le tableau complet, vous pouvez utiliser

a=np.array([[1,2,4,7],[9,88,6,45],[9,76,3,4]])
>>> a
[[ 1  2  4  7]
 [ 9 88  6 45]
 [ 9 76  3  4]]

c=(np.argmax(a)/len(a[0]),np.argmax(a)%len(a[0]))
>>> c
(1, 1)
4
MartijnVanAttekum
""" ....READ THE COMMENTS FOR CLARIFICATION....."""

import numpy as np
a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]])

"""np.argmax(a) will give index of max value in flatted array of given matrix """
>>np.arg(max)
5

"""np.argmax(a,axis=0) will return list of indexes of  max value column-wise"""
>>print(np.argmax(a,axis=0))
[1,1,1,1]

"""np.argmax(a,axis=1) will return list of indexes of  max value row-wise"""
>>print(np.argmax(a,axis=1))
[3,1,1]

"""np.argmin(a) will give index of min value in flatted array of given matrix """
>>np.arg(min)
0

"""np.argmin(a,axis=0) will return list of indexes of  min value column-wise"""
>>print(np.argmin(a,axis=0))
[0,0,2,2]

"""np.argmin(a,axis=0) will return list of indexes of  min value row-wise"""
>>print(np.argmin(a,axis=1))
[0,2,2]
2
Nitin Ashutosh

L'axe dans l'argument de la fonction argmax fait référence à l'axe le long duquel le tableau sera découpé. 

Dans un autre mot, np.argmin(a,axis=0) est effectivement identique à np.apply_along_axis(np.argmin, 0, a), c’est-à-dire que l’emplacement minimum de ces vecteurs découpés le long de l’axe = 0.

Par conséquent, dans votre exemple, np.argmin(a, axis=0) est [0, 0, 2, 2], ce qui correspond aux valeurs de [1, 2, 3, 4] dans les colonnes respectives.

0
xingzhi.sg