Skip to content

Commit

Permalink
hmm...
Browse files Browse the repository at this point in the history
  • Loading branch information
trollkotze committed Mar 25, 2024
1 parent 0274e6b commit 7dbed97
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3176,24 +3176,24 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};

const auto handle_get_control_vectors = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res) {
const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
json vectors = json::array();

for (const auto & vec : params.control_vectors) {
for (const auto & vec : ctx_server.params.control_vectors) {
vectors.push_back(json {
{ "fname", vec.fname },
{ "strength", vec.strength }
});
}
json data = {
{ "vectors", vectors },
{ "layer_start", params.control_vector_layer_start },
{ "layer_end", params.control_vector_layer_end }
{ "layer_start", ctx_server.params.control_vector_layer_start },
{ "layer_end", ctx_server.params.control_vector_layer_end }
};
res.set_content(data.dump(), "application/json; charset=utf-8");
};

const auto handle_set_control_vectors = [&ctx_server, &res_error, &params, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
const auto handle_set_control_vectors = [&ctx_server, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));

json data = json::parse(req.body);
Expand All @@ -3202,52 +3202,55 @@ int main(int argc, char ** argv) {
if (data.contains("vectors") && data["vectors"].is_array()) {
for (const auto &item : data["vectors"]) {
auto v = item.get<llama_control_vector_load_info>();
// std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
vec_params.push_back(v);
}
} else {
std::cerr << "No vectors passed\n";
res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER));
return;
}
for (auto v : params.control_vectors) {
// std::cout << "Subtract vector:" << v.fname << " " << v.strength << "\n";
vec_params.push_back({ -v.strength, v.fname });
}
const auto cvec = llama_control_vector_load(vec_params);
if (cvec.n_embd == -1) {
// std::cerr << "Could not load control vector\n";
std::cerr << "Could not load control vector\n";
res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER));
return;
}

if (params.control_vector_layer_start <= 0) {
params.control_vector_layer_start = 1;
if (ctx_server.params.control_vector_layer_start <= 0) {
ctx_server.params.control_vector_layer_start = 1;
}
if (params.control_vector_layer_end <= 0){
params.control_vector_layer_end = llama_n_layer(ctx_server.model);
if (ctx_server.params.control_vector_layer_end <= 0){
ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model);
}
int err = llama_control_vector_apply(ctx_server.ctx,
cvec.data.data(),
cvec.data.size(),
cvec.n_embd,
params.control_vector_layer_start,
params.control_vector_layer_end);
ctx_server.params.control_vector_layer_start,
ctx_server.params.control_vector_layer_end);
if (err) {
std::cerr << "Could not apply control vector\n";
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
return;
}
auto s = params.control_vectors.size();
auto s2 = vec_params.size();
params.control_vectors.clear();
unsigned i = 0;
ctx_server.params.control_vectors.clear();
for (auto v : vec_params) {
if (i++ < s2 - s) {
//std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
params.control_vectors.push_back(v);
}
//std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
ctx_server.params.control_vectors.push_back(v);
}

/*std::cerr << "Maybe we need to do this initiation ritual before it werks?\n"; // No, it's still all garbled bullshit.
std::vector<llama_token> tmp = { llama_token_bos(ctx_server.model), llama_token_eos(ctx_server.model), };
std::cerr << "decode, bro\n";
llama_decode(ctx_server.ctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) ctx_server.params.n_batch), 0, 0));
std::cerr << "clear that fucking cache\n";
llama_kv_cache_clear(ctx_server.ctx);
std::cerr << "symcr0nice or what\n";
llama_synchronize(ctx_server.ctx);
std::cerr << "time will tell\n";
llama_reset_timings(ctx_server.ctx);*/
handle_get_control_vectors(req, res);
};

Expand Down

0 comments on commit 7dbed97

Please sign in to comment.