Skip to content

Commit

Permalink
Rename method to_float() to to_float32() (#1114)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Mar 6, 2023
1 parent 0137dbd commit a21e4f4
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 40 deletions.
2 changes: 1 addition & 1 deletion include/ctranslate2/storage_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace ctranslate2 {
StorageView to(Device D) const;
StorageView to(DataType dtype) const;
StorageView to_float16() const;
StorageView to_float() const;
StorageView to_float32() const;

// Actual storage type.
DataType dtype() const {
Expand Down
6 changes: 3 additions & 3 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ namespace ctranslate2 {
if (!is_expanded)
repeat_batch(attention_step, _beam_size);
split_batch_beam(attention_step, _beam_size);
append_step_output(alive_attention, attention_step.to_float().to(Device::CPU));
append_step_output(alive_attention, attention_step.to_float32().to(Device::CPU));
gather_beam_flat(alive_attention, gather_indices, num_candidates);
}

Expand Down Expand Up @@ -736,7 +736,7 @@ namespace ctranslate2 {
if (prefix_ids)
update_sample_with_prefix(step, best_ids, best_probs, *prefix_ids, end_id, batch_offset);
if (attention_step_device)
attention_step.copy_from(attention_step_device.to_float());
attention_step.copy_from(attention_step_device.to_float32());

if (!logits_processors.empty()) {
if (alive_seq) {
Expand Down Expand Up @@ -978,7 +978,7 @@ namespace ctranslate2 {

if (options.return_attention) {
if (attention.device() != Device::CPU)
attention = attention.to_float().to(Device::CPU);
attention = attention.to_float32().to(Device::CPU);
for (dim_t t = 0; t < prefix_length; ++t) {
const float* vector = attention.index<float>({0, t, 0});
result.attention[i].emplace_back(vector, vector + attention.dim(-1));
Expand Down
4 changes: 2 additions & 2 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ namespace ctranslate2 {

if (target_dtype == DataType::FLOAT32 || target_dtype == DataType::FLOAT16) {
if (is_float16) {
target_variable = variable.to_float();
target_variable = variable.to_float32();
} else if (is_float32) {
target_variable = variable.to_float16();
} else {
Expand All @@ -304,7 +304,7 @@ namespace ctranslate2 {
// Quantize float32 to int8 or int16.
StorageView scale;
if (is_float16) {
quantize_op(variable.to_float(), target_variable, scale);
quantize_op(variable.to_float32(), target_variable, scale);
} else {
quantize_op(variable, target_variable, scale);
}
Expand Down
4 changes: 2 additions & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace ctranslate2 {
ops::Gather(/*axis=*/1, /*batch_dims=*/1)(probs, gather_ids, no_speech_probs);

if (no_speech_probs.dtype() != DataType::FLOAT32)
no_speech_probs = no_speech_probs.to_float();
no_speech_probs = no_speech_probs.to_float32();
return no_speech_probs.to_vector<float>();
}

Expand Down Expand Up @@ -378,7 +378,7 @@ namespace ctranslate2 {
ops::SoftMax()(lang_probs);

if (lang_probs.dtype() != DataType::FLOAT32)
lang_probs = lang_probs.to_float();
lang_probs = lang_probs.to_float32();
if (lang_probs.device() != Device::CPU)
lang_probs = lang_probs.to(Device::CPU);

Expand Down
2 changes: 1 addition & 1 deletion src/scoring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace ctranslate2 {
if (scores.device() != Device::CPU)
scores = scores.to(Device::CPU);
if (scores.dtype() != DataType::FLOAT32)
scores = scores.to_float();
scores = scores.to_float32();

std::vector<ScoringResult> results(batch_size);
for (dim_t b = 0; b < batch_size; ++b) {
Expand Down
2 changes: 1 addition & 1 deletion src/storage_view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ namespace ctranslate2 {
return to(DataType::FLOAT16);
}

StorageView StorageView::to_float() const {
StorageView StorageView::to_float32() const {
return to(DataType::FLOAT32);
}

Expand Down
46 changes: 23 additions & 23 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ TEST_P(OpDeviceFPTest, Gemm) {
ops::Gemm op(1.0, 1.0, false, false);
y = y.to(dtype);
op(a.to(dtype), b.to(dtype), y);
expect_storage_eq(y.to_float(), expected);
expect_storage_eq(y.to_float32(), expected);
};

TEST_P(OpDeviceTest, GemmInt8) {
Expand Down Expand Up @@ -525,7 +525,7 @@ TEST_P(OpDeviceFPTest, TopK) {
StorageView indices(expected_indices.dtype(), device);
ops::TopK op(k);
op(input.to(dtype), values, indices);
expect_storage_eq(values.to_float(), expected_values, 1e-3);
expect_storage_eq(values.to_float32(), expected_values, 1e-3);
expect_storage_eq(indices, expected_indices);
}

Expand Down Expand Up @@ -586,9 +586,9 @@ TEST_P(OpDeviceFPTest, SoftMax) {
0.760941, 0.207381, 0.009342, 0.001544, 0.020792}, device);
StorageView y(dtype, device);
ops::SoftMax()(x, y);
expect_storage_eq(y.to_float(), expected, 1e-3);
expect_storage_eq(y.to_float32(), expected, 1e-3);
ops::SoftMax()(x);
expect_storage_eq(x.to_float(), expected, 1e-3);
expect_storage_eq(x.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, LogSoftMax) {
Expand All @@ -602,9 +602,9 @@ TEST_P(OpDeviceFPTest, LogSoftMax) {
-0.319434, -1.619434, -4.719434, -6.519434, -3.919434, -9.519434, -8.219434, -5.119434, -3.319434, -5.919434}, device);
StorageView y(dtype, device);
ops::LogSoftMax()(x, y);
expect_storage_eq(y.to_float(), expected, 1e-2);
expect_storage_eq(y.to_float32(), expected, 1e-2);
ops::LogSoftMax()(x);
expect_storage_eq(x.to_float(), expected, 1e-2);
expect_storage_eq(x.to_float32(), expected, 1e-2);
}

TEST_P(OpDeviceFPTest, MaskedSoftMax) {
Expand All @@ -619,7 +619,7 @@ TEST_P(OpDeviceFPTest, MaskedSoftMax) {
0.777098, 0.211783, 0.009540, 0.001577, 0}, device);
StorageView y(dtype, device);
ops::SoftMax()(x.to(dtype), lengths, y);
expect_storage_eq(y.to_float(), expected, 1e-3);
expect_storage_eq(y.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) {
Expand Down Expand Up @@ -657,7 +657,7 @@ TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) {
}, device);
StorageView y(dtype, device);
ops::SoftMax()(x.to(dtype), mask, y);
expect_storage_eq(y.to_float(), expected, 1e-3);
expect_storage_eq(y.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, LayerNorm) {
Expand All @@ -673,7 +673,7 @@ TEST_P(OpDeviceFPTest, LayerNorm) {
-6.319339, -3.988876, -0.637330, 2.841982, -0.158437}, device);
StorageView y(dtype, device);
ops::LayerNorm()(beta.to(dtype), gamma.to(dtype), x.to(dtype), y);
expect_storage_eq(y.to_float(), expected, 1e-3);
expect_storage_eq(y.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, RMSNorm) {
Expand All @@ -688,7 +688,7 @@ TEST_P(OpDeviceFPTest, RMSNorm) {
0.3445, 2.5953, 0.0824, 0.3595, 0.2622}, device);
StorageView y(dtype, device);
ops::RMSNorm()(gamma.to(dtype), x.to(dtype), y);
expect_storage_eq(y.to_float(), expected, 1e-3);
expect_storage_eq(y.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceTest, QuantizeINT8) {
Expand Down Expand Up @@ -768,7 +768,7 @@ TEST_P(OpDeviceFPTest, ReLU) {
StorageView expected({2, 5}, std::vector<float>{0, 1, 2, 0, 2, 4, 0, 0, 0, 0}, device);
StorageView output(dtype, device);
ops::ReLU()(input.to(dtype), output);
expect_storage_eq(output.to_float(), expected);
expect_storage_eq(output.to_float32(), expected);
}

TEST_P(OpDeviceFPTest, GELU) {
Expand All @@ -778,7 +778,7 @@ TEST_P(OpDeviceFPTest, GELU) {
StorageView expected({2}, std::vector<float>{0.11585195362567902, -0.1258406937122345}, device);
StorageView output(dtype, device);
ops::GELU()(input.to(dtype), output);
expect_storage_eq(output.to_float(), expected, 1e-4);
expect_storage_eq(output.to_float32(), expected, 1e-4);
}

TEST_P(OpDeviceFPTest, GELUTanh) {
Expand All @@ -789,7 +789,7 @@ TEST_P(OpDeviceFPTest, GELUTanh) {
StorageView output(dtype, device);
const ops::GELU gelu_op(ops::GELU::Approximation::Tanh);
gelu_op(input.to(dtype), output);
expect_storage_eq(output.to_float(), expected, 1e-4);
expect_storage_eq(output.to_float32(), expected, 1e-4);
}

TEST_P(OpDeviceFPTest, GELUSigmoid) {
Expand All @@ -800,7 +800,7 @@ TEST_P(OpDeviceFPTest, GELUSigmoid) {
StorageView output(dtype, device);
const ops::GELU gelu_op(ops::GELU::Approximation::Sigmoid);
gelu_op(input.to(dtype), output);
expect_storage_eq(output.to_float(), expected, 1e-4);
expect_storage_eq(output.to_float32(), expected, 1e-4);
}

TEST_P(OpDeviceFPTest, Swish) {
Expand All @@ -810,7 +810,7 @@ TEST_P(OpDeviceFPTest, Swish) {
StorageView expected({2}, std::vector<float>{0.10996679, -0.27841452}, device);
StorageView output(dtype, device);
ops::Swish()(input.to(dtype), output);
expect_storage_eq(output.to_float(), expected, 1e-4);
expect_storage_eq(output.to_float32(), expected, 1e-4);
}

TEST_P(OpDeviceFPTest, Tanh) {
Expand All @@ -822,7 +822,7 @@ TEST_P(OpDeviceFPTest, Tanh) {
std::vector<float>{-0.96402758, -0.90514825, 0., 0.90514825, 0.96402758},
device);
ops::Tanh()(x.to(dtype), y);
expect_storage_eq(y.to_float(), expected, 1e-3);
expect_storage_eq(y.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, Log) {
Expand All @@ -837,7 +837,7 @@ TEST_P(OpDeviceFPTest, Log) {
StorageView expected({2, 4}, expected_vec, device);
StorageView output(dtype, device);
ops::Log()(input.to(dtype), output);
expect_storage_eq(output.to_float(), expected, 1e-3);
expect_storage_eq(output.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, LogLimits) {
Expand All @@ -847,7 +847,7 @@ TEST_P(OpDeviceFPTest, LogLimits) {
StorageView values({2}, std::vector<float>{0.f, -1.f}, device);
values = values.to(dtype);
ops::Log()(values, values);
values = values.to_float();
values = values.to_float32();

EXPECT_EQ(values.scalar_at<float>({0}), -std::numeric_limits<float>::infinity());
EXPECT_TRUE(std::isnan(values.scalar_at<float>({1})));
Expand Down Expand Up @@ -935,7 +935,7 @@ TEST_P(OpDeviceFPTest, Conv1D) {
conv_bias.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float(), expected, 1e-3);
expect_storage_eq(output.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, Conv1DNoBias) {
Expand All @@ -953,7 +953,7 @@ TEST_P(OpDeviceFPTest, Conv1DNoBias) {
conv_weight.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float(), expected, 1e-3);
expect_storage_eq(output.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, Conv1DPadding) {
Expand All @@ -976,7 +976,7 @@ TEST_P(OpDeviceFPTest, Conv1DPadding) {
conv_bias.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float(), expected, 1e-3);
expect_storage_eq(output.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, Conv1DStride) {
Expand All @@ -993,7 +993,7 @@ TEST_P(OpDeviceFPTest, Conv1DStride) {
conv_bias.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float(), expected, 1e-3);
expect_storage_eq(output.to_float32(), expected, 1e-3);
}

TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) {
Expand All @@ -1012,7 +1012,7 @@ TEST_P(OpDeviceFPTest, Conv1DPaddingAndStride) {
conv_bias.to(device).to(dtype),
output);
EXPECT_EQ(output.dtype(), dtype);
expect_storage_eq(output.to_float(), expected, 1e-3);
expect_storage_eq(output.to_float32(), expected, 1e-3);
}


Expand Down
2 changes: 1 addition & 1 deletion tests/storage_view_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TEST_P(StorageViewDeviceTest, HalfConversion) {
const StorageView b = a.to_float16();
EXPECT_EQ(b.dtype(), DataType::FLOAT16);
EXPECT_EQ(b.reserved_memory(), 4 * 2);
expect_storage_eq(b.to_float(), a);
expect_storage_eq(b.to_float32(), a);
}

INSTANTIATE_TEST_SUITE_P(CPU, StorageViewDeviceTest, ::testing::Values(Device::CPU));
Expand Down
12 changes: 6 additions & 6 deletions tests/translator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ TEST_P(BiasedDecodingDeviceFPTest, OneBatchOneBeam) {
logits.to(device).to(dtype),
log_probs);

expect_storage_eq(log_probs.to_float(), expected_log_probs, 0.01);
expect_storage_eq(log_probs.to_float32(), expected_log_probs, 0.01);
}

TEST_P(BiasedDecodingDeviceFPTest, TwoBatchesTwoBeams) {
Expand Down Expand Up @@ -560,7 +560,7 @@ TEST_P(BiasedDecodingDeviceFPTest, TwoBatchesTwoBeams) {
logits.to(device).to(dtype),
log_probs);

expect_storage_eq(log_probs.to_float(), expected_log_probs, 0.01);
expect_storage_eq(log_probs.to_float32(), expected_log_probs, 0.01);
}

TEST_P(BiasedDecodingDeviceFPTest, BeamDiverged) {
Expand Down Expand Up @@ -588,7 +588,7 @@ TEST_P(BiasedDecodingDeviceFPTest, BeamDiverged) {
logits.to(dtype),
log_probs);

expect_storage_eq(log_probs.to_float(), expected_log_probs, 0.01);
expect_storage_eq(log_probs.to_float32(), expected_log_probs, 0.01);
}

TEST_P(BiasedDecodingDeviceFPTest, TimeStepPastPrefix) {
Expand Down Expand Up @@ -616,7 +616,7 @@ TEST_P(BiasedDecodingDeviceFPTest, TimeStepPastPrefix) {
logits.to(dtype),
log_probs);

expect_storage_eq(log_probs.to_float(), expected_log_probs, 0.01);
expect_storage_eq(log_probs.to_float32(), expected_log_probs, 0.01);
}

TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepBias) {
Expand Down Expand Up @@ -649,7 +649,7 @@ TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepBias) {
logits.to(device).to(dtype),
log_probs);

expect_storage_eq(log_probs.to_float(), expected_log_probs, 0.01);
expect_storage_eq(log_probs.to_float32(), expected_log_probs, 0.01);
}

TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepDiverge) {
Expand Down Expand Up @@ -677,7 +677,7 @@ TEST_P(BiasedDecodingDeviceFPTest, NonZeroTimestepDiverge) {
logits.to(dtype),
log_probs);

expect_storage_eq(log_probs.to_float(), expected_log_probs, 0.01);
expect_storage_eq(log_probs.to_float32(), expected_log_probs, 0.01);
}

static std::string fp_test_name(::testing::TestParamInfo<std::pair<Device, DataType>> param_info) {
Expand Down

0 comments on commit a21e4f4

Please sign in to comment.