Skip to content

Commit

Permalink
Improve gradient checkpointing implementation to:
Browse files Browse the repository at this point in the history
1. Avoid issue that a input exec is visited early because we don't know
   it has other inputs.

2. tensor symbol mapping is incorrect because optimization passes will
   free / then reuse some tensors in backward pass.
  • Loading branch information
liuliu committed Oct 18, 2024
1 parent bed628c commit a02a510
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 38 deletions.
1 change: 1 addition & 0 deletions lib/nnc/_ccv_cnnp_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ void ccv_cnnp_model_tensors_init_0(const ccv_cnnp_model_t* const model, ccv_cnnp
void ccv_cnnp_model_tensors_init_1(const ccv_cnnp_model_t* const model, ccv_cnnp_compiled_data_t* const compiled_data);
int ccv_cnnp_model_tensors_any_to_alloc(const ccv_cnnp_model_t* const model, ccv_cnnp_compiled_data_t* const compiled_data);
ccv_nnc_stream_context_t* ccv_cnnp_compiled_data_get_stream(ccv_cnnp_compiled_data_t* const compiled_data, const int type);
void ccv_cnnp_model_gradient_checkpoints_cleanup_after_build(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph);
void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph);
void ccv_cnnp_model_add_to_array(void* const context, const ccv_nnc_tensor_symbol_t symbol, const int is_trainable);

Expand Down
1 change: 1 addition & 0 deletions lib/nnc/ccv_cnnp_model.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ static void _ccv_cnnp_model_compile(ccv_cnnp_model_t* const model, const ccv_nnc
compiled_data->internals = internals;
compiled_data->ids.parameters = parameter_ids;
compiled_data->ids.internals = internal_ids;
ccv_cnnp_model_gradient_checkpoints_cleanup_after_build(compiled_data, model->graph);
}

static void _ccv_cnnp_graph_push_graph_exec_symbol(void* context, const ccv_nnc_graph_exec_symbol_t symbol, const ccv_nnc_cmd_t cmd, const ccv_nnc_tensor_symbol_t* const inputs, const int input_size, const ccv_nnc_tensor_symbol_t* const outputs, const int output_size, const char* const name)
Expand Down
205 changes: 167 additions & 38 deletions lib/nnc/ccv_cnnp_model_gradient_checkpointing.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,31 @@
// This can be removed once we organized ccv_cnnp_apply_gradient_checkpoints better.
#include "_ccv_nnc_symbolic_graph.h"

void ccv_cnnp_model_gradient_checkpoints_cleanup_after_build(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph)
{
ccv_array_t* const gradient_checkpoints = compiled_data->gradient_checkpoints;
if (!gradient_checkpoints || gradient_checkpoints->rnum == 0) // No saved gradient checkpoints, this is an easy way out.
return;
int i, j;
const ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (const ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0);
// Go through to check if any tensors that supposes in this map is removed.
for (i = 0; i < gradient_checkpoints->rnum; i++)
{
ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
for (j = 0; j < checkpoint->tensor_symbols->rnum; j++)
{
ccv_nnc_tensor_symbol_t* const symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j));
if (symbol->d >= 0 && symbol->d < graph->tensor_symbol_info->rnum)
// If it is dead, we need to remove this symbol.
if (CCV_NNC_TENSOR_SYMBOL_IS_DEAD(tensor_symbol_info[symbol->d].flags))
{
symbol->d = -1;
symbol->graph = 0;
}
}
}
}

typedef struct {
ccv_array_t* outgoings;
} ccv_nnc_graph_exec_symbol_reverse_t;
Expand Down Expand Up @@ -46,6 +71,7 @@ static void _ccv_cnnp_model_gradient_checkpoint_graph_exec_symbol_new_hook(void*
}

KHASH_MAP_INIT_INT(ccv_cnnp_tensor_symbol_map, int)
KHASH_SET_INIT_INT(ccv_cnnp_tensor_symbol_set)

