web-dev-qa-db-fra.com

assertAlmostEqual dans le test unitaire Python pour les collections de flottants

La méthode assertAlmostEqual (x, y) } du framework de test unitaire de Python teste si x et y sont approximativement égaux, en supposant qu'ils sont des flottants.

Le problème avec assertAlmostEqual() est qu’il ne fonctionne que sur les flottants. Je cherche une méthode comme assertAlmostEqual() qui fonctionne avec des listes de flotteurs, des ensembles de flotteurs, des dictionnaires de flotteurs, des tuples de flotteurs, des listes de tuples de flotteurs, des ensembles de listes de flotteurs, etc.

Par exemple, prenons x = 0.1234567890, y = 0.1234567891. x et y sont presque égaux car ils conviennent de tous les chiffres sauf le dernier. Par conséquent, self.assertAlmostEqual(x, y) est True car assertAlmostEqual() fonctionne pour les flottants.

Je recherche une assertAlmostEquals() plus générique, qui évalue également les appels suivants à True:

  • self.assertAlmostEqual_generic([x, x, x], [y, y, y]).
  • self.assertAlmostEqual_generic({1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}).
  • self.assertAlmostEqual_generic([(x,x)], [(y,y)]).

Existe-t-il une telle méthode ou dois-je la mettre en œuvre moi-même?

Clarifications:

  • assertAlmostEquals() a un paramètre facultatif nommé places et les nombres sont comparés en calculant la différence arrondie au nombre décimal places. Par défaut, places=7, self.assertAlmostEqual(0.5, 0.4) est donc False tandis que self.assertAlmostEqual(0.12345678, 0.12345679) est True. Ma assertAlmostEqual_generic() spéculative devrait avoir les mêmes fonctionnalités.

  • Deux listes sont considérées presque égales si elles ont des nombres presque égaux dans le même ordre. formellement, for i in range(n): self.assertAlmostEqual(list1[i], list2[i]).

  • De même, deux ensembles sont considérés presque égaux s'ils peuvent être convertis en listes presque égales (en attribuant un ordre à chaque ensemble).

  • De même, deux dictionnaires sont considérés presque égaux si l'ensemble de clés de chaque dictionnaire est presque égal à l'ensemble de clés de l'autre dictionnaire et, pour chaque paire de clés presque égale, une valeur correspondante est presque égale.

  • En général: je considère que deux collections sont presque égales si elles le sont, à l’exception de certains flottants correspondants, qui sont presque égaux les uns aux autres. En d’autres termes, j’aimerais vraiment comparer les objets mais avec une précision faible (personnalisée) lorsqu’on compare des flottants en cours de route.

53
snakile

Voici comment j'ai implémenté une fonction is_almost_equal(first, second) générique:

Commencez par dupliquer les objets que vous devez comparer (first et second), mais ne faites pas de copie exacte: coupez les chiffres décimaux non significatifs de tous les flottants rencontrés à l'intérieur de l'objet.

Maintenant que vous disposez de copies de first et second pour lesquelles les chiffres décimaux non significatifs ont disparu, comparez simplement first et second à l'aide de l'opérateur ==.

Supposons que nous ayons une fonction cut_insignificant_digits_recursively(obj, places) qui duplique obj mais ne laisse que les places chiffres décimaux les plus significatifs de chaque flottant dans le obj d'origine. Voici une implémentation fonctionnelle de is_almost_equals(first, second, places):

from insignificant_digit_cutter import cut_insignificant_digits_recursively

def is_almost_equal(first, second, places):
    '''returns True if first and second equal. 
    returns true if first and second aren't equal but have exactly the same
    structure and values except for a bunch of floats which are just almost
    equal (floats are almost equal if they're equal when we consider only the
    [places] most significant digits of each).'''
    if first == second: return True
    cut_first = cut_insignificant_digits_recursively(first, places)
    cut_second = cut_insignificant_digits_recursively(second, places)
    return cut_first == cut_second

Et voici une implémentation fonctionnelle de cut_insignificant_digits_recursively(obj, places):

def cut_insignificant_digits(number, places):
    '''cut the least significant decimal digits of a number, 
    leave only [places] decimal digits'''
    if  type(number) != float: return number
    number_as_str = str(number)
    end_of_number = number_as_str.find('.')+places+1
    if end_of_number > len(number_as_str): return number
    return float(number_as_str[:end_of_number])

