-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cu
79 lines (68 loc) · 2.1 KB
/
main.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <array>
#include <iostream>
#include <cuda/std/array>
#include <cuda/std/tuple>
#include <cuda/barrier>
#include <cublasdx.hpp>
#include <torch/torch.h>
#include "util.cuh"
#include "mma.cuh"
struct Expert final : torch::nn::Module {
torch::nn::Linear g1;
torch::nn::Linear g2;
Expert():
g1(torch::nn::LinearOptions(2, 4)),
g2(torch::nn::LinearOptions(4, 2)) {
register_module("g1", g1);
register_module("g2", g2);
}
torch::Tensor forward(torch::Tensor& x) {
x = relu(g1->forward(x));
return g2->forward(x);
}
};
struct __align__(8) Foo {
uint x;
uint y;
};
__host__ __forceinline__
void tensorWork() {
const torch::nn::Sequential expert(
torch::nn::Linear(2,4),
torch::nn::ReLU(),
torch::nn::Linear(4, 2)
);
std::cout << expert << std::endl;
std::array<float, 4> a{0, 1, 2, 3};
torch::Device device(torch::kCUDA);
const torch::Tensor tensor = torch::from_blob(a.data(), {2,2}).to(device).to(torch::kFloat8_e4m3fn);
const Expert model;
const Expert model2;
std::cout << model << std::endl;
for (const auto& p : model.named_parameters()) {
std::cout << p.key() << std::endl;
std::cout << p.value() << std::endl;
}
// pack both experts into a single torch tensor
constexpr auto nX = 2U;
constexpr auto GEMMs = 2U;
constexpr auto h = 2U;
constexpr auto upH = 4U;
const torch::Tensor pT = torch::zeros({nX, GEMMs, h, upH}).contiguous();
pT[0][0] = model.named_parameters()[2].value();
pT[0][1] = model.named_parameters()[2].value();
const auto expert1 = pT[0][0];
// flatten [s, b, h] -> [sb, h]
constexpr auto s = 2U;
constexpr auto b = 4U;
const torch::Tensor act = torch::ones({s, b, h}).contiguous();
const auto sz = act.sizes();
/*std::cout << sz.size() << std::endl;
std::cout << act << std::endl;
std::cout << act.view({sz[0]*sz[1], h}) << std::endl;*/
/*auto* __restrict__ p = pT.const_data_ptr<float>();
std::cout << p[15] << std::endl;*/
}
int main() {
testCollective();
}