Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TBE annotation in Kineto trace #3057

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import random
import statistics
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -43,6 +44,7 @@
)
from fbgemm_gpu.tbe.utils import generate_requests, get_device, round_up, TBERequest
from torch import Tensor
from torch.profiler import profile

haveAIBench = False
try:
Expand Down Expand Up @@ -110,6 +112,12 @@ def cli() -> None:
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option("--export-trace", is_flag=True, default=False)
@click.option(
"--trace-url",
type=str,
default="{tbe_type}_tbe_{phase}_trace_{ospid}.json",
)
def device( # noqa C901
alpha: float,
bag_size: int,
Expand All @@ -134,6 +142,8 @@ def device( # noqa C901
output_dtype: SparseType,
requests_data_file: Optional[str],
tables: Optional[str],
export_trace: bool,
trace_url: str,
) -> None:
np.random.seed(42)
torch.manual_seed(42)
Expand Down Expand Up @@ -187,6 +197,7 @@ def device( # noqa C901
do_pooling = False

if dense:
tbe_type: str = "dense"
emb = DenseTableBatchedEmbeddingBagsCodegen(
[
(
Expand All @@ -199,6 +210,7 @@ def device( # noqa C901
use_cpu=not torch.cuda.is_available(),
)
else:
tbe_type: str = "split"
emb = SplitTableBatchedEmbeddingBagsCodegen(
[
(
Expand Down Expand Up @@ -263,18 +275,29 @@ def device( # noqa C901
use_cpu=not torch.cuda.is_available(),
)

# forward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.long(),
offsets.long(),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
)
def _kineto_trace_handler(p: profile, phase: str) -> None:
p.export_chrome_trace(
trace_url.format(tbe_type=tbe_type, phase=phase, ospid=os.getpid())
)

# pyre-ignore[3]
def context_factory(on_trace_ready: Callable[[profile], None]):
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()

with context_factory(lambda p: _kineto_trace_handler(p, "fwd")):
# forward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices.long(),
offsets.long(),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
)

logging.info(
f"Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
Expand All @@ -290,20 +313,23 @@ def device( # noqa C901
grad_output = torch.randn(B, sum(Ds)).to(get_device())
else:
grad_output = torch.randn(B * T * L, D).to(get_device())
# backward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
bwd_only=True,
grad=grad_output,
num_warmups=warmup_runs,
)

with context_factory(lambda p: _kineto_trace_handler(p, "fwd_bwd")):
# backward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb(
indices.long(),
offsets.long(),
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
bwd_only=True,
grad=grad_output,
num_warmups=warmup_runs,
)

logging.info(
f"Backward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
f"BW: {2 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/TypeDefault.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>
#include "torch/csrc/autograd/record_function_ops.h"

#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/ops_utils.h"
Expand All @@ -20,6 +21,7 @@
using Tensor = at::Tensor;

using namespace fbgemm_gpu;
namespace profiler = torch::autograd::profiler;

{#/* Module description */#}
{%- set fwd_mdesc = "ssd" if ssd else ("dense" if dense else "split") %}
Expand Down Expand Up @@ -62,7 +64,7 @@ enum SSDTensor {
.findSchemaOrThrow("fbgemm::{{ forward_op }}", "")
.typed<decltype({{ forward_op }})>();

return {
auto ret = {
embedding_codegen_forward_op.call(
flatten_dev_weights,
{%- if not dense %}
Expand Down Expand Up @@ -111,6 +113,10 @@ enum SSDTensor {
{{ "is_experimental" if has_experimental else "false" }}
)
};
if (is_annotate_trace_enabled) {
record_trace->record.end();
}
return ret;
{%- endmacro %}

/* This macro generates a code blob for dispatching corresponding weighted and
Expand Down Expand Up @@ -195,7 +201,7 @@ enum SSDTensor {
/*unused=*/0
{%- endif %}
);
return {
auto ret = {
{%- if not dense %}
Tensor(), // placeholder autograd tensor
{%- endif %}
Expand Down Expand Up @@ -261,6 +267,10 @@ enum SSDTensor {
{%- endif %}
{{ args.split_variables | join(", ") }}
};
if (is_annotate_trace_enabled) {
record_trace->record.end();
}
return ret;
{%- endmacro %}

/* This macro generates a code blob that calls corresponding autograd function
Expand Down Expand Up @@ -630,6 +640,30 @@ class {{ autograd_func }} :
const auto max_B_ = offsets.sym_size(0) / T;
{%- endif %}

// Annotate Kineto trace
const static bool is_annotate_trace_enabled = config::is_feature_enabled(
config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE);
std::string op_annotation = "";
c10::intrusive_ptr<profiler::PythonRecordFunction> record_trace;
if (is_annotate_trace_enabled) {
std::stringstream ss;
ss << "["
<< "weighted={{ "T" if weighted else "F" }},"
<< "pooled={{ "T" if not nobag else "F" }},"
<< "vbe={{ "T" if vbe else "F" }},"
<< "avg_B=" << ({{ "max_B_" if not vbe else "max_B_ / T" }}) << ","
<< "max_B=" << max_B_ << ","
<< "T=" << T << ","
<< "avg_D=" << ({{ "total_D / T" if not nobag else "D" }}) << ","
<< "max_D=" << {{ "max_D" if not nobag else "D" }} << ","
<< "num_indices=" << indices.numel() << ","
<< "avg_pooling_fac=" << (static_cast<float>(indices.numel()) / T / max_B_)
<< "]";
op_annotation = ss.str();
record_trace = profiler::record_function_enter_new(
"{{ fwd_mdesc }}_tbe_fwd" + op_annotation);
}

{%- if not dense %}
// NOTE: The `local_uvm_cache_stats` variable held by the nn.Module has dtype int32_t
// TODO: Hook up with frontend code
Expand Down Expand Up @@ -724,6 +758,7 @@ class {{ autograd_func }} :
{{ args.split_saved_tensors | join(", ") }}
});

ctx->saved_data["op_annotation"] = op_annotation;
{%- if not nobag %}
ctx->saved_data["max_D"] = max_D;
ctx->saved_data["pooling_mode"] = pooling_mode;
Expand Down Expand Up @@ -874,6 +909,15 @@ class {{ autograd_func }} :
{%- endfor %}
{%- endif %}

const static bool is_annotate_trace_enabled = config::is_feature_enabled(
config::FeatureGateName::TBE_ANNOTATE_KINETO_TRACE);
c10::intrusive_ptr<profiler::PythonRecordFunction> record_trace;
if (is_annotate_trace_enabled) {
auto& op_annotation = ctx->saved_data["op_annotation"].toStringRef();
record_trace = profiler::record_function_enter_new(
"{{ bwd_mdesc }}_tbe_bwd" + op_annotation);
}

TORCH_CHECK_EQ(grad_outputs.size(), 1);

#ifdef USE_ROCM
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ namespace fbgemm_gpu::config {
/// For OSS: The environment variable will be evaluated as f"FBGEMM_{ENUM}"
#define ENUMERATE_ALL_FEATURE_FLAGS \
X(TBE_V2) \
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD)
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
X(TBE_ANNOTATE_KINETO_TRACE)
// X(EXAMPLE_FEATURE_FLAG)

/// @ingroup fbgemm-gpu-config
Expand Down
Loading