diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b02c2546eb4c6..1ab80412b5f26 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -624,7 +624,6 @@ struct server_response { } } }; - struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; @@ -2700,6 +2699,35 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } params.kv_overrides.push_back(kvo); + } else if (arg == "--control-vector") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back({ 1.0f, argv[i], }); + } else if (arg == "--control-vector-scaled") { + if (++i >= argc) { + invalid_param = true; + break; + } + const char* fname = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vectors.push_back({ std::stof(argv[i]), fname, }); + } else if (arg == "--control-vector-layer-range") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vector_layer_start = std::stoi(argv[i]); + if (++i >= argc) { + invalid_param = true; + break; + } + params.control_vector_layer_end = std::stoi(argv[i]); + break; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); @@ -3148,6 +3176,81 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_get_control_vectors = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res) { + json vectors = json::array(); + + for (const auto & vec : params.control_vectors) { + vectors.push_back(json { + { "fname", vec.fname }, + { "strength", vec.strength } + }); + } + json data = { + { "vectors", vectors }, + { "layer_start", params.control_vector_layer_start }, + { "layer_end", params.control_vector_layer_end } + }; + res.set_content(data.dump(), "application/json; charset=utf-8"); + }; + + const auto handle_set_control_vectors = [&ctx_server, &res_error, ¶ms, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + + json data = json::parse(req.body); + std::vector vec_params; + + if (data.contains("vectors") && data["vectors"].is_array()) { + for (const auto &item : data["vectors"]) { + auto v = item.get(); + // std::cout << "Add vector: " << v.fname << " " << v.strength << "\n"; + vec_params.push_back(v); + } + } else { + std::cerr << "No vectors passed\n"; + res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER)); + return; + } + for (auto v : params.control_vectors) { + // std::cout << "Subtract vector:" << v.fname << " " << v.strength << "\n"; + vec_params.push_back({ -v.strength, v.fname }); + } + const auto cvec = llama_control_vector_load(vec_params); + if (cvec.n_embd == -1) { + // std::cerr << "Could not load control vector\n"; + res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER)); + return; + } + + if (params.control_vector_layer_start <= 0) { + params.control_vector_layer_start = 1; + } + if (params.control_vector_layer_end <= 0){ + params.control_vector_layer_end = llama_n_layer(ctx_server.model); + } + int err = llama_control_vector_apply(ctx_server.ctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); + if (err) { + std::cerr << "Could not apply control vector\n"; + res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER)); + return; + } + auto s = params.control_vectors.size(); + auto s2 = vec_params.size(); + params.control_vectors.clear(); + unsigned i = 0; + for (auto v : vec_params) { + if (i++ < s2 - s) { + //std::cout << "set vector param: " << v.fname << " " << v.strength << "\n"; + params.control_vectors.push_back(v); + } + } + handle_get_control_vectors(req, res); + }; + const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { @@ -3497,8 +3600,10 @@ int main(int argc, char ** argv) { svr->Get ("/health", handle_health); svr->Get ("/slots", handle_slots); svr->Get ("/metrics", handle_metrics); + svr->Get ("/control-vectors", handle_get_control_vectors); svr->Get ("/props", handle_props); svr->Get ("/v1/models", handle_models); + svr->Post("/control-vectors", handle_set_control_vectors); svr->Post("/completion", handle_completions); // legacy svr->Post("/completions", handle_completions); svr->Post("/v1/completions", handle_completions); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8f20ff61454e9..f73ec441f6135 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -615,3 +615,8 @@ static json format_error_response(const std::string & message, const enum error_ {"type", type_str}, }; } + +void from_json(const json& j, llama_control_vector_load_info& l) { + j.at("strength").get_to(l.strength); + j.at("fname").get_to(l.fname); +}