-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscratch
74 lines (70 loc) · 2.35 KB
/
scratch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// language: metal2.4
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct _mslBufferSizes {
uint size0;
uint size1;
};
typedef float type_1[1];
struct Meta {
metal::uint3 in_strides;
metal::packed_uint3 out_strides;
uint offset;
float base;
float scale;
};
struct main_Input {
};
kernel void main_(
metal::uint3 local_id [[thread_position_in_threadgroup]]
, metal::uint3 pos [[thread_position_in_grid]]
, uint subgroup_id [[simdgroup_index_in_threadgroup]]
, uint subgroup_size [[threads_per_simdgroup]]
, metal::uint3 groups [[threadgroups_per_grid]]
, device type_1 const& in [[buffer(0)]]
, device type_1& out [[buffer(1)]]
, constant Meta& metadata [[buffer(2)]]
, constant _mslBufferSizes& _buffer_sizes [[buffer(3)]]
) {
uint in_index_1_ = 0u;
uint in_index_2_ = 0u;
uint out_index_1_ = 0u;
uint out_index_2_ = 0u;
metal::uint3 grid = metal::uint3(groups.x * 16u, groups.y * 8u, groups.z * 8u);
uint _e27 = metadata.out_strides[2];
uint _e33 = metadata.out_strides[1];
uint _e40 = metadata.out_strides[0];
out_index_1_ = ((pos.x * _e27) + (pos.y * _e33)) + (pos.z * _e40);
uint _e43 = out_index_1_;
uint _e48 = metadata.out_strides[2];
out_index_2_ = _e43 + (grid.x * _e48);
uint _e55 = metadata.in_strides.z;
uint _e61 = metadata.in_strides.y;
uint _e68 = metadata.in_strides.x;
in_index_1_ = ((pos.x * _e55) + (pos.y * _e61)) + (pos.z * _e68);
uint _e71 = in_index_1_;
uint _e76 = metadata.in_strides.z;
in_index_2_ = _e71 + (grid.x * _e76);
float _e81 = metadata.scale;
uint _e85 = metadata.offset;
float L = _e81 * static_cast<float>(pos.y + _e85);
float d = static_cast<float>(pos.x) / static_cast<float>(grid.x);
float _e97 = metadata.base;
float theta = L * metal::exp2(-(d) * _e97);
float costheta = metal::cos(theta);
float sintheta = metal::sin(theta);
uint _e104 = in_index_1_;
float _e106 = in[_e104];
float x1_ = static_cast<float>(_e106);
uint _e109 = in_index_2_;
float _e111 = in[_e109];
float x2_ = static_cast<float>(_e111);
float rx1_ = (x1_ * costheta) - (x2_ * sintheta);
float rx2_ = (x1_ * sintheta) + (x2_ * costheta);
uint _e120 = out_index_1_;
out[_e120] = static_cast<float>(rx1_);
uint _e124 = out_index_2_;
out[_e124] = static_cast<float>(rx2_);
return;
}