Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dump v #200

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Dump v #200

wants to merge 1 commit into from

Conversation

Lifann
Copy link
Collaborator

@Lifann Lifann commented Aug 13, 2024

Here is the costs in microseconds of dump_kernel and dump_kernel_v2 on both pinned host or device output on 2^24 capacity table with half of the contents are exported. The table values are stored on pure GPU buckets.

capacity: 2^24, and the table is full when running the export_batch_if
num exported: 8388771
dim: 64

A100 + AMD

dump_kernel dump_kernel_v2 dump_kernel_v2_vectorized
Pinned host memory 14887.594 2116.001 607.138
Device 24.700 6.012 3.957

H20 + Intel

dump_kernel dump_kernel_v2 dump_kernel_v2_vectorized
Pinned host memory 624.399 44.536 44.143
Device 16.615 4.546 2.359

Copy link

template <class K, class V, class S,
template <typename, typename> class PredFunctor,
int TILE_SIZE>
__global__ void dump_kernel_v2(const Table<K, V, S>* __restrict table,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No call to this kernel?

int dim = table->dim;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());

__shared__ block_acc;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_acc is not used.

@Lifann
Copy link
Collaborator Author

Lifann commented Aug 14, 2024

Hi, @jiashuy This PR is under development yet. I'll fix the problems ASAP.

jiashuy
jiashuy previously approved these changes Aug 18, 2024
Copy link
Collaborator

@jiashuy jiashuy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiashuy
Copy link
Collaborator

jiashuy commented Aug 18, 2024

/blossom-ci

2 similar comments
@jiashuy
Copy link
Collaborator

jiashuy commented Aug 19, 2024

/blossom-ci

@jiashuy
Copy link
Collaborator

jiashuy commented Aug 20, 2024

/blossom-ci

cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaEventRecord(start);
table->export_batch_if<ExportIfPredFunctor>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are total three kernel templates, and has each kernel been tested? If not is it necessary to test each kernel.

Copy link
Collaborator Author

@Lifann Lifann Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tested them seperately, but not added them into the tests case. Since it's an internal option not for public API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants