-
Notifications
You must be signed in to change notification settings - Fork 977
/
ampere_fprop_mainloop_fusion.cu
768 lines (624 loc) · 25 KB
/
ampere_fprop_mainloop_fusion.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/**
This example shows how to fuse per channel scale+bias+relu of the activations
into the fprop mainloop.
Compared with original fprop kernel, this example has two more vectors, one for
the scale and one for the bias. The length of the vectors are the same as the
activation channel number. This kernels loads the vectors when the associated
activation channels are loaded in the mainloop. Between reading the
activations and scale/bias data from the shared memory and calling tensor core
instructions, scale+bias+relu is computed in the register file.
This example is customized for Ampere 16816 fp16 tensor core instruction.
Changing to different data types or different tensor core instruction require
source code changing. See
include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more
technical details.
This example is modified based on 16_ampere_tensorop_conv2dfprop. The command
line is the same.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/kernel/default_conv2d_fprop_fusion.h"
#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/device/convolution.h"
#include "cutlass/util/tensor_view_io.h"
#include "helper.h"
// The code section below describes datatype for input, output tensors and computation between
// elements
using ElementAccumulator = float; // Data type of accumulator
using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta)
using ElementInputA = cutlass::half_t; // Data type of elements in input tensor
using ElementInputB = cutlass::half_t; // Data type of elements in input tensor
using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors
using ElementOutput = float; // Data type of elements in output tensor
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutInputScaleBias = cutlass::layout::RowMajor;
using LayoutOutput = cutlass::layout::TensorNHWC;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch = cutlass::arch::Sm80;
// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape
// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape
// This code section describes the size of MMA op
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Number of pipelines you want to use
constexpr int NumStages = 4;
// This code section describe iterator algorithm selected is Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;
// This code section describes the epilogue part of the kernel, we use default value
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
128 / cutlass::sizeof_bits<ElementOutput>::value, // The number of elements per vectorized.
// memory access. This becomes the vector width of
// math instructions in the epilogue too.
ElementAccumulator, // Data type of accumulator
ElementComputeEpilogue>; // Data type for alpha/beta in linear combination
using Conv2dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv2dFpropFusion<
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementInputScaleBias, LayoutInputScaleBias,
ElementOutput, LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm
>::Kernel;
using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion<Conv2dFpropFusionKernel>;
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
cutlass::Tensor4DCoord input_size;
cutlass::Tensor4DCoord filter_size;
cutlass::Tensor4DCoord padding;
cutlass::MatrixCoord conv_stride;
cutlass::MatrixCoord dilation;
bool reference_check;
bool measure_performance;
int iterations;
bool save_workspace;
ElementComputeEpilogue alpha;
ElementComputeEpilogue beta;
bool benchmark;
std::string tag;
Options():
help(false),
input_size(1, 32, 32, 32),
filter_size(32, 3, 3, 32),
padding(1, 1, 1, 1),
conv_stride(1, 1),
dilation(1, 1),
reference_check(true),
measure_performance(false),
iterations(20),
save_workspace(false),
alpha(1),
beta(0),
benchmark(false) { }
// Verify the problem size is compatible with the CUTLASS Convolution implementation.
bool valid() {
//
// CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently,
// all pointers, strides, and tensor extents must be divisible by 8 elements.
//
int const kAlignment = 8;
if ((input_size.c() % kAlignment) ||
(filter_size.n() % kAlignment)) {
// misaligned tensors
return false;
}
// Invalid padding
if ((padding.h() != filter_size.h() / 2) ||
(padding.w() != filter_size.w() / 2)) {
return false;
}
return true;
}
/// Updates input and filter sizes
void update(
cutlass::Tensor4DCoord input_size,
cutlass::Tensor4DCoord filter_size,
cutlass::MatrixCoord stride) {
this->input_size = input_size;
this->filter_size = filter_size;
conv_stride = stride;
padding.n() = filter_size.h() / 2;
padding.h() = filter_size.h() / 2;
padding.w() = filter_size.w() / 2;
padding.c() = filter_size.w() / 2;
}
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
if (cmd.check_cmd_line_flag("ref-check")) {
reference_check = true;
}
if (cmd.check_cmd_line_flag("perf-check")) {
measure_performance = true;
}
if (cmd.check_cmd_line_flag("save-workspace")) {
save_workspace = true;
}
if (cmd.check_cmd_line_flag("benchmark")) {
benchmark = true;
}
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
filter_size.c() = input_size.c();
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("tag", tag);
if (filter_size.h() == 3 && filter_size.w() == 3) {
padding = {1, 1, 1, 1};
}
else {
filter_size.h() = 1;
filter_size.w() = 1;
padding = {0, 0, 0, 0};
}
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "25_ampere_fprop_mainloop_fusion example\n\n"
<< " This example fuses scale+bias+relu of the activations into Ampere's\n"
<< " Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --n=<int> Input tensor extent N\n"
<< " --h=<int> Input tensor extent H\n"
<< " --w=<int> Input tensor extent W\n"
<< " --c=<int> Input tensor extent C\n"
<< " --k=<int> Filter extent K\n"
<< " --r=<int> Filter extent R\n"
<< " --s=<int> Filter extent S\n\n"
<< " --alpha=<float> Epilogue scalar alpha\n"
<< " --beta=<float> Epilogue scalar beta\n\n"
<< " --ref-check If set (true), reference check on the host is computed\n"
<< " --perf-check If set (true), performance is measured.\n"
<< " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n"
<< " --save-workspace If set, workspace is written to a text file.\n"
<< " --tag=<string> String to replicate across the first column in the results table\n";
out << "\n\nExamples:\n\n"
<< "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n"
<< "$ ./examples/25_ampere_fprop_mainloop_fusion/25_ampere_fprop_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n";
return out;
}
/// Computes the output tensor size (NPQK)
cutlass::Tensor4DCoord output_size() const {
return cutlass::Tensor4DCoord(
input_size.n(),
(input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1,
(input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1,
filter_size.n());
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {
// Number of multiply-adds = NPQK * CRS
int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c());
// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
struct Result {
double runtime_ms;
double gflops;
cutlass::Status status;
cutlass::Status reference_check;
cudaError_t error;
Result():
runtime_ms(0),
gflops(0),
status(cutlass::Status::kSuccess),
reference_check(cutlass::Status::kInvalid),
error(cudaSuccess) { }
static std::ostream & print_header(std::ostream &out, Options const &options) {
if (!options.tag.empty()) {
out << "Name,";
}
out << "Layer,N,H,W,C,K,R,S,Stride_H,Stride_W,Runtime,GFLOPs";
return out;
}
std::ostream & print(std::ostream &out, int idx, Options const &options) {
if (!options.tag.empty()) {
out << options.tag << ",";
}
out
<< "conv_" << idx << ","
<< options.input_size.n() << ","
<< options.input_size.h() << ","
<< options.input_size.w() << ","
<< options.input_size.c() << ","
<< options.filter_size.n() << ","
<< options.filter_size.h() << ","
<< options.filter_size.w() << ","
<< options.conv_stride.row() << ","
<< options.conv_stride.column() << ","
<< runtime_ms << ","
<< gflops;
return out;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Runs one benchmark
Result profile_convolution(Options const &options) {
Result result;
//
// Allocate host-device tensors using the CUTLASS Utilities.
//
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_transformed_a(options.input_size);
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
tensor_a_scale({1, options.input_size.c()});
cutlass::HostTensor<ElementInputScaleBias, LayoutInputScaleBias>
tensor_a_bias({1, options.input_size.c()});
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(options.output_size());
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_ref_d(options.output_size());
//
// Initialize tensors
//
// Fill tensor A on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(),
1,
ElementInputA(3),
ElementInputA(-4),
0);
// Fill scale vector for tensor A on host with uniform-distribution random
// data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a_scale.host_view(),
1,
ElementInputA(3),
ElementInputA(-4),
0);
// Fill bias vector for tensor A on host with uniform-distribution random
// data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a_bias.host_view(),
1,
ElementInputA(3),
ElementInputA(-4),
0);
// Fill tensor B on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(),
1,
ElementInputB(7),
ElementInputB(-8),
0);
// Fill tensor C on host with uniform-distribution random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_c.host_view(),
1,
ElementOutput(7),
ElementOutput(-8),
0);
// Fill tensor D on host with zeros
cutlass::reference::host::TensorFill(
tensor_d.host_view());
// Fill tensor D for reference on host with zeros
cutlass::reference::host::TensorFill(
tensor_ref_d.host_view());
// Copy data from host to GPU
tensor_a.sync_device();
tensor_a_scale.sync_device();
tensor_a_bias.sync_device();
tensor_b.sync_device();
tensor_c.sync_device();
tensor_d.sync_device();
tensor_ref_d.sync_device();
//
// Define arguments for CUTLASS Convolution
//
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
// Split K dimension into 1 partitions
int split_k_slices = 1;
// Construct Conv2dProblemSize with user defined output size
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
mode,
split_k_slices
);
typename ImplicitGemmFusion::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_a_scale.device_ref(),
tensor_a_bias.device_ref(),
tensor_c.device_ref(),
tensor_d.device_ref(),
{options.alpha, options.beta},
};
//
// Initialize CUTLASS Convolution
//
ImplicitGemmFusion implicit_gemm_fusion_op;
size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
result.status = implicit_gemm_fusion_op.can_implement(arguments);
CUTLASS_CHECK(result.status);
result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(result.status);
//
// Launch initialized CUTLASS kernel
//
result.status = implicit_gemm_fusion_op();
CUTLASS_CHECK(result.status);
//
// Optional reference check
//
if (options.reference_check) {
std::cout << "Verification on device...\n";
// Compute scale + bias + relu in host code
for (int n = 0; n < options.input_size.n(); ++n) {
for (int h = 0; h < options.input_size.h(); ++h) {
for (int w = 0; w < options.input_size.w(); ++w) {
for (int c = 0; c < options.input_size.c(); ++c) {
tensor_transformed_a.at({n, h, w, c}) = std::max(
ElementOutput(0), ElementOutput(tensor_a.at({n, h, w, c}) *
tensor_a_scale.at({0, c}) +
tensor_a_bias.at({0, c})));
}
}
}
}
tensor_transformed_a.sync_device();
// Compute with reference implementation
cutlass::reference::device::Conv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementComputeEpilogue,
ElementAccumulator,
cutlass::NumericConverter<ElementOutput, ElementComputeEpilogue>
>(
problem_size,
tensor_transformed_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_ref_d.device_ref(),
options.alpha,
options.beta
);
// Check if output from CUTLASS kernel and reference kernel are equal or not
tensor_d.sync_host();
tensor_ref_d.sync_host();
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
if (!passed) {
result.reference_check = cutlass::Status::kErrorInternal;
std::cout << "ERROR - results miscompared.\n";
}
else {
result.reference_check = cutlass::Status::kSuccess;
std::cout << "Passed.\n";
}
}
else {
result.reference_check = cutlass::Status::kInvalid;
}
if (options.save_workspace) {
std::stringstream ss;
ss << "25_ampere_fprop_mainloop_fusion"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
<< ".dat";
std::ofstream output_workspace(ss.str());
output_workspace
<< "Input = \n" << tensor_a.host_view() << "\n\n"
<< "Filters = \n" << tensor_b.host_view() << "\n\n";
if (options.reference_check) {
output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n";
}
output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl;
std::cout << "Results written to '" << ss.str() << "'." << std::endl;
}
//
// Performance measurement
//
if (options.measure_performance) {
cudaEvent_t events[2];
for (auto & event : events) {
result.error = cudaEventCreate(&event);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
}
// Record an event at the start of a series of convolution operations.
result.error = cudaEventRecord(events[0]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Launch a sequence of implicit GEMM operations on the device
for (int iteration = 0; iteration < options.iterations; ++iteration) {
result.status = implicit_gemm_fusion_op();
CUTLASS_CHECK(result.status);
}
// Record an event when the convolutions have been launched.
result.error = cudaEventRecord(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Wait for work on the device to complete.
result.error = cudaEventSynchronize(events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Measure elapsed runtime
float runtime_ms = 0;
result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
if (result.error != cudaSuccess) {
std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl;
return result;
}
// Print average runtime and GFLOPs.
result.runtime_ms = double(runtime_ms) / double(options.iterations);
result.gflops = options.gflops(result.runtime_ms / 1000.0);
// Cleanup
for (auto event : events) {
(void)cudaEventDestroy(event);
}
}
return result;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
bool notSupported = false;
// Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples.
if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) {
std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl;
notSupported = true;
}
cudaDeviceProp props;
CUDA_CHECK(cudaGetDeviceProperties(&props, 0));
if (!(props.major >= 8)) {
std::cerr << "This test must run on SM80 or above.\n";
notSupported = true;
}
if (notSupported) {
return 0;
}
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (options.benchmark) {
// Benchmark several layers
int batch_sizes[] = {34, 408};
struct Benchmark {
int h, w, c, k, r, s, stride_h, stride_w;
} layers[] = {
{56, 56, 64, 256, 1, 1, 1, 1},
{56, 56, 64, 64, 1, 1, 1, 1},
{56, 56, 64, 64, 3, 3, 1, 1},
{56, 56, 256, 64, 1, 1, 1, 1},
{56, 56, 256, 512, 1, 1, 2, 2},
{56, 56, 256, 128, 1, 1, 1, 1},
{56, 56, 128, 128, 3, 3, 2, 2},
{28, 28, 128, 512, 1, 1, 1, 1},
{28, 28, 512, 128, 1, 1, 1, 1},
{28, 28, 128, 128, 3, 3, 1, 1},
{28, 28, 512, 1024, 1, 1, 2, 2},
{28, 28, 512, 256, 1, 1, 1, 1},
{28, 28, 256, 256, 3, 3, 2, 2},
{14, 14, 256, 1024, 1, 1, 1, 1},
{14, 14, 1024, 256, 1, 1, 1, 1},
{14, 14, 256, 256, 3, 3, 1, 1},
{14, 14, 1024, 2048, 1, 1, 2, 2},
{14, 14, 1024, 512, 1, 1, 1, 1},
{14, 14, 512, 512, 3, 3, 2, 2},
{ 7, 7, 512, 2048, 1, 1, 1, 1},
{ 7, 7, 2048, 512, 1, 1, 1, 1},
{ 7, 7, 512, 512, 3, 3, 1, 1},
};
Result::print_header(std::cout, options) << std::endl;
int idx = 1;
for (auto const &layer : layers) {
for (auto N : batch_sizes) {
options.update({N, layer.h, layer.w, layer.c},
{layer.k, layer.r, layer.s, layer.c},
{layer.stride_h, layer.stride_w});
Result result = profile_convolution(options);
result.print(std::cout, idx, options) << std::endl;
}
++idx;
}
}
else {
// Execute one problem size
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
Result result = profile_convolution(options);
Result::print_header(std::cout, options) << std::endl;
result.print(std::cout, 1, options) << std::endl;
}
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////