From bfd1eb2abbfc884dd069034ea6c7943caa921a09 Mon Sep 17 00:00:00 2001 From: Xunchi Zhang Date: Fri, 5 Jul 2024 18:17:39 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A5=E5=85=A8=E4=BA=86=E7=BC=BA=E5=A4=B1?= =?UTF-8?q?=E7=9A=84=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../beginner/basics/optimization_tutorial.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/2.0/tutorials/beginner/basics/optimization_tutorial.md b/docs/2.0/tutorials/beginner/basics/optimization_tutorial.md index 86aec389c..1f471c346 100644 --- a/docs/2.0/tutorials/beginner/basics/optimization_tutorial.md +++ b/docs/2.0/tutorials/beginner/basics/optimization_tutorial.md @@ -212,6 +212,18 @@ def test_loop(dataloader, model, loss_fn): 我们初始化了损失函数和优化器,传递给 `train_loop` 和 `test_loop`。你可以随意地修改 epochs 的数量来跟踪模型表现的进步情况。 +```py +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + +epochs = 10 +for t in range(epochs): + print(f"Epoch {t+1}\n-------------------------------") + train_loop(train_dataloader, model, loss_fn, optimizer) + test_loop(test_dataloader, model, loss_fn) +print("Done!") +``` + 输出: ```py