-
Notifications
You must be signed in to change notification settings - Fork 4
/
mul.fj
170 lines (145 loc) · 7.05 KB
/
mul.fj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
ns hex {
// Time Complexity: 5@+26
// Space Complexity: 4@+52
// .mul.add_carry_dst : res += x * .mul.dst + .mul.add_carry_dst
//
// .mul.add_carry_dst, res, x, .mul.dst, they are all hex.
// @requires hex.add.init (or hex.init)
def add_mul res, x @ ret < .mul.dst, .add.dst, .mul.ret, .tables.res {
.xor .mul.dst+4, x
.xor .add.dst, res
wflip .mul.ret+w, ret, .mul.dst
pad 4
ret: // meanwhile - make @+24 fj ops there
wflip .mul.ret+w, ret
.xor_zero res, .tables.res
}
// Time Complexity: n(5@+26)
// Space Complexity: n(4@+52)
// res[n] += a[n] * b[1]
//
// @requires hex.add.init (or hex.init)
def add_mul n, res, a, b < .mul.dst {
.mul.clear_carry
.xor .mul.dst, b
rep(n, i) .add_mul res+i*dw, a+i*dw
.xor .mul.dst, b
.mul.clear_carry
}
// Time Complexity: n^2(3@+7) + n*b(5@+26)
// for n==b/2: n^2(5.5@+20)
// Space Complexity: n(21@+479)
//
// res[:n] = a[:n] * b[:n]
//
// b is the minimum of #on-bits in a,b (b<n/2 on average, e.g. b=7 for n=16).
// @requires hex.add.init (or hex.init)
def mul n, res, a, b @ a_less_1bits, b_less_1bits, loop, after_add, dst, src, a_1bits, b_1bits, end {
.zero n, dst
.zero n, src
.zero n, res
// if a has less 1(on) bits - jump to a_less_1bits. else jump to b_less_1bits
.count_bits n, a_1bits, a
.count_bits n, b_1bits, b
.cmp ((#(n*4))+3)/4, a, b, a_less_1bits, a_less_1bits, b_less_1bits
a_less_1bits:
.xor n, dst, b
.xor n, src, a
;loop
b_less_1bits:
.xor n, dst, a
.xor n, src, b
loop:
.if0 src, after_add
.add_mul n, res, dst, src
after_add:
.shl_hex n, dst
.shr_hex n, src
.if n, src, end, loop // can be replaced with #n-sized index decrement and check.
dst: hex.vec n
src: hex.vec n
a_1bits: hex.vec ((#(n*4))+3)/4
b_1bits: hex.vec ((#(n*4))+3)/4
end:
}
ns mul {
// Complexity: 3@+1
def clear_carry @ return < .ret, .add_carry_dst {
..add.clear_carry
wflip .ret+w, return
.add_carry_dst+dbit+9; .add_carry_dst
pad 16
return:
wflip .ret+w, return
}
// Time Complexity: @+24 (when jumping to dst, until finished)
// Space Complexity: 1620+@
// This is where the add_mul "truth" tables are.
// @output-param ret: The return address. Jumps to it after finishing the add_mul flow.
// @output-param dst: hex[:2] variable. The code calculates dst[0]*dst[1] + add_carry_dst.
// dst[1] is zeroed after finishing this multiplication.
// @output-param add_carry_dst: hex variable. It's value is added to dst[0]*dst[1], and the carry is written back to it.
// @requires hex.tables.init_shared & hex.add.init (or hex.init)
def init @ add_res, after_add, end, switch_small_table, add_carry_small_table, set_carry_small_table, \
clean_small_table, switch, set_carry_0, set_carry_1, clean, add_carry, clean_add, clean_carry \
< ..add.dst, ..tables.ret > ret, dst, add_carry_dst {
// general progression (after jumping to hex.mul.dst with value d):
// 1. dst -> switch+d (set lower4 mul result in add_carry_dst+4) (runtime=5)
// 2. add_carry_dst -> add_carry+? (set add result in hex.add.dst +4. sets dst to set_carry_{0/1}. sets add_carry_dst to clean_add) (runtime=6)
// 3. add_carry_dst -> clean_add+? (clears the all 8 bits of add_carry_dst. sets add_carry_dst back to add_carry) (runtime=9)
// 4. dst -> set_carry_{0/1}+d (set higher4 mul result in add_carry_dst+0. sets dst to clean) (runtime=5)
// 5. add_res -> dst-table+?? (set add result in hex.tables.res +0) (runtime=@-8)
// 6. dst -> clean+d (clears the higher4 bits of dst. sets dst back to switch) (runtime=6)
// 7. hex.mul.ret -> ... (runtime=1)
;end
ret: ;0
dst: ;switch
add_carry_dst: ;add_carry // the 4-bit carry is in the lower 4bits in here
add_res:
wflip ..tables.ret+w, after_add, ..add.dst
pad 256
after_add:
wflip ..tables.ret+w, after_add, .dst
pad 16 // not really needed
switch_small_table:
rep(16, d) stl.fj \
(d==0) ? 0 : (.add_carry_dst + dbit + (#d) + 3), \
(d==((1<<(#d))>>1)) ? .add_carry_dst : switch_small_table + (d^((1<<(#d)) >> 1))*dw
set_carry_small_table:
rep(16, d) stl.fj \
(d==0) ? 0 : (.add_carry_dst+dbit+(#d)-1), \
(d==((1<<(#d))>>1)) ? add_res : set_carry_small_table + (d^((1<<(#d)) >> 1))*dw
add_carry_small_table:
rep(16, d) stl.fj \
(d==0) ? .add_carry_dst+dbit+8 : (..add.dst+dbit+(#d)+3), \
(d==0) ? .add_carry_dst : add_carry_small_table+(d^((1<<(#d))>>1))*dw
clean_small_table:
rep(16, d) stl.fj \
(d==0) ? .dst+dbit+9 : (.dst+dbit+(#d)+3), \
(d==0) ? .ret : clean_small_table + (d^((1<<(#d)) >> 1))*dw
pad 1024
switch:
rep(256, d) stl.fj 0, switch_small_table + (((d&0xf)*(d>>4)) & 0xf) * dw
set_carry_0:
rep(256, d) stl.fj .dst+dbit+9, set_carry_small_table + (((d&0xf)*(d>>4)) >> 4) * dw
set_carry_1:
rep(256, d) stl.fj .dst+dbit+8, set_carry_small_table + ((((d&0xf)*(d>>4)) >> 4)+1) * dw
clean:
rep(256, d) stl.fj .dst+dbit+8, clean_small_table + (d>>4) * dw
pad 1024 // needs to be 1024-padded
add_carry:
rep(256, d) stl.fj \
.dst+dbit + (((d&0xf)+(d>>4) > 0xf) ? 9 : 8), \
add_carry_small_table + (((d&0xf)+(d>>4)) & 0xf) * dw
clean_add:
rep(256, d) stl.fj \
(d==0) ? .add_carry_dst+dbit+8 : (.add_carry_dst+dbit+(#d)-1), \
(d==0) ? .dst : clean_add +(d^((1<<(#d))>>1))*dw
clean_carry:
rep( 16, d) stl.fj \
(d==0) ? .add_carry_dst+dbit+9 : (.add_carry_dst+dbit+(#d)-1), \
(d==0) ? .dst : clean_carry+(d^((1<<(#d))>>1))*dw
end:
}
}
}