J'utilise la bibliothèque tqdm et cela ne me donne pas la barre de progression, mais plutôt une sortie qui ressemble à ceci où elle me dit juste l'itération:
251it [01:44, 2.39it/s]
Une idée pourquoi le code ferait cela? Je pensais que c'était peut-être parce que je lui passais un générateur, mais là encore, j'ai utilisé des générateurs dans le passé qui ont fonctionné. Je n'ai jamais vraiment joué avec le formatage tdqm auparavant. Voici une partie du code source:
train_iter = Zip(train_x, train_y) #train_x and train_y are just lists of elements
....
def train(train_iter, model, criterion, optimizer):
model.train()
total_loss = 0
for x, y in tqdm(train_iter):
x = x.transpose(0, 1)
y = y.transpose(0, 1)
optimizer.zero_grad()
bloss = model.forward(x, y, criterion)
bloss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
optimizer.step()
total_loss += bloss.data[0]
return total_loss
tqdm
doit savoir combien d'itérations seront effectuées (le montant total) pour afficher une barre de progression.
Vous pouvez essayer ceci:
from tqdm import tqdm
train_x = range(100)
train_y = range(200)
train_iter = Zip(train_x, train_y)
# Notice `train_iter` can only be iter over once, so i get `total` in this way.
total = min(len(train_x), len(train_y))
with tqdm(total=total) as pbar:
for item in train_iter:
# do something ...
pbar.update(1)