-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlinear_algebra.rs
163 lines (133 loc) · 5.42 KB
/
linear_algebra.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
//! This module contains implementation of all primitives related to linear algebra:
//! vectors, matrices, and R1CS.
//! NOTE: This is a dummy implementation and is not meant to be used in production.
//! Typically, you would use a library like arkworks-rs/algebra and such to work with linear algebra.
use crate::finite_field::Fp;
// First, we implement a vector of fixed length with an inner product operation.
/// Structure to represent a vector of fixed length with an inner product operation
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Vector<const N: usize>([Fp; N]);
impl<const N: usize> Vector<N> {
/// Creates a new vector from a fixed-size slice
pub fn new(value: [Fp; N]) -> Self {
assert_eq!(value.len(), N, "Vector length mismatch");
Vector(value)
}
/// Returns the ith element of the vector
pub fn get(&self, i: usize) -> Fp {
self.0[i]
}
/// Implements the dot product of two vectors
pub fn dot(&self, other: &Self) -> Fp {
// TODO: Implement dot product here!
unimplemented!("Implement dot product!")
}
/// Implements the element-wise product of two vectors
/// (called the Hadamard product)
pub fn hadamard_product(&self, other: &Self) -> Self {
// TODO: Implement hadamard product here!
unimplemented!("Implement hadamard product!")
}
}
// Next, we implement some matrix operations.
/// Structure, representing a matrix.
/// Here, we matrix consists of M row vectors of length N. This way,
/// essentially, Matrix.0[i][j] is the j-th element of the i-th row.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Matrix<const N: usize, const M: usize>([Vector<N>; M]);
impl<const N: usize, const M: usize> Matrix<N, M> {
/// Creates a new matrix with all elements set to zero
pub fn zero() -> Self {
Matrix::<N, M>::new([Vector::<N>::new([Fp::from(0); N]); M])
}
/// Creates a new matrix from a fixed-size slice of row vectors
pub fn new(value: [Vector<N>; M]) -> Self {
// Asserting that sizes are correct
assert_eq!(value.len(), M, "Matrix length mismatch");
assert_eq!(value[0].0.len(), N, "Matrix row length mismatch");
Matrix(value)
}
/// Returns the ith row of the matrix
pub fn row(&self, i: usize) -> &Vector<N> {
&self.0[i]
}
/// Returns the jth column of the matrix
pub fn column(&self, j: usize) -> Vector<M> {
let mut resultant_column = Vector::<M>::new([Fp::from(0); M]);
for i in 0..M {
resultant_column.0[i] = self.0[i].0[j];
}
resultant_column
}
/// Returns the (i,j) element of the matrix
pub fn get(&self, i: usize, j: usize) -> Fp {
self.0[i].0[j]
}
/// Implements the hadamard product of two matrices. Namely, given two matrices `A` and `B`,
/// consisting of elements `a_ij` and `b_ij` respectively, the hadamard product is a matrix `C`
/// consisting of elements `c_ij = a_ij * b_ij`.
pub fn hadamard_product(&self, other: &Self) -> Self {
// TODO: Implement hadamard product here!
unimplemented!("Implement hadamard product!")
}
/// Implements the matrix-vector product. Namely, given a matrix `A` and a vector `b`, the
/// matrix-vector product gives `Ab`.
///
/// **Hint:** this is a vector `c` such that `c_i = \sum_j a_ij * b_j`.
pub fn vector_product(&self, other: &Vector<N>) -> Vector<M> {
// TODO: Implement matrix-vector product here!
unimplemented!("Implement matrix-vector product!")
}
}
/// Below code simply tests the correctness of the your implementation. Do not touch
/// it unless you know what you are doing.
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_dot() {
// a = [1, 5, 5]
let a = Vector::new([Fp::from(1), Fp::from(5), Fp::from(5)]);
// b = [2, 3, 2]
let b = Vector::new([Fp::from(2), Fp::from(3), Fp::from(2)]);
// We expect to get 1*2+5*3+5*2 = 27
assert_eq!(a.dot(&b), Fp::from(27));
}
#[test]
fn test_matrix_vector_product() {
// A = {{1, 5, 5}, {2, 3, 2}}
let a = Matrix::new([
Vector::new([Fp::from(1), Fp::from(5), Fp::from(5)]),
Vector::new([Fp::from(2), Fp::from(3), Fp::from(2)]),
]);
// b = [2, 3, 2]
let b = Vector::new([Fp::from(2), Fp::from(3), Fp::from(2)]);
// c = A*b. We expect to get
// {1*2+5*3+5*2, 2*2+3*3+2*2} = {27, 17}
let c = a.vector_product(&b);
assert_eq!(c, Vector::new([Fp::from(27), Fp::from(17)]));
}
#[test]
fn test_matrix_hadamard_product() {
// A = {{1, 5, 5}, {2, 3, 2}}
let a = Matrix::new([
Vector::new([Fp::from(1), Fp::from(5), Fp::from(5)]),
Vector::new([Fp::from(2), Fp::from(3), Fp::from(2)]),
]);
// B = {{2, 3, 2}, {1, 5, 5}}
let b = Matrix::new([
Vector::new([Fp::from(2), Fp::from(3), Fp::from(2)]),
Vector::new([Fp::from(1), Fp::from(5), Fp::from(5)]),
]);
// C = A hadamard B. We expect to get
// {{1*2, 5*3, 5*2}, {2*1, 3*5, 2*5}} = {{2, 15, 10}, {2, 15, 10}}
let c = a.hadamard_product(&b);
assert_eq!(
c,
Matrix::new([
Vector::new([Fp::from(2), Fp::from(15), Fp::from(10)]),
Vector::new([Fp::from(2), Fp::from(15), Fp::from(10)]),
])
);
}
}