web-dev-qa-db-fra.com

Comparaison de deux tableaux numpy pour l'égalité, élément par élément

Quel est le moyen le plus simple de comparer deux tableaux numpy pour l’égalité (l’égalité étant définie par: A = B ssigne de tous les indices i: A[i] == B[i])?

Simplement utiliser == me donne un tableau booléen:

 >>> numpy.array([1,1,1]) == numpy.array([1,1,1])

array([ True,  True,  True], dtype=bool)

Dois-je and les éléments de ce tableau pour déterminer si les tableaux sont égaux ou existe-t-il un moyen plus simple de comparer?

192
clstaudt
(A==B).all()

teste si toutes les valeurs du tableau (A == B) sont vraies.

Edit (extrait de la réponse de dbaupp et du commentaire de yoavram)

Il convient de noter que:

  • cette solution peut avoir un comportement étrange dans un cas particulier: si A ou B est vide et si l'autre contient un seul élément, il retourne True. Pour une raison quelconque, la comparaison A==B renvoie un tableau vide pour lequel l'opérateur all renvoie True.
  • Un autre risque est que si A et B n'aient pas la même forme et ne puissent pas être diffusés, cette approche générera une erreur.

En conclusion, je pense que la solution que j'ai proposée est la solution standard, mais si vous avez un doute sur la forme de A et B ou si vous voulez simplement être sûr: utilisez l'une des fonctions spécialisées suivantes:

np.array_equal(A,B)  # test if same shape, same elements values
np.array_equiv(A,B)  # test if broadcastable shape, same elements values
np.allclose(A,B,...) # test if same shape, elements have close enough values
291
Juh_

La solution _(A==B).all()_ est très soignée, mais il existe des fonctions intégrées pour cette tâche. À savoir array_equal , allclose et array_equiv .

(Bien que quelques tests rapides avec timeit semblent indiquer que la méthode _(A==B).all()_ est la plus rapide, ce qui est un peu étrange, étant donné qu'elle doit allouer un tout nouveau tableau.)

81
huon

Mesurons les performances en utilisant le code suivant.

import numpy as np
import time

exec_time0 = []
exec_time1 = []
exec_time2 = []

sizeOfArray = 5000
numOfIterations = 200

for i in xrange(numOfIterations):

    A = np.random.randint(0,255,(sizeOfArray,sizeOfArray))
    B = np.random.randint(0,255,(sizeOfArray,sizeOfArray))

    a = time.clock() 
    res = (A==B).all()
    b = time.clock()
    exec_time0.append( b - a )

    a = time.clock() 
    res = np.array_equal(A,B)
    b = time.clock()
    exec_time1.append( b - a )

    a = time.clock() 
    res = np.array_equiv(A,B)
    b = time.clock()
    exec_time2.append( b - a )

print 'Method: (A==B).all(),       ', np.mean(exec_time0)
print 'Method: np.array_equal(A,B),', np.mean(exec_time1)
print 'Method: np.array_equiv(A,B),', np.mean(exec_time2)

sortie

Method: (A==B).all(),        0.03031857
Method: np.array_equal(A,B), 0.030025185
Method: np.array_equiv(A,B), 0.030141515

D'après les résultats ci-dessus, les méthodes numpy semblent être plus rapides que la combinaison de l'opérateur == et de la méthode all () et en comparant les méthodes numpy le plus rapide il semble que ce soit la méthode numpy.array_equal.

13
funk

Si vous voulez vérifier si deux tableaux ont le même shape ET elements, vous devez utiliser np.array_equal car c'est la méthode recommandée dans la documentation.

En termes de performances, il ne faut pas s'attendre à ce qu'un contrôle d'égalité en vienne à un autre, car il n'y a pas beaucoup de place pour optimiser comparing two elements. Juste pour le plaisir, j'ai quand même fait des tests.

import numpy as np
import timeit

A = np.zeros((300, 300, 3))
B = np.zeros((300, 300, 3))
C = np.ones((300, 300, 3))

timeit.timeit(stmt='(A==B).all()', setup='from __main__ import A, B', number=10**5)
timeit.timeit(stmt='np.array_equal(A, B)', setup='from __main__ import A, B, np', number=10**5)
timeit.timeit(stmt='np.array_equiv(A, B)', setup='from __main__ import A, B, np', number=10**5)
> 51.5094
> 52.555
> 52.761

Tellement égal, pas besoin de parler de vitesse.

Le (A==B).all() se comporte assez bien comme l'extrait de code suivant:

x = [1,2,3]
y = [1,2,3]
print all([x[i]==y[i] for i in range(len(x))])
> True
10
user1767754