From 98d12c125a5938dddd583daaac0f084d8d2e00d5 Mon Sep 17 00:00:00 2001 From: Glycogen W <109408857+XihWang@users.noreply.github.com> Date: Sun, 29 Dec 2024 16:54:08 +0800 Subject: [PATCH] Update seq2seq.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 根据新版d2l中class EncoderDecoder的定义[1],其返回值只有一个,因此建议将“_”删除,否则会报错:ValueError: too many values to unpack (expected 2) [1] https://github.com/d2l-ai/d2l-en/blob/master/d2l/torch.py#L951 --- chapter_recurrent-modern/seq2seq.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chapter_recurrent-modern/seq2seq.md b/chapter_recurrent-modern/seq2seq.md index 0b587cb8d..ddf3aec85 100644 --- a/chapter_recurrent-modern/seq2seq.md +++ b/chapter_recurrent-modern/seq2seq.md @@ -724,7 +724,7 @@ def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): ctx=device).reshape(-1, 1) dec_input = np.concatenate([bos, Y[:, :-1]], 1) # 强制教学 with autograd.record(): - Y_hat, _ = net(X, dec_input, X_valid_len) + Y_hat = net(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.backward() d2l.grad_clipping(net, 1) @@ -766,7 +766,7 @@ def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): bos = torch.tensor([tgt_vocab['']] * Y.shape[0], device=device).reshape(-1, 1) dec_input = torch.cat([bos, Y[:, :-1]], 1) # 强制教学 - Y_hat, _ = net(X, dec_input, X_valid_len) + Y_hat = net(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.sum().backward() # 损失函数的标量进行“反向传播” d2l.grad_clipping(net, 1) @@ -797,7 +797,7 @@ def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): shape=(-1, 1)) dec_input = tf.concat([bos, Y[:, :-1]], 1) # 强制教学 with tf.GradientTape() as tape: - Y_hat, _ = net(X, dec_input, X_valid_len, training=True) + Y_hat = net(X, dec_input, X_valid_len, training=True) l = MaskedSoftmaxCELoss(Y_valid_len)(Y, Y_hat) gradients = tape.gradient(l, net.trainable_variables) gradients = d2l.grad_clipping(gradients, 1) @@ -1120,4 +1120,4 @@ for eng, fra in zip(engs, fras): :begin_tab:`paddle` [Discussions](https://discuss.d2l.ai/t/11838) -:end_tab: \ No newline at end of file +:end_tab: