web-dev-qa-db-fra.com

Convertir le tenseur PyTorch en python

Comment convertir un PyTorch Tensor en une liste python?

Mon cas d'utilisation actuel consiste à convertir un tenseur de taille [1, 2048, 1, 1] dans une liste de 2048 éléments.

Mon tenseur a des valeurs en virgule flottante. Existe-t-il une solution qui prend également en compte l'int et éventuellement d'autres types de données?

19
Tom Hale

J'ai trouvé Tensor.tolist() qui donne l'exemple d'utilisation suivant:

>>> a = torch.randn(2, 2)
>>> a.tolist()
[[0.012766935862600803, 0.5415473580360413],
 [-0.08909505605697632, 0.7729271650314331]]
>>> a[0,0].tolist()
0.012766935862600803

Donc, pour répondre à la question, utilisez a.squeeze().tolist() pour supprimer toutes les dimensions de la taille 1.

Considérez également .flatten() si une liste de listes n'est pas souhaitée.


Avant de rencontrer .tolist(), j'utilisais:

list = [element.item() for element in tensor.flatten()]

Cela aplatit le tenseur en une seule dimension, puis appelle .item() pour convertir chaque élément en un nombre Python).

30
Tom Hale