def cut_insignificant_digits_lazy(iterable, places):
    for obj in iterable:
        yield cut_insignificant_digits_recursively(obj, places)

def cut_insignificant_digits_recursively(obj, places):
    '''return a copy of obj except that every float loses its least significant 
    decimal digits remaining only [places] decimal digits'''
    t = type(obj)
    if t == float: return cut_insignificant_digits(obj, places)
    if t in (list, Tuple, set):
        return t(cut_insignificant_digits_lazy(obj, places))
    if t == dict:
        return {cut_insignificant_digits_recursively(key, places):
                cut_insignificant_digits_recursively(val, places)
                for key,val in obj.items()}
    return obj

Le code et ses tests unitaires sont disponibles ici: https://github.com/snakile/approximate_comparator . Je me félicite de toute amélioration et correction de bug.

7
snakile

si cela ne vous dérange pas d'utiliser NumPy (fourni avec votre Python (x, y)), vous pouvez consulter le module np.testing qui définit, entre autres, une fonction assert_almost_equal.

La signature est np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)

>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError: 
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)
50
Pierre GM

A partir de python 3.5, vous pouvez comparer en utilisant

math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)

Comme décrit dans pep-0485 . La mise en œuvre devrait être équivalente à 

abs(a-b) <= max( rel_tol * max(abs(a), abs(b)), abs_tol )

Si cela ne vous dérange pas d'utiliser le package numpy, alors numpy.testing utilise la méthode assert_array_almost_equal

Cela fonctionne pour les objets array_like, il convient donc pour les tableaux, les listes et les nuplets de flottants, mais ne fonctionne pas pour les ensembles et les dictionnaires.

La documentation est ici .

4
DJCowley

Il n'y a pas une telle méthode, vous devriez le faire vous-même.

Pour les listes et les nuplets, la définition est évidente, mais notez que les autres cas que vous mentionnez ne le sont pas. Il n’est donc pas étonnant que cette fonction ne soit pas fournie. Par exemple, {1.00001: 1.00002} est-il presque égal à {1.00002: 1.00001}? Le traitement de tels cas nécessite de choisir si la proximité dépend de clés ou de valeurs, ou des deux. Il est peu probable que vous trouviez une définition significative pour les ensembles car ceux-ci ne sont pas ordonnés. Il n’existe donc aucune notion d’éléments "correspondants".

4
BrenBarn

J'utiliserais toujours self.assertEqual() car c'est la solution la plus informative lorsque la merde frappe le ventilateur. Vous pouvez le faire en arrondissant, par exemple.

self.assertEqual(round_Tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))

round_Tuple est

def round_Tuple(t: Tuple, ndigits: int) -> Tuple:
    return Tuple(round(e, ndigits=ndigits) for e in t)

def round_list(l: list, ndigits: int) -> list:
    return [round(e, ndigits=ndigits) for e in l]

Selon la python docs (voir https://stackoverflow.com/a/41407651/1031191 ), vous pouvez éviter les problèmes d'arrondi tels que 13.94999999, car 13.94999999 == 13.95 est True.

0
Barnabas Szabolcs

Aucune de ces réponses ne fonctionne pour moi. Le code suivant devrait fonctionner pour les collections Python, les classes, les classes de données et les exemples nommés. J'aurais peut-être oublié quelque chose, mais jusqu'à présent, cela fonctionne pour moi.

import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any


def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
    """
    Compares two objects by recursively walking them trough. Equality is as usual except for floats.
    Floats are compared according to the two measures defined below.

    :param o1: The first object.
    :param o2: The second object.
    :param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
    `abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
    :param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
    :return: Whether the two objects are almost equal.
    """
    if type(o1) != type(o2):
        return False

    composite_type_passed = False

    if hasattr(o1, '__slots__'):
        if len(o1.__slots__) != len(o2.__slots__):
            return False
        if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
                                    max_abs_ratio_diff, max_abs_diff)
            for s1, s2 in Zip(sorted(o1.__slots__), sorted(o2.__slots__))):
            return False
        else:
            composite_type_passed = True

    if hasattr(o1, '__dict__'):
        if len(o1.__dict__) != len(o2.__dict__):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2))
            in Zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
            if not k1.startswith('__')):  # avoid infinite loops
            return False
        else:
            composite_type_passed = True

    if isinstance(o1, dict):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
            or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for ((k1, v1), (k2, v2)) in Zip(sorted(o1.items()), sorted(o2.items()))):
            return False

    Elif any(issubclass(o1.__class__, c) for c in (list, Tuple, set)):
        if len(o1) != len(o2):
            return False
        if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
            for v1, v2 in Zip(o1, o2)):
            return False

    Elif isinstance(o1, float):
        if o1 == o2:
            return True
        else:
            if max_abs_ratio_diff > 0:  # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
                if o2 != 0:
                    if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
                        return False
                else:  # if both == 0, we already returned True
                    if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
                        return False
            if 0 < max_abs_diff < abs(o1 - o2):  # if max_abs_diff < 0, max_abs_diff is ignored
                return False
            return True

    else:
        if not composite_type_passed:
            return o1 == o2

    return True