void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const compiled_data, ccv_nnc_symbolic_graph_t* const graph)
{
Expand Down Expand Up @@ -102,9 +128,36 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
}
ccv_nnc_tensor_symbol_t* max_outputs = ccmalloc(sizeof(ccv_nnc_tensor_symbol_t) * max_output_size);
ccv_array_t* newly_used_outputs = ccv_array_new(sizeof(int), 0, 0);
khash_t(ccv_cnnp_tensor_symbol_set)* const parameters_or_internals = kh_init(ccv_cnnp_tensor_symbol_set);
for (i = 0; i < compiled_data->parameters->rnum; i++)
{
const ccv_nnc_tensor_symbol_t* const symbol = (const ccv_nnc_tensor_symbol_t*)ccv_array_get(compiled_data->parameters, i);
int ret;
kh_put(ccv_cnnp_tensor_symbol_set, parameters_or_internals, symbol->d, &ret);
}
for (i = 0; i < compiled_data->internals->rnum; i++)
{
const ccv_nnc_tensor_symbol_t* const symbol = (const ccv_nnc_tensor_symbol_t*)ccv_array_get(compiled_data->parameters, i);
int ret;
kh_put(ccv_cnnp_tensor_symbol_set, parameters_or_internals, symbol->d, &ret);
}
khash_t(ccv_cnnp_tensor_symbol_set)* const newly_created_tensor_symbols = kh_init(ccv_cnnp_tensor_symbol_set);
khash_t(ccv_cnnp_tensor_symbol_map)* symbol_map = kh_init(ccv_cnnp_tensor_symbol_map);
for (i = 0; i < gradient_checkpoints->rnum; i++)
{
ccv_cnnp_model_gradient_checkpoint_t* const checkpoint = (ccv_cnnp_model_gradient_checkpoint_t*)ccv_array_get(gradient_checkpoints, i);
kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
for (j = 0; j < checkpoint->tensor_symbols->rnum; j++)
{
const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, j))->d;
if (idx < 0)
continue;
// Skip parameters or internals.
if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals))
continue;
int ret;
kh_put(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, idx, &ret);
}
ccv_array_clear(input_execs);
ccv_array_clear(output_execs);
ccv_nnc_graph_exec_symbol_info_t* exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
Expand All @@ -128,9 +181,13 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
int flag = 0;
for (k = 0; inputs && k < input_size && !flag; k++)
if (inputs[k] >= 0)
for (l = 0; l < checkpoint->input_size && !flag; l++)
if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
flag = 1;
for (l = 0; l < checkpoint->input_size && !flag; l++)
if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
flag = 1;
// Input logic is different from output logic. We need to filter out these exec that contains inputs from within the graph.
for (k = 0; inputs && k < input_size && flag; k++)
if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols))
flag = 0;
if (flag)
ccv_array_push(input_execs, &symbol);
flag = 0;
Expand Down Expand Up @@ -288,6 +345,17 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
// Note that there is no graph optimization applied here.
exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
// Reuse existing one.
kh_clear(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
for (j = 0; j < build.tensor_symbols->rnum; j++)
{
const int idx = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d;
if (idx < 0)
continue;
if (kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, idx) != kh_end(parameters_or_internals))
continue;
int ret;
kh_put(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, idx, &ret);
}
ccv_array_t* const newly_input_execs = input_execs;
ccv_array_t* const newly_output_execs = output_execs;
ccv_array_clear(newly_input_execs);
Expand All @@ -310,19 +378,23 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
int flag = 0;
for (k = 0; inputs && k < input_size && !flag; k++)
if (inputs[k] >= 0)
for (l = 0; l < checkpoint->input_size && !flag; l++)
if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
flag = 1;
for (l = 0; l < checkpoint->input_size && !flag; l++)
if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
flag = 1;
// Input logic is different from output logic. We need to filter out these exec that contains inputs from within the graph.
for (k = 0; inputs && k < input_size && flag; k++)
if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols))
flag = 0;
if (flag)
ccv_array_push(newly_input_execs, &symbol);
flag = 0;
const int* outputs = exec_info[idx].outputs;
int output_size = exec_info[idx].output_size;
for (k = 0; inputs && k < output_size && !flag; k++)
if (outputs[k] >= 0)
for (l = 0; l < checkpoint->output_size && !flag; l++)
if (max_outputs[l].d >= 0 && outputs[k] == max_outputs[l].d)
flag = 1;
for (l = 0; l < checkpoint->output_size && !flag; l++)
if (max_outputs[l].d >= 0 && outputs[k] == max_outputs[l].d)
flag = 1;
if (flag)
ccv_array_push(newly_output_execs, &symbol);
}
Expand All @@ -339,6 +411,7 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
ccv_nnc_graph_exec_symbol_new_hook(graph, build.old_graph_exec_symbol_new_hook, build.old_graph_exec_symbol_new_hook_context, 0);
// Need to autogen and redo source / destination.
ccv_nnc_graph_exec_symbol_autogen(graph, (ccv_nnc_graph_exec_symbol_t*)ccv_array_get(build.graph_exec_symbols, 0), build.graph_exec_symbols->rnum, 0);
ccv_nnc_tensor_symbol_info_t* const tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(graph->tensor_symbol_info, 0);
exec_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, 0);
ccv_array_clear(newly_input_execs);
for (j = 0; j < build.graph_exec_symbols->rnum; j++)
Expand All @@ -359,41 +432,32 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
int flag = 0;
for (k = 0; inputs && k < input_size && !flag; k++)
if (inputs[k] >= 0)
for (l = 0; l < checkpoint->input_size && !flag; l++)
if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
flag = 1;
for (l = 0; l < checkpoint->input_size && !flag; l++)
if (checkpoint->inputs[l].d >= 0 && inputs[k] == checkpoint->inputs[l].d)
flag = 1;
for (k = 0; inputs && k < input_size && flag; k++)
if (inputs[k] >= 0 && kh_get(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols, inputs[k]) != kh_end(newly_created_tensor_symbols))
flag = 0;
if (flag)
ccv_array_push(newly_input_execs, &symbol);
}
// Build a map between old tensor symbols and new tensor symbols.
khash_t(ccv_cnnp_tensor_symbol_map)* symbol_map = kh_init(ccv_cnnp_tensor_symbol_map);
assert(build.tensor_symbols->rnum <= checkpoint->tensor_symbols->rnum);
// Build a map to potentially map from old input to new input.
kh_clear(ccv_cnnp_tensor_symbol_map, symbol_map);
for (j = 0, k = 0; j < build.tensor_symbols->rnum && k < checkpoint->tensor_symbols->rnum;)
{
const int from_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(checkpoint->tensor_symbols, k))->d;
assert(from_d >= 0);
const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d;
assert(to_d >= 0);
int from_flag = 0;
int to_flag = 0;
for (l = 0; (!from_flag || !to_flag) && l < parameters->rnum; l++)
if (from_d < 0) // This is removed, move to the next one.
{
const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(parameters, l))->d;
if (d == from_d)
from_flag = 1;
if (d == to_d)
to_flag = 1;
++j;
++k;
continue;
}
if (!from_flag || !to_flag)
for (l = 0; (!from_flag || !to_flag) && l < internals->rnum; l++)
{
const int d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(internals, l))->d;
if (d == from_d)
from_flag = 1;
if (d == to_d)
to_flag = 1;
}
const int to_d = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j))->d;
assert(to_d >= 0);
int from_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, from_d) != kh_end(parameters_or_internals);
int to_flag = kh_get(ccv_cnnp_tensor_symbol_set, parameters_or_internals, to_d) != kh_end(parameters_or_internals);
if (from_flag)
++k;
if (to_flag)
Expand All @@ -408,6 +472,12 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
from_flag = 1;
if (from_flag)
continue;
// Skip if to_d is outputs.
for (l = 0; l < !to_flag && checkpoint->output_size; l++)
if (checkpoint->outputs[l].d == to_d)
to_flag = 1;
if (to_flag)
continue;
int ret = 0;
khiter_t h = kh_put(ccv_cnnp_tensor_symbol_map, symbol_map, from_d, &ret);
kh_val(symbol_map, h) = to_d;
Expand All @@ -430,9 +500,14 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
const khiter_t h = kh_get(ccv_cnnp_tensor_symbol_map, symbol_map, exec_info[idx].inputs[k]);
if (h != kh_end(symbol_map)) // Replacing it.
{
const int newly_created_output = kh_val(symbol_map, h);
int newly_created_output = kh_val(symbol_map, h);
exec_info[idx].inputs[k] = newly_created_output;
ccv_array_add_unique_int(newly_used_outputs, newly_created_output);
if (tensor_symbol_info[newly_created_output].alias_ref > 0)
{
newly_created_output = tensor_symbol_info[newly_created_output].alias_ref - 1;
ccv_array_add_unique_int(newly_used_outputs, newly_created_output);
}
ccv_array_add_unique_int(replaced_backward_execs, idx);
}
}
Expand All @@ -453,9 +528,23 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
assert(ccv_nnc_cmd_is_backward(exec_info[idx].cmd));
int flag = 0;
for (x = 0; !flag && x < exec_info[idx].input_size; x++)
{
int x_d = exec_info[idx].inputs[x];
if (x_d < 0)
continue;
if (tensor_symbol_info[x_d].alias_ref > 0)
x_d = tensor_symbol_info[x_d].alias_ref - 1;
for (y = 0; !flag && y < exec_info[symbol->d].output_size; y++)
if (exec_info[idx].inputs[x] == exec_info[symbol->d].outputs[y])
{
int y_d = exec_info[symbol->d].outputs[y];
if (y_d < 0)
continue;
if (tensor_symbol_info[y_d].alias_ref > 0)
y_d = tensor_symbol_info[y_d].alias_ref - 1;
if (x_d == y_d)
flag = 1;
}
}
if (flag)
ccv_nnc_graph_exec_symbol_concat(graph, *symbol, (ccv_nnc_graph_exec_symbol_t){
.graph = graph,
Expand Down Expand Up @@ -584,7 +673,17 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
int* const inputs = exec_info[symbol->d].inputs;
const int input_size = exec_info[symbol->d].input_size;
for (k = 0; k < input_size; k++)
ccv_array_add_unique_int(forward_pass_inputs, inputs[k]);
{
int d = inputs[k];
if (d < 0)
continue;
ccv_array_add_unique_int(forward_pass_inputs, d);
if (tensor_symbol_info[d].alias_ref > 0)
{
d = tensor_symbol_info[d].alias_ref - 1;
ccv_array_add_unique_int(forward_pass_inputs, d);
}
}
}
any_deleted = 0;
for (j = 0; j < build.graph_exec_symbols->rnum; j++)
Expand All @@ -598,7 +697,17 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
const int output_size = exec_info[symbol->d].output_size;
int flag = 0;
for (k = 0; !flag && k < output_size; k++)
flag = ccv_array_contain_int(newly_used_outputs, outputs[k]) || ccv_array_contain_int(forward_pass_inputs, outputs[k]);
{
int d = outputs[k];
if (d < 0)
continue;
flag = ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d);
if (!flag && tensor_symbol_info[d].alias_ref > 0)
{
d = tensor_symbol_info[d].alias_ref - 1;
flag = ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d);
}
}
if (flag)
continue;
ccv_nnc_graph_exec_symbol_free(graph, *symbol);
Expand All @@ -618,18 +727,36 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
int* const inputs = exec_info[symbol->d].inputs;
const int input_size = exec_info[symbol->d].input_size;
for (k = 0; k < input_size; k++)
{
if (inputs[k] < 0)
continue;
ccv_array_add_unique_int(forward_pass_inputs, inputs[k]);
if (tensor_symbol_info[inputs[k]].alias_ref > 0)
ccv_array_add_unique_int(forward_pass_inputs, tensor_symbol_info[inputs[k]].alias_ref - 1);
}
int* const outputs = exec_info[symbol->d].outputs;
const int output_size = exec_info[symbol->d].output_size;
for (k = 0; k < output_size; k++)
{
if (outputs[k] < 0)
continue;
ccv_array_add_unique_int(forward_pass_inputs, outputs[k]);
if (tensor_symbol_info[outputs[k]].alias_ref > 0)
ccv_array_add_unique_int(forward_pass_inputs, tensor_symbol_info[outputs[k]].alias_ref - 1);
}
}
// Free unused tensor symbols.
for (j = 0; j < build.tensor_symbols->rnum; j++)
{
const ccv_nnc_tensor_symbol_t* symbol = ((ccv_nnc_tensor_symbol_t*)ccv_array_get(build.tensor_symbols, j));
if (ccv_array_contain_int(newly_used_outputs, symbol->d) || ccv_array_contain_int(forward_pass_inputs, symbol->d))
continue;
if (tensor_symbol_info[symbol->d].alias_ref > 0)
{
const int d = tensor_symbol_info[symbol->d].alias_ref - 1;
if (ccv_array_contain_int(newly_used_outputs, d) || ccv_array_contain_int(forward_pass_inputs, d))
continue;
}
ccv_nnc_tensor_symbol_free(graph, *symbol);
}
for (j = 0; j < build.graph_exec_symbols->rnum; j++)
Expand All @@ -644,8 +771,10 @@ void ccv_cnnp_model_apply_gradient_checkpoints(ccv_cnnp_compiled_data_t* const c
// Free these newly created execs and tensor symbols.
ccv_array_free(build.tensor_symbols);
ccv_array_free(build.graph_exec_symbols);
kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map);
}
kh_destroy(ccv_cnnp_tensor_symbol_map, symbol_map);
kh_destroy(ccv_cnnp_tensor_symbol_set, newly_created_tensor_symbols);
kh_destroy(ccv_cnnp_tensor_symbol_set, parameters_or_internals);
ccfree(max_outputs);
ccv_array_free(buf);
ccv_array_free(newly_used_outputs);
Expand Down

0 comments on commit a02a510

Please sign in to comment.