How do I implement the following pytorch operations in burn? #1368
Answered
by
nathanielsimard
zemelLeong
asked this question in
Q&A
-
>>> tensor_randn.shape
torch.Size([1, 54, 8])
>>> tensor_randn[:, :, 0].shape
torch.Size([1, 54]) |
Beta Was this translation helpful? Give feedback.
Answered by
nathanielsimard
Feb 26, 2024
Replies: 1 comment 1 reply
-
let tensor = Tensor::random(..);
let [b, d1, d2] = tensor.dims();
let tensor_partial = tensor.slice([0..b, 0..d1, 0..1]); You have to specify the range of each dimension; at some point, we might have syntax sugar to avoid specifying the start and end position of dimensions where you want to keep everything, maybe something like that: let tensor_partial = tensor.slice(index![.., .., 0]); We might also implement the Index trait with multiple different generics: https://doc.rust-lang.org/std/ops/trait.Index.html, dispatching to |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
zemelLeong
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You have to specify the range of each dimension; at some point, we might have syntax sugar to avoid specifying the start and end position of dimensions where you want to keep everything, maybe something like that:
We might also implement the Index trait with multiple different generics: https://doc.rust-lang.org/std/ops/trait.Index.html, dispatching to
slice
when passingRange
and toselect
when using an int tensor.