Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized rope_rotation_llama and apply temperature to logits with vectorization #59

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions llama2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -617,22 +617,54 @@ fn rope_rotation_llama(
) -> None:
# stories model, llama2
let head_size = config.head_size
let head_range = config.head_size - config.head_size % (nelts * 2)

@parameter
fn head_loop(i:Int):
fn head_loop(i: Int):
# Simple vectorization with (head_size // 2) steps gave junk transformer output.
# Maybe because the nelt ranges end up overlapping between the steps.
for j in range(0, config.head_size, 2):
let fcr = freq_cis_real_row[j // 2]
let fci = freq_cis_imag_row[j // 2]
let q0 = state.q[i * head_size + j]
let q1 = state.q[i * head_size + j + 1]
state.q[i * head_size + j] = q0 * fcr - q1 * fci
state.q[i * head_size + j + 1] = q0 * fci + q1 * fcr
@parameter
fn calc_head[nelts: Int](j: Int):
# for j in range(0, config.head_size, 2):
let fcr = freq_cis_real_row.simd_load[nelts](j // 2)
let fci = freq_cis_imag_row.simd_load[nelts](j // 2)
let q0 = state.q.data().offset(i * head_size + j).simd_strided_load[nelts](
2
)
let q1 = state.q.data().offset(i * head_size + j + 1).simd_strided_load[
nelts
](2)

state.q.data().offset(i * head_size + j).simd_strided_store[nelts](
q0 * fcr - q1 * fci, 2
)
state.q.data().offset(i * head_size + j + 1).simd_strided_store[nelts](
q0 * fci + q1 * fcr, 2
)

if i < config.n_kv_heads:
let k0 = state.k[i * head_size + j]
let k1 = state.k[i * head_size + j + 1]
state.k[i * head_size + j] = k0 * fcr - k1 * fci
state.k[i * head_size + j + 1] = k0 * fci + k1 * fcr
let k0 = state.k.data().offset(i * head_size + j).simd_strided_load[
nelts
](2)
let k1 = state.k.data().offset(i * head_size + j + 1).simd_strided_load[
nelts
](2)

state.k.data().offset(i * head_size + j).simd_strided_store[nelts](
k0 * fcr - k1 * fci, 2
)

state.k.data().offset(i * head_size + j + 1).simd_strided_store[nelts](
k0 * fci + k1 * fcr, 2
)

for j in range(0, head_range, nelts * 2):
calc_head[nelts](j)

# deal with tail elements
for j in range(head_range, config.head_size, 2):
calc_head[1](j)

parallelize[head_loop](config.n_heads, workers)


Expand Down Expand Up @@ -975,8 +1007,13 @@ fn main() raises:
next_token = argmax(state.logits)
else:
# Apply the temperature to the logits
for q in range(config.vocab_size):
state.logits[q] = state.logits[q] / temperature
@parameter
fn v_temperature[_nelts: Int](q: Int):
state.logits.simd_store[_nelts](
q, state.logits.simd_load[_nelts](q) / temperature
)

vectorize[nelts, v_temperature](config.vocab_size)

# Apply softmax to the logits to get the probabilities for the next token
softmax(state.logits)
Expand Down