J'ai le code suivant:
a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
J'ai un index multidimensionnel b
et je veux l'utiliser pour sélectionner une seule cellule dans a
. Si b n'était pas un tenseur, je pourrais faire:
a[1,1,1,1]
Qui renvoie la bonne cellule, mais:
a[b]
Ne fonctionne pas, car il sélectionne simplement a[1]
quatre fois.
Comment puis-je faire ceci? Merci
Une solution plus élégante (et plus simple) pourrait être de simplement convertir b
en Tuple:
a[Tuple(b)]
Out[10]: tensor(5.)
J'étais curieux de voir comment cela fonctionne avec numpy "ordinaire", et j'ai trouvé un article connexe expliquant cela assez bien ici .
Vous pouvez diviser b
en 4 à l'aide de chunk
, puis utiliser le fragmenté b
pour indexer l'élément spécifique que vous souhaitez:
>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)] # here's the trick!
Out[24]: tensor([[40, 80, 0]])
Ce qui est bien, c'est qu'il peut être facilement généralisé à n'importe quelle dimension de a
, il vous suffit de faire en sorte que le nombre de mandrins soit égal à la dimension de a
.