web-dev-qa-db-fra.com

Comment utiliser tqdm dans une exécution parallèle avec joblib?

Je veux exécuter une fonction en parallèle et attendre que tous les nœuds parallèles soient terminés, en utilisant joblib. Comme dans l'exemple:

from math import sqrt
from joblib import Parallel, delayed
Parallel(n_jobs=2)(delayed(sqrt)(i ** 2) for i in range(10))

Mais, je veux que l'exécution soit vue dans une seule barre de progression comme avec tqdm, montrant le nombre de travaux terminés.

Comment feriez-vous cela?

25
Dror Hilman

Si votre problème se compose de plusieurs parties, vous pouvez diviser les parties en sous-groupes k, exécuter chaque sous-groupe en parallèle et mettre à jour la barre de progression entre les deux, ce qui entraîne k mises à jour de la progression.

Ceci est démontré dans l'exemple suivant de la documentation.

>>> with Parallel(n_jobs=2) as parallel:
...    accumulator = 0.
...    n_iter = 0
...    while accumulator < 1000:
...        results = parallel(delayed(sqrt)(accumulator + i ** 2)
...                           for i in range(5))
...        accumulator += sum(results)  # synchronization barrier
...        n_iter += 1

https://pythonhosted.org/joblib/parallel.html#reusing-a-pool-of-workers

8
PureW

Il suffit de mettre range(10) à l'intérieur tqdm(...)! Cela semblait probablement trop beau pour être vrai pour vous, mais cela fonctionne vraiment (sur ma machine):

from math import sqrt
from joblib import Parallel, delayed  
from tqdm import tqdm  
result = Parallel(n_jobs=2)(delayed(sqrt)(i ** 2) for i in tqdm(range(100000)))
34
tyrex

Voici une solution de contournement possible

def func(x):
    time.sleep(random.randint(1, 10))
    return x

def text_progessbar(seq, total=None):
    step = 1
    tick = time.time()
    while True:
        time_diff = time.time()-tick
        avg_speed = time_diff/step
        total_str = 'of %n' % total if total else ''
        print('step', step, '%.2f' % time_diff, 
              'avg: %.2f iter/sec' % avg_speed, total_str)
        step += 1
        yield next(seq)

all_bar_funcs = {
    'tqdm': lambda args: lambda x: tqdm(x, **args),
    'txt': lambda args: lambda x: text_progessbar(x, **args),
    'False': lambda args: iter,
    'None': lambda args: iter,
}

def ParallelExecutor(use_bar='tqdm', **joblib_args):
    def aprun(bar=use_bar, **tq_args):
        def tmp(op_iter):
            if str(bar) in all_bar_funcs.keys():
                bar_func = all_bar_funcs[str(bar)](tq_args)
            else:
                raise ValueError("Value %s not supported as bar type"%bar)
            return Parallel(**joblib_args)(bar_func(op_iter))
        return tmp
    return aprun

aprun = ParallelExecutor(n_jobs=5)

a1 = aprun(total=25)(delayed(func)(i ** 2 + j) for i in range(5) for j in range(5))
a2 = aprun(total=16)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
a2 = aprun(bar='txt')(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
a2 = aprun(bar=None)(delayed(func)(i ** 2 + j) for i in range(4) for j in range(4))
7
Ben Usman

J'ai créé pqdm un wrapper tqdm parallèle avec des futurs simultanés pour y arriver confortablement, essayez-le!

À installer

pip install pqdm

et utilise

from pqdm.processes import pqdm
# If you want threads instead:
# from pqdm.threads import pqdm

args = [1, 2, 3, 4, 5]
# args = range(1,6) would also work

def square(a):
    return a*a

result = pqdm(args, square, n_jobs=2)
2
niedakh

Comme indiqué ci-dessus, les solutions qui enveloppent simplement l'itérable passé à joblib.Parallel() ne surveillent pas vraiment la progression de l'exécution. Au lieu de cela, je suggère de sous-classer Parallel et de remplacer la méthode print_progress(), comme suit:

import joblib
from tqdm.auto import tqdm

class ProgressParallel(joblib.Parallel):
    def __call__(self, *args, **kwargs):
        with tqdm() as self._pbar:
            return joblib.Parallel.__call__(self, *args, **kwargs)

    def print_progress(self):
        self._pbar.total = self.n_dispatched_tasks
        self._pbar.n = self.n_completed_tasks
        self._pbar.refresh()
0
nth