forked from Spotlight0xff/warp-rna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
core.h
39 lines (32 loc) · 975 Bytes
/
core.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#ifndef RNA_CORE_H
#define RNA_CORE_H
#ifdef WARPRNA_ENABLE_GPU
#include <cuda_runtime.h>
#endif
typedef enum {
RNA_STATUS_SUCCESS = 0,
RNA_STATUS_WARP_FAILED = 1,
RNA_STATUS_GRADS_BLANK_FAILED = 2,
RNA_STATUS_GRADS_LABEL_FAILED = 3,
RNA_STATUS_COSTS_FAILED = 4
} rnaStatus_t;
#ifdef __cplusplus
#include <cstddef>
extern "C" {
#endif
#ifdef WARPRNA_ENABLE_GPU
rnaStatus_t run_warp_rna(
cudaStream_t stream, unsigned int *counts, float *alphas, float *betas,
const int *labels, const float *log_probs, float *grads, float *costs,
const int *xn, const int *yn, int N, int T, int S, int U, int V, int blank);
#endif
#ifdef WARPRNA_ENABLE_CPU
rnaStatus_t run_warp_rna_cpu(
unsigned int *counts, float *alphas, float *betas,
const int *labels, const float *log_probs, float *grads, float *costs,
const int *xn, const int *yn, int N, int T, int S, int U, int V, int blank);
#endif
#ifdef __cplusplus
}
#endif
#endif //RNA_CORE_H