J'utilise le module unittest
de Python et je veux vérifier si deux structures de données complexes sont égales. Les objets peuvent être des listes de dict avec toutes sortes de valeurs: des nombres, des chaînes de caractères, des conteneurs Python (listes/tuples/dicts) et des tableaux numpy
. Ces derniers sont la raison de poser la question, car je ne peux pas simplement faire
self.assertEqual(big_struct1, big_struct2)
parce qu'il produit un
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
J'imagine que je dois écrire mon propre test d'égalité pour cela. Cela devrait fonctionner pour des structures arbitraires. Mon idée actuelle est une fonction récursive qui:
arg1
avec le noeud correspondant de arg2
;ValueError
est attrapé, va plus loin jusqu'à trouver un numpy.array
;Ce qui semble un peu problématique est de garder trace des nœuds "correspondants" de deux structures, mais peut-être que Zip
est tout ce dont j'ai besoin ici.
La question est la suivante: existe-t-il de bonnes (plus simples) alternatives à cette approche? Peut-être que numpy
présente des outils pour cela? Si aucune alternative n'est suggérée, je mettrai en œuvre cette idée (à moins d'en avoir une meilleure) et posterai comme réponse.
P.S. J'ai le vague sentiment d'avoir vu une question sur ce problème, mais je ne la trouve pas maintenant.
P.P.S. Une approche alternative serait une fonction qui traverse la structure et convertit tous les numpy.array
s en listes, mais est-ce plus facile à implémenter? Cela me semble pareil.
Edit: Le sous-classement numpy.ndarray
semble très prometteur, mais je n'ai évidemment pas les deux côtés de la comparaison codés en dur dans un test. L’un d’eux, cependant, est codé en dur, je peux donc:
numpy.array
;isinstance(other, SaneEqualityArray)
par isinstance(other, np.ndarray)
dans jterrace 's answer ;Mes questions à cet égard sont les suivantes:
numpy
)?.Edit 2 : Je l’ai essayé, l’implémentation (apparemment) fonctionnelle est montrée dans cette réponse .
Donc, l'idée illustrée par jterrace semble fonctionner pour moi avec une légère modification:
class SaneEqualityArray(np.ndarray):
def __eq__(self, other):
return (isinstance(other, np.ndarray) and self.shape == other.shape and
np.allclose(self, other))
Comme je l'ai dit, le conteneur avec ces objets doit être situé à gauche du contrôle d'égalité. Je crée des objets SaneEqualityArray
à partir de numpy.ndarray
existants comme ceci:
SaneEqualityArray(my_array.shape, my_array.dtype, my_array)
conformément à ndarray
signature du constructeur:
ndarray(shape, dtype=float, buffer=None, offset=0,
strides=None, order=None)
Cette classe est définie dans la suite de tests et sert uniquement à des fins de test. Le RHS du contrôle d'égalité est un objet réel renvoyé par la fonction testée et contient de vrais objets numpy.ndarray
.
P.S. Merci aux auteurs des deux réponses affichées jusqu’à présent, ils ont été très utiles. Si quelqu'un voit des problèmes avec cette approche, j'apprécierais vos commentaires.
Aurait commenté, mais ça prend trop de temps ...
Fait amusant, vous ne pouvez pas utiliser ==
pour vérifier si les tableaux sont identiques. Je vous suggérerais plutôt d'utiliser np.testing.assert_array_equal
.
(float('nan') == float('nan')) == False
(la séquence normale en python ==
a une façon encore plus amusante d’ignorer ce parfois , car il utilise PyObject_RichCompareBool
qui effectue un (pour NaNs incorrect) is
quick check ( pour tester bien sûr c'est parfait) ...assert_allclose
car l’égalité en virgule flottante peut devenir très délicate si vous effectuez des calculs réels et que vous voulez généralement presque les mêmes valeurs, car les valeurs peuvent devenir dépendantes du matériel ou éventuellement aléatoires en fonction de ce que vous en ferez.Je suggérerais presque d'essayer de le sérialiser avec pickle si vous voulez quelque chose d'aussi imbriqué, mais c'est trop strict (et le point 3 est bien sûr complètement cassé à ce moment-là), par exemple, la disposition de la mémoire de votre tableau n'a pas d'importance, mais compte pour sa la sérialisation.
La fonction assertEqual
invoquera la méthode des objets __eq__
, qui doit être répétée pour les types de données complexes. L'exception est numpy, qui n'a pas de méthode saine __eq__
. En utilisant une sous-classe numpy de cette question , vous pouvez restaurer l'intégrité du comportement d'égalité:
import copy
import numpy
import unittest
class SaneEqualityArray(numpy.ndarray):
def __eq__(self, other):
return (isinstance(other, SaneEqualityArray) and
self.shape == other.shape and
numpy.ndarray.__eq__(self, other).all())
class TestAsserts(unittest.TestCase):
def testAssert(self):
tests = [
[1, 2],
{'foo': 2},
[2, 'foo', {'d': 4}],
SaneEqualityArray([1, 2]),
{'foo': {'hey': SaneEqualityArray([2, 3])}},
[{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
SaneEqualityArray([5, 6]), 34]
]
for t in tests:
self.assertEqual(t, copy.deepcopy(t))
if __== '__main__':
unittest.main()
Ce test réussit.
Je définirais ma propre méthode assertNumpyArraysEqual () qui effectue explicitement la comparaison que vous souhaitez utiliser. Ainsi, votre code de production reste inchangé, mais vous pouvez toujours faire des affirmations raisonnables dans vos tests unitaires. Assurez-vous de le définir dans un module incluant __unittest = True
afin qu'il ne soit pas inclus dans les traces de pile:
import numpy
__unittest = True
def assertNumpyArraysEqual(self, other):
if self.shape != other.shape:
raise AssertionError("Shapes don't match")
if not numpy.allclose(self, other)
raise AssertionError("Elements don't match!")
check numpy.testing.assert_almost_equal
which "déclenche une AssertionError si deux éléments ne sont pas égaux jusqu'à la précision souhaitée" , par exemple:
import numpy.testing as npt
npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
np.array([1.0,2.33333334]), decimal=9)
J'ai rencontré le même problème et développé une fonction pour comparer l'égalité basée sur la création d'un hachage fixe pour l'objet. Cela présente l’avantage supplémentaire de pouvoir vérifier qu’un objet est conforme aux attentes en le comparant à un hachage corrigé dans le code.
Le code (un fichier python autonome, est ici ). Il existe deux fonctions: fixed_hash_eq
, qui résout votre problème, et compute_fixed_hash
, qui crée un hachage à partir de la structure. Les tests sont ici
Voici un test:
obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
En s'appuyant sur @dbw (avec mes remerciements), la méthode suivante insérée dans la sous-classe de cas de test a bien fonctionné pour moi:
def assertNumpyArraysEqual(self,this,that,msg=''):
'''
modified from http://stackoverflow.com/a/15399475/5459638
'''
if this.shape != that.shape:
raise AssertionError("Shapes don't match")
if not np.allclose(this,that):
raise AssertionError("Elements don't match!")
Je l’ai appelée self.assertNumpyArraysEqual(this,that)
dans mes méthodes de test et a fonctionné comme un charme.