class EqualityTest(unittest.TestCase):

    def test_floats(self) -> None:
        o1 = ('hi', 3, 3.4)
        o2 = ('hi', 3, 3.400001)
        self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
        self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))

    def test_ratio_only(self):
        o1 = ['hey', 10000, 123.12]
        o2 = ['hey', 10000, 123.80]
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))

    def test_diff_only(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 1234567890.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, 1))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))

    def test_both_ignored(self):
        o1 = ['hey', 10000, 1234567890.12]
        o2 = ['hey', 10000, 0.80]
        o3 = ['hi', 10000, 0.80]
        self.assertTrue(are_almost_equal(o1, o2, -1, -1))
        self.assertFalse(are_almost_equal(o1, o3, -1, -1))

    def test_different_lengths(self):
        o1 = ['hey', 1234567890.12, 10000]
        o2 = ['hey', 1234567890.80]
        self.assertFalse(are_almost_equal(o1, o2, 1, 1))

    def test_classes(self):
        class A:
            d = 12.3

            def __init__(self, a, b, c):
                self.a = a
                self.b = b
                self.c = c

        o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
        self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))

        o2.hello = 'hello'
        self.assertFalse(are_almost_equal(o1, o2, -1, -1))

    def test_namedtuples(self):
        B = namedtuple('B', ['x', 'y'])
        o1 = B(3.3, 4.4)
        o2 = B(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
        self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))

    def test_classes_with_slots(self):
        class C(object):
            __slots__ = ['a', 'b']

            def __init__(self, a, b):
                self.a = a
                self.b = b

        o1 = C(3.3, 4.4)
        o2 = C(3.4, 4.5)
        self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))

    def test_dataclasses(self):
        @dataclass
        class D:
            s: str
            i: int
            f: float

        @dataclass
        class E:
            f2: float
            f4: str
            d: D

        o1 = E(12.3, 'hi', D('hello', 34, 20.01))
        o2 = E(12.1, 'hi', D('hello', 34, 20.0))
        self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
        self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))

        o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
        self.assertFalse(are_almost_equal(o2, o3, -1, -1))

    def test_ordereddict(self):
        o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
        o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
        self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
        self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))
0
redsk

Vous devrez peut-être l'implémenter vous-même, alors qu'il est vrai que la liste et les ensembles peuvent être itérés de la même manière, les dictionnaires sont une histoire différente, vous répétez leurs clés et non des valeurs, et le troisième exemple me semble un peu ambigu. comparez chaque valeur de l'ensemble ou chaque valeur de chaque ensemble.

heres un extrait de code simple.

def almost_equal(value_1, value_2, accuracy = 10**-8):
    return abs(value_1 - value_2) < accuracy

x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in Zip(x, y))
0
Samy Vilar

Une autre approche consiste à convertir vos données en une forme comparable, par exemple en transformant chaque float en chaîne avec une précision fixe.

def comparable(data):
    """Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
    if isinstance(data, (int, str)):
        return data
    if isinstance(data, float):
        return '{:.4f}'.format(data)
    if isinstance(data, list):
        return [comparable(el) for el in data]
    if isinstance(data, Tuple):
        return Tuple([comparable(el) for el in data])
    if isinstance(data, dict):
        return {k: comparable(v) for k, v in data.items()}

Ensuite vous pouvez:

self.assertEquals(comparable(value1), comparable(value2))
0
Karl Rosaen