web-dev-qa-db-fra.com

Comment utiliser le multitraitement PyTorch?

J'essaie d'utiliser le multiprocessing Pool méthode de python dans pytorch pour traiter une image. Voici le code:

from multiprocessing import Process, Pool
from torch.autograd import Variable
import numpy as np
from scipy.ndimage import zoom

def get_pred(args):

  img = args[0]
  scale = args[1]
  scales = args[2]
  img_scale = zoom(img.numpy(),
                     (1., 1., scale, scale),
                     order=1,
                     prefilter=False,
                     mode='nearest')

  # feed input data
  input_img = Variable(torch.from_numpy(img_scale),
                     volatile=True).cuda()
  return input_img

scales = [1,2,3,4,5]
scale_list = []
for scale in scales: 
    scale_list.append([img,scale,scales])
multi_pool = Pool(processes=5)
predictions = multi_pool.map(get_pred,scale_list)
multi_pool.close() 
multi_pool.join()

Je reçois cette erreur:

`RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

`Dans cette ligne:

predictions = multi_pool.map(get_pred,scale_list)

Quelqu'un peut-il me dire ce que je fais mal?

4
Rahul

Comme indiqué dans documentation pytorch la meilleure pratique pour gérer le multitraitement est d'utiliser torch.multiprocessing au lieu de multiprocessing.

Sachez que le partage des tenseurs CUDA entre les processus n'est pris en charge que dans Python 3, soit avec spawn ou forkserver comme méthode de démarrage.

Sans toucher à votre code, une solution de contournement pour l'erreur que vous avez obtenue remplace

from multiprocessing import Process, Pool

avec:

from torch.multiprocessing import Pool, Process, set_start_method
try:
     set_start_method('spawn')
except RuntimeError:
    pass
3
nicobonne

Je vous suggère de lire les docs du module multiprocessing, en particulier cette section . Vous devrez modifier la façon dont les sous-processus sont créés en appelant set_start_method. Tiré de ces documents cités:

import multiprocessing as mp

def foo(q):
    q.put('hello')

if __== '__main__':
    mp.set_start_method('spawn')
    q = mp.Queue()
    p = mp.Process(target=foo, args=(q,))
    p.start()
    print(q.get())
    p.join()
0
Oliver Baumann