web-dev-qa-db-fra.com

Filtrer les lignes d'un tableau numpy?

Je cherche à appliquer une fonction à chaque ligne d'un tableau numpy. Si cette fonction a la valeur true, je conserverai la ligne, sinon je la rejetterai. Par exemple, ma fonction pourrait être:

def f(row):
    if sum(row)>10: return True
    else: return False

Je me demandais s'il y avait quelque chose de similaire à:

np.apply_over_axes()

qui applique une fonction à chaque ligne d'un tableau numpy et renvoie le résultat. J'espérais quelque chose comme:

np.filter_over_axes()

qui appliquerait une fonction à chaque ligne d'un tableau numpy et ne retournerait que les lignes pour lesquelles la fonction a renvoyé true. Y a-t-il quelque chose comme ça? Ou devrais-je simplement utiliser une boucle for?

24
kyphos

Idéalement, vous pourriez implémenter une version vectorisée de votre fonction et l'utiliser pour faire indexation booléenne . Pour la grande majorité des problèmes, c'est la bonne solution. Numpy fournit un certain nombre de fonctions pouvant agir sur différents axes ainsi que toutes les opérations de base et les comparaisons, de sorte que la plupart des conditions utiles doivent être vectorisables.

import numpy as np

x = np.random.randn(20, 3)
x_new = x[np.sum(x, axis=1) > .5]

Si vous êtes absolument sûr de ne pas pouvoir faire ce qui précède, je vous suggère d'utiliser une liste de compréhension (ou np.apply_along_axis ) pour créer un tableau de bools à indexer.

def myfunc(row):
    return sum(row) > .5

bool_arr = np.array([myfunc(row) for row in x])
x_new = x[bool_arr]

Cela permettra de faire le travail d'une manière relativement propre, mais sera beaucoup plus lent qu'une version vectorisée. Un exemple:

x = np.random.randn(5000, 200)

%timeit x[np.sum(x, axis=1) > .5]
# 100 loops, best of 3: 5.71 ms per loop

%timeit x[np.array([myfunc(row) for row in x])]
# 1 loops, best of 3: 217 ms per loop
30
Roger Fan