Skip to content

Commit

Permalink
[CPU][RV64] Implement Exp
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Feb 10, 2025
1 parent 8511f9a commit f4a4cf1
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ std::set<std::vector<element::Type>> jit_add_emitter::get_supported_precisions(c
}


/// CLamp ///
/// Clamp ///
jit_clamp_emitter::jit_clamp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, get_arithmetic_binary_exec_precision(node)) {
if (const auto clamp = ov::as_type_ptr<ov::op::v0::Clamp>(node)) {
Expand Down Expand Up @@ -93,10 +93,10 @@ bool jit_clamp_emitter::need_table() const {
}

const void* jit_clamp_emitter::get_table() const {
static float values[2];
values[0] = min; // use explicit assignment to change dynamically array in runtime
values[1] = max;
return values;
static float tbl[2];
tbl[0] = min; // use explicit assignment to change dynamically array in runtime
tbl[1] = max;
return tbl;
}

/// DIV ///
Expand All @@ -122,6 +122,140 @@ std::set<std::vector<element::Type>> jit_div_emitter::get_supported_precisions(c
return {{element::f32, element::f32}};
}

/// Exp ///
jit_exp_emitter::jit_exp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, get_arithmetic_binary_exec_precision(node)) {}

jit_exp_emitter::jit_exp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const ov::element::Type exec_prc)
: jit_emitter(host, exec_prc) {}

size_t jit_exp_emitter::get_inputs_num() const {
return 1;
}

size_t jit_exp_emitter::aux_gprs_count() const {
return 2;
}

size_t jit_exp_emitter::aux_vecs_count() const {
return 3;
}

void jit_exp_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
VReg src = VReg(in_vec_idxs[0]);
VReg dst = VReg(out_vec_idxs[0]);
VReg aux0 = VReg(aux_vec_idxs[0]);
VReg aux1 = VReg(aux_vec_idxs[1]);
VReg aux2 = VReg(aux_vec_idxs[2]);

// save src
h->vmv_v_v(aux2, src);

// get mask of values lower than log(FLT_MIN) to zero them in the output
FReg ln_flt_min_f = f1;
h->flw(ln_flt_min_f, p_table, 10 * sizeof(uint32_t));
h->vfmax_vf(dst, src, ln_flt_min_f);

FReg ln_flt_max_f = f0;
h->flw(ln_flt_max_f, p_table, 9 * sizeof(uint32_t));
h->vfmin_vf(dst, dst, ln_flt_max_f);

// keep dst = x for further computations
h->vmv_v_v(aux0, dst);

// calculate exp(x)
// fx = x * log2ef + 0.5
FReg log2ef = f0;
h->flw(log2ef, p_table, 8 * sizeof(uint32_t));
h->vfmul_vf(dst, dst, log2ef);
FReg half = f0;
h->flw(half, p_table, 6 * sizeof(uint32_t));
h->vfadd_vf(dst, dst, log2ef);

// aux1 = floorf(fx)
h->vfcvt_x_f_v(aux1, dst); // fp32 -> int32
h->vfcvt_f_x_v(aux1, aux1); // int32 -> fp32
h->vmfgt_vv(mask_vreg(), aux1, dst);
FReg one = f0;
h->flw(one, p_table, 5 * sizeof(uint32_t)); // one
h->vfsub_vf(aux1, aux1, one, VM::masked);

// keep dst = floorf(fx) for further computations
h->vmv_v_v(dst, aux1);

// x = x - fx * ln2
FReg ln2 = f0;
h->flw(ln2, p_table, 7 * sizeof(uint32_t));
h->vfnmsac_vf(aux0, ln2, aux1);

// compute 2^n
Reg tmp = Reg(aux_gpr_idxs[0]);
h->vfcvt_x_f_v(aux1, dst);
h->lw(tmp, p_table, 11 * sizeof(uint32_t)); // exponent_bias
h->vadd_vx(aux1, aux1, tmp);
const int n_mantissa_bits = 23;
h->vsll_vi(aux1, aux1, n_mantissa_bits);

