diff --git a/llama2.mojo b/llama2.mojo index b03888f..abf4048 100644 --- a/llama2.mojo +++ b/llama2.mojo @@ -617,22 +617,54 @@ fn rope_rotation_llama( ) -> None: # stories model, llama2 let head_size = config.head_size + let head_range = config.head_size - config.head_size % (nelts * 2) + @parameter - fn head_loop(i:Int): + fn head_loop(i: Int): # Simple vectorization with (head_size // 2) steps gave junk transformer output. # Maybe because the nelt ranges end up overlapping between the steps. - for j in range(0, config.head_size, 2): - let fcr = freq_cis_real_row[j // 2] - let fci = freq_cis_imag_row[j // 2] - let q0 = state.q[i * head_size + j] - let q1 = state.q[i * head_size + j + 1] - state.q[i * head_size + j] = q0 * fcr - q1 * fci - state.q[i * head_size + j + 1] = q0 * fci + q1 * fcr + @parameter + fn calc_head[nelts: Int](j: Int): + # for j in range(0, config.head_size, 2): + let fcr = freq_cis_real_row.simd_load[nelts](j // 2) + let fci = freq_cis_imag_row.simd_load[nelts](j // 2) + let q0 = state.q.data().offset(i * head_size + j).simd_strided_load[nelts]( + 2 + ) + let q1 = state.q.data().offset(i * head_size + j + 1).simd_strided_load[ + nelts + ](2) + + state.q.data().offset(i * head_size + j).simd_strided_store[nelts]( + q0 * fcr - q1 * fci, 2 + ) + state.q.data().offset(i * head_size + j + 1).simd_strided_store[nelts]( + q0 * fci + q1 * fcr, 2 + ) + if i < config.n_kv_heads: - let k0 = state.k[i * head_size + j] - let k1 = state.k[i * head_size + j + 1] - state.k[i * head_size + j] = k0 * fcr - k1 * fci - state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr + let k0 = state.k.data().offset(i * head_size + j).simd_strided_load[ + nelts + ](2) + let k1 = state.k.data().offset(i * head_size + j + 1).simd_strided_load[ + nelts + ](2) + + state.k.data().offset(i * head_size + j).simd_strided_store[nelts]( + k0 * fcr - k1 * fci, 2 + ) + + state.k.data().offset(i * head_size + j + 1).simd_strided_store[nelts]( + k0 * fci + k1 * fcr, 2 + ) + + for j in range(0, head_range, nelts * 2): + calc_head[nelts](j) + + # deal with tail elements + for j in range(head_range, config.head_size, 2): + calc_head[1](j) + parallelize[head_loop](config.n_heads, workers) @@ -975,8 +1007,13 @@ fn main() raises: next_token = argmax(state.logits) else: # Apply the temperature to the logits - for q in range(config.vocab_size): - state.logits[q] = state.logits[q] / temperature + @parameter + fn v_temperature[_nelts: Int](q: Int): + state.logits.simd_store[_nelts]( + q, state.logits.simd_load[_nelts](q) / temperature + ) + + vectorize[nelts, v_temperature](config.vocab_size) # Apply softmax to the logits to get the probabilities for the next token softmax(state.logits)