fix batch loss reporting

This commit is contained in:
2024-05-19 00:07:34 +01:00
parent 5d7960d4c7
commit c8d8ec9928

View File

@@ -173,8 +173,8 @@ def train():
accuracy = 100 * correct / sample_size
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
if i % 10 == 9:
last_loss = running_loss / 10 # loss per batch
print(" batch {} loss: {}".format(i + 1, last_loss))
running_loss = 0.0