web-dev-qa-db-fra.com

Comparer (affirmer l'égalité de) deux structures de données complexes contenant des tableaux numpy en unittest

J'utilise le module unittest de Python et je veux vérifier si deux structures de données complexes sont égales. Les objets peuvent être des listes de dict avec toutes sortes de valeurs: des nombres, des chaînes de caractères, des conteneurs Python (listes/tuples/dicts) et des tableaux numpy. Ces derniers sont la raison de poser la question, car je ne peux pas simplement faire

self.assertEqual(big_struct1, big_struct2)

parce qu'il produit un

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

J'imagine que je dois écrire mon propre test d'égalité pour cela. Cela devrait fonctionner pour des structures arbitraires. Mon idée actuelle est une fonction récursive qui:

  • essaie la comparaison directe du "noeud" actuel de arg1 avec le noeud correspondant de arg2;
  • si aucune exception n'est levée, continue (les nœuds/feuilles "terminaux" sont également traités ici);
  • si ValueError est attrapé, va plus loin jusqu'à trouver un numpy.array;
  • compare les tableaux (par exemple comme ceci ).

Ce qui semble un peu problématique est de garder trace des nœuds "correspondants" de deux structures, mais peut-être que Zip est tout ce dont j'ai besoin ici.

La question est la suivante: existe-t-il de bonnes (plus simples) alternatives à cette approche? Peut-être que numpy présente des outils pour cela? Si aucune alternative n'est suggérée, je mettrai en œuvre cette idée (à moins d'en avoir une meilleure) et posterai comme réponse.

P.S. J'ai le vague sentiment d'avoir vu une question sur ce problème, mais je ne la trouve pas maintenant.

P.P.S. Une approche alternative serait une fonction qui traverse la structure et convertit tous les numpy.arrays en listes, mais est-ce plus facile à implémenter? Cela me semble pareil.


Edit: Le sous-classement numpy.ndarray semble très prometteur, mais je n'ai évidemment pas les deux côtés de la comparaison codés en dur dans un test. L’un d’eux, cependant, est codé en dur, je peux donc:

  • remplissez-le avec des sous-classes personnalisées de numpy.array;
  • remplacez isinstance(other, SaneEqualityArray) par isinstance(other, np.ndarray) dans jterrace 's answer ;
  • utilisez-le toujours comme LHS dans les comparaisons.

Mes questions à cet égard sont les suivantes:

  1. Cela fonctionnera-t-il (je veux dire, cela me semble bien, mais peut-être que certains cas délicats d'Edge ne seront pas traités correctement)? Mon objet personnalisé finira-t-il toujours par devenir LHS dans les contrôles d'égalité récursifs, comme je m'y attendais?
  2. Encore une fois, existe-t-il de meilleurs moyens (étant donné que j'ai au moins une des structures avec de véritables tableaux numpy)?.

Edit 2 : Je l’ai essayé, l’implémentation (apparemment) fonctionnelle est montrée dans cette réponse .

20
Lev Levitsky

Donc, l'idée illustrée par jterrace semble fonctionner pour moi avec une légère modification:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

Comme je l'ai dit, le conteneur avec ces objets doit être situé à gauche du contrôle d'égalité. Je crée des objets SaneEqualityArray à partir de numpy.ndarray existants comme ceci:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

conformément à ndarray signature du constructeur:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

Cette classe est définie dans la suite de tests et sert uniquement à des fins de test. Le RHS du contrôle d'égalité est un objet réel renvoyé par la fonction testée et contient de vrais objets numpy.ndarray.

P.S. Merci aux auteurs des deux réponses affichées jusqu’à présent, ils ont été très utiles. Si quelqu'un voit des problèmes avec cette approche, j'apprécierais vos commentaires.

7
Lev Levitsky

Aurait commenté, mais ça prend trop de temps ...

Fait amusant, vous ne pouvez pas utiliser == pour vérifier si les tableaux sont identiques. Je vous suggérerais plutôt d'utiliser np.testing.assert_array_equal .

  1. qui vérifie le type, la forme, etc.,
  2. cela n’échoue pas pour la mathématique soignée de (float('nan') == float('nan')) == False (la séquence normale en python == a une façon encore plus amusante d’ignorer ce parfois , car il utilise PyObject_RichCompareBool qui effectue un (pour NaNs incorrect) is quick check ( pour tester bien sûr c'est parfait) ...
  3. Il existe également assert_allclose car l’égalité en virgule flottante peut devenir très délicate si vous effectuez des calculs réels et que vous voulez généralement presque les mêmes valeurs, car les valeurs peuvent devenir dépendantes du matériel ou éventuellement aléatoires en fonction de ce que vous en ferez.

Je suggérerais presque d'essayer de le sérialiser avec pickle si vous voulez quelque chose d'aussi imbriqué, mais c'est trop strict (et le point 3 est bien sûr complètement cassé à ce moment-là), par exemple, la disposition de la mémoire de votre tableau n'a pas d'importance, mais compte pour sa la sérialisation.

12
seberg

La fonction assertEqual invoquera la méthode des objets __eq__, qui doit être répétée pour les types de données complexes. L'exception est numpy, qui n'a pas de méthode saine __eq__. En utilisant une sous-classe numpy de cette question , vous pouvez restaurer l'intégrité du comportement d'égalité:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __== '__main__':
    unittest.main()

Ce test réussit.

8
jterrace

Je définirais ma propre méthode assertNumpyArraysEqual () qui effectue explicitement la comparaison que vous souhaitez utiliser. Ainsi, votre code de production reste inchangé, mais vous pouvez toujours faire des affirmations raisonnables dans vos tests unitaires. Assurez-vous de le définir dans un module incluant __unittest = True afin qu'il ne soit pas inclus dans les traces de pile:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")
2
dbn

check numpy.testing.assert_almost_equal which "déclenche une AssertionError si deux éléments ne sont pas égaux jusqu'à la précision souhaitée" , par exemple:

 import numpy.testing as npt
 npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
                         np.array([1.0,2.33333334]), decimal=9)
1
Hanan Shteingart

J'ai rencontré le même problème et développé une fonction pour comparer l'égalité basée sur la création d'un hachage fixe pour l'objet. Cela présente l’avantage supplémentaire de pouvoir vérifier qu’un objet est conforme aux attentes en le comparant à un hachage corrigé dans le code.

Le code (un fichier python autonome, est ici ). Il existe deux fonctions: fixed_hash_eq, qui résout votre problème, et compute_fixed_hash, qui crée un hachage à partir de la structure. Les tests sont ici

Voici un test:

obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
1
Peter

En s'appuyant sur @dbw (avec mes remerciements), la méthode suivante insérée dans la sous-classe de cas de test a bien fonctionné pour moi:

 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from http://stackoverflow.com/a/15399475/5459638
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

Je l’ai appelée self.assertNumpyArraysEqual(this,that) dans mes méthodes de test et a fonctionné comme un charme.

0
XavierStuvw