// set zeroes at those points which were < log(FLT_MIN)
// Note: Xbyak doesn't support vmv_v_i with mask to set zero where masked
h->vmflt_vf(mask_vreg(), aux2, ln_flt_min_f); // aux - tmp mask
h->vand_vx(aux1, aux1, zero, VM::masked);

// compute polynomial
FReg pol = f0;
h->flw(pol, p_table, 4 * sizeof(uint32_t)); // pol5
h->vfmv_v_f(dst, pol);

h->flw(pol, p_table, 3 * sizeof(uint32_t)); // pol4
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 2 * sizeof(uint32_t)); // pol3
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 1 * sizeof(uint32_t)); // pol2
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 0 * sizeof(uint32_t)); // pol1
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

h->flw(pol, p_table, 5 * sizeof(uint32_t)); // one
h->vfmv_v_f(aux2, pol);
h->vfmadd_vv(dst, aux0, aux2);

// y = y * 2^n
h->vfmul_vv(dst, dst, aux1);
}

std::set<std::vector<element::Type>> jit_exp_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

bool jit_exp_emitter::need_table() const {
return true;
}

const void* jit_exp_emitter::get_table() const {
static uint32_t tbl[12] = {
0x3f7ffffb, // pol1
0x3efffee3, // pol2
0x3e2aad40, // pol3
0x3d2b9d0d, // pol4
0x3c07cfce, // pol5
0x3f800000, // one
0x3f000000, // 0.5f
0x3f317218, // ln2f
0x3fb8aa3b, // log2ef
0x42b17218, // ln_flt_max_f
0xc2aeac50, // ln_flt_min_f
0x0000007f // exponent_bias
};
return tbl;
}

/// MUL ///
jit_mul_emitter::jit_mul_emitter(ov::intel_cpu::riscv64::jit_generator* host, const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, get_arithmetic_binary_exec_precision(node)) {}
Expand Down Expand Up @@ -225,9 +359,9 @@ bool jit_relu_emitter::need_table() const {
}

const void* jit_relu_emitter::get_table() const {
static float values[1];
values[0] = alpha; // use explicit assignment to change dynamically array in runtime
return values;
static float tbl[1];
tbl[0] = alpha; // use explicit assignment to change dynamically array in runtime
return tbl;
}

/// SUB ///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ class jit_div_emitter : public jit_emitter {
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
};

class jit_exp_emitter : public jit_emitter {
public:
jit_exp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const ov::element::Type exec_prc = ov::element::f32);
jit_exp_emitter(ov::intel_cpu::riscv64::jit_generator* host, const std::shared_ptr<ov::Node>& node);

size_t get_inputs_num() const override;
size_t aux_gprs_count() const override;
size_t aux_vecs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
bool need_table() const override;
const void* get_table() const override;
};

class jit_mul_emitter : public jit_emitter {
public:
jit_mul_emitter(ov::intel_cpu::riscv64::jit_generator* host, const ov::element::Type exec_prc = ov::element::f32);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
Algorithm::EltwiseAdd,
Algorithm::EltwiseClamp,
Algorithm::EltwiseDivide,
Algorithm::EltwiseExp,
Algorithm::EltwiseMultiply,
Algorithm::EltwisePrelu,
Algorithm::EltwiseRelu,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic::create_eltwise_emitter(con
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, jit_div_emitter),
OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_mul_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter),
Expand Down Expand Up @@ -463,6 +464,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, jit_div_emitter),
OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter),
OV_CASE(Algorithm::EltwiseMultiply, jit_mul_emitter),
OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter),
OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,10 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
#if defined(OPENVINO_ARCH_RISCV64)
if ((activation_type == utils::ActivationTypes::Relu) ||
(activation_type == utils::ActivationTypes::PReLu) ||
(activation_type == utils::ActivationTypes::Clamp))
(activation_type == utils::ActivationTypes::Clamp) ||
(activation_type == utils::ActivationTypes::Exp))
return "jit";
#if defined(OV_CPU_WITH_SHL)
if ((activation_type == utils::ActivationTypes::Exp)) {
return "shl";
} else {
return "ref";
}
#endif
return "ref";
#else
return CPUTestsBase::getPrimitiveType();
#endif
Expand Down

0 comments on commit f4a4cf1

Please sign in to comment.