diff --git a/src/matrix.rs b/src/matrix.rs index 200ae04..9ef9946 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -114,6 +114,8 @@ impl Msg { // endregion: --- impls // region: --- functions +// AB -> a.col == b.row (左乘) +// 最后的矩阵是一个 a.row * b.col 的矩阵 pub fn multiply(a: &Matrix, b: &Matrix) -> Result> where T: Copy + Default + Add + AddAssign + Mul + Send + 'static, @@ -127,6 +129,7 @@ where // for i in 0..a.row { // for j in 0..b.col { // for k in 0..a.col { + // // data[i][j] += a[i][k] * b[k][j] // data[i * b.col + j] += a.data[i * a.col + k] * b.data[k * b.col + j]; // } // } @@ -156,14 +159,14 @@ where // map/reduce: map phase for i in 0..a.row { for j in 0..b.col { - let row = Vector::new(&a.data[i * a.col..(i + 1) * a.col]); + let row = Vector::new(&a.data[i * a.col..(i + 1) * a.col]); // a[i][k] let col_data = b.data[j..] .iter() .step_by(b.col) .copied() .collect::>(); - let col = Vector::new(col_data); - let idx = i * b.col + j; + let col = Vector::new(col_data); // b[k][j] + let idx = i * b.col + j; // i,j -> idx let input = MsgInput::new(idx, row, col); let (tx, rx) = oneshot::channel(); let msg = Msg::new(input, tx);