diff --git a/cpp_tests/BUILD b/cpp_tests/BUILD index b8e712b690e..49b225480a5 100644 --- a/cpp_tests/BUILD +++ b/cpp_tests/BUILD @@ -45,39 +45,6 @@ tfrt_cc_test( ], ) -tfrt_cc_test( - name = "host_context/async_value_ref_test", - srcs = ["host_context/async_value_ref_test.cc"], - deps = [ - ":common", - "@com_google_googletest//:gtest_main", - "@tf_runtime//:hostcontext", - "@tf_runtime//:support", - ], -) - -tfrt_cc_test( - name = "host_context/async_value_ptr_test", - srcs = ["host_context/async_value_ptr_test.cc"], - deps = [ - ":common", - "@com_google_googletest//:gtest_main", - "@tf_runtime//:hostcontext", - "@tf_runtime//:support", - ], -) - -tfrt_cc_test( - name = "host_context/async_value_test", - srcs = ["host_context/async_value_test.cc"], - deps = [ - ":common", - "@com_google_googletest//:gtest_main", - "@tf_runtime//:hostcontext", - "@tf_runtime//:support", - ], -) - tfrt_cc_test( name = "host_context/host_allocator_test", srcs = [ diff --git a/cpp_tests/host_context/async_value_ptr_test.cc b/cpp_tests/host_context/async_value_ptr_test.cc deleted file mode 100644 index 871524b4455..00000000000 --- a/cpp_tests/host_context/async_value_ptr_test.cc +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2020 The TensorFlow Runtime Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains unit tests for TFRT AsyncValuePtr class. - -#include "gtest/gtest.h" -#include "tfrt/cpp_tests/test_util.h" -#include "tfrt/host_context/async_value_ref.h" -#include "tfrt/support/ref_count.h" - -namespace tfrt { -namespace { - -class AsyncValuePtrTest : public ::testing::Test { - protected: - AsyncValuePtrTest() { host_context_ = CreateHostContext(); } - std::unique_ptr host_context_; -}; - -TEST_F(AsyncValuePtrTest, Construct) { - AsyncValueRef ref = MakeAvailableAsyncValueRef(42); - AsyncValuePtr ptr = ref.AsPtr(); - - EXPECT_EQ(ptr.get(), 42); -} - -TEST_F(AsyncValuePtrTest, CopyRef) { - AsyncValueRef ref0 = MakeAvailableAsyncValueRef(42); - AsyncValuePtr ptr = ref0.AsPtr(); - - EXPECT_TRUE(ref0.IsUnique()); // pointer doesn't change the reference count - - AsyncValueRef ref1 = ptr.CopyRef(); - - EXPECT_FALSE(ref0.IsUnique()); - EXPECT_FALSE(ref1.IsUnique()); -} - -TEST_F(AsyncValuePtrTest, Emplace) { - AsyncValueRef ref = MakeUnconstructedAsyncValueRef(); - AsyncValuePtr ptr = ref.AsPtr(); - - EXPECT_FALSE(ptr.IsConcrete()); - EXPECT_FALSE(ptr.IsAvailable()); - - ptr.emplace(42); - EXPECT_EQ(ptr.get(), 42); -} - -TEST_F(AsyncValuePtrTest, SetError) { - AsyncValueRef ref = MakeUnconstructedAsyncValueRef(); - AsyncValuePtr ptr = ref.AsPtr(); - - EXPECT_FALSE(ptr.IsConcrete()); - EXPECT_FALSE(ptr.IsAvailable()); - - ptr.SetError(absl::InternalError("test error")); - - EXPECT_TRUE(ptr.IsAvailable()); - EXPECT_TRUE(ptr.IsError()); -} - -TEST_F(AsyncValuePtrTest, AndThen) { - AsyncValueRef ref = MakeUnconstructedAsyncValueRef(); - AsyncValuePtr ptr = ref.AsPtr(); - - EXPECT_FALSE(ptr.IsConcrete()); - EXPECT_FALSE(ptr.IsAvailable()); - - bool executed = false; - ptr.AndThen([&]() { executed = true; }); - - ptr.emplace(42); - EXPECT_TRUE(executed); -} - -} // namespace -} // namespace tfrt diff --git a/cpp_tests/host_context/async_value_ref_test.cc b/cpp_tests/host_context/async_value_ref_test.cc deleted file mode 100644 index aabcb711c23..00000000000 --- a/cpp_tests/host_context/async_value_ref_test.cc +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright 2020 The TensorFlow Runtime Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains unit tests for TFRT AsyncValueRef class. - -#include "tfrt/host_context/async_value_ref.h" - -#include -#include - -#include "gtest/gtest.h" -#include "tfrt/cpp_tests/test_util.h" -#include "tfrt/support/forward_decls.h" -#include "tfrt/support/ref_count.h" - -namespace tfrt { -namespace { - -class WrappedInt32 { - public: - explicit WrappedInt32(int32_t value) : value_(value) {} - int32_t value() const { return value_; } - - private: - int32_t value_; -}; - -constexpr int32_t kTestValue = 42; - -class AsyncValueRefTest : public ::testing::Test { - protected: - AsyncValueRefTest() { host_context_ = CreateHostContext(); } - std::unique_ptr host_context_; -}; - -TEST_F(AsyncValueRefTest, ValueCheck) { - auto wrapped_int_value = MakeAvailableAsyncValueRef(kTestValue); - EXPECT_EQ(wrapped_int_value.get().value(), kTestValue); - EXPECT_EQ(wrapped_int_value->value(), kTestValue); - EXPECT_EQ((*wrapped_int_value).value(), kTestValue); -} - -TEST_F(AsyncValueRefTest, ValueCheckFromRCReference) { - auto wrapped_int_value = MakeAvailableAsyncValueRef(kTestValue); - RCReference generic_value = std::move(wrapped_int_value); - EXPECT_EQ(generic_value->get().value(), kTestValue); -} - -TEST_F(AsyncValueRefTest, ValueCheckFromAliasedRCReference) { - auto wrapped_int_value = MakeAvailableAsyncValueRef(kTestValue); - RCReference generic_value = std::move(wrapped_int_value); - AsyncValueRef aliased_int_value(std::move(generic_value)); - EXPECT_EQ(aliased_int_value.get().value(), kTestValue); - EXPECT_EQ(aliased_int_value->value(), kTestValue); - EXPECT_EQ((*aliased_int_value).value(), kTestValue); -} - -TEST_F(AsyncValueRefTest, ConstructedToError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - EXPECT_FALSE(value.IsConcrete()); - EXPECT_FALSE(value.IsAvailable()); - - value.AndThen([] {}); - value.SetError(absl::InternalError("test error")); - - EXPECT_TRUE(value.IsAvailable()); - EXPECT_FALSE(value.IsConcrete()); - EXPECT_TRUE(value.IsError()); -} - -TEST_F(AsyncValueRefTest, ConstructedToConcrete) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - EXPECT_FALSE(value.IsConcrete()); - EXPECT_FALSE(value.IsAvailable()); - - value.AndThen([] {}); - value.SetStateConcrete(); - - EXPECT_TRUE(value.IsAvailable()); - EXPECT_TRUE(value.IsConcrete()); - EXPECT_FALSE(value.IsError()); - - EXPECT_EQ(kTestValue, value.get()); -} - -TEST_F(AsyncValueRefTest, UnconstructedEmplace) { - auto value = MakeUnconstructedAsyncValueRef(); - - EXPECT_FALSE(value.IsConcrete()); - EXPECT_FALSE(value.IsAvailable()); - - value.AndThen([] {}); - - value.emplace(kTestValue); - EXPECT_TRUE(value.IsAvailable()); - EXPECT_TRUE(value.IsConcrete()); - - EXPECT_EQ(kTestValue, value.get()); -} - -TEST_F(AsyncValueRefTest, CopyRef) { - auto value = MakeAvailableAsyncValueRef(kTestValue); - - EXPECT_TRUE(value.IsConcrete()); - - EXPECT_TRUE(value.IsUnique()); - auto copied_value = value.CopyRef(); - EXPECT_FALSE(value.IsUnique()); - - EXPECT_EQ(value.GetAsyncValue(), copied_value.GetAsyncValue()); -} - -TEST_F(AsyncValueRefTest, AndThenError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - auto diag = absl::InternalError("test error"); - value.AndThen([&](absl::Status status) { EXPECT_EQ(status, diag); }); - - value.SetError(diag); -} - -TEST_F(AsyncValueRefTest, AndThenNoError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - value.AndThen([](absl::Status status) { EXPECT_TRUE(status.ok()); }); - - value.SetStateConcrete(); -} - -TEST_F(AsyncValueRefTest, AndThenStatusOrError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - auto diag = absl::InternalError("test error"); - value.AndThen([&](absl::StatusOr v) { - EXPECT_FALSE(v.ok()); - EXPECT_EQ(v.status(), diag); - }); - - value.SetError(diag); -} - -TEST_F(AsyncValueRefTest, PtrAndThenStatusOrError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - auto diag = absl::InternalError("test error"); - value.AsPtr().AndThen([&](absl::StatusOr v) { - EXPECT_FALSE(v.ok()); - EXPECT_EQ(v.status(), diag); - }); - - value.SetError(diag); -} - -TEST_F(AsyncValueRefTest, AndThenStatusOrNoError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - value.AndThen([](absl::StatusOr v) { - EXPECT_TRUE(v.ok()); - EXPECT_EQ(**v, kTestValue); - }); - - value.SetStateConcrete(); -} - -TEST_F(AsyncValueRefTest, PtrAndThenStatusOrNoError) { - auto value = MakeConstructedAsyncValueRef(kTestValue); - - value.AsPtr().AndThen([](absl::StatusOr v) { - EXPECT_TRUE(v.ok()); - EXPECT_EQ(**v, kTestValue); - }); - - value.SetStateConcrete(); -} - -TEST_F(AsyncValueRefTest, Nullptr) { - // Test constructing from nullptr. - AsyncValueRef av_int = nullptr; - EXPECT_FALSE(av_int); - - // Test assignment to nullptr. - AsyncValueRef av_int2 = MakeConstructedAsyncValueRef(kTestValue); - EXPECT_TRUE(av_int2); - av_int2 = nullptr; - EXPECT_FALSE(av_int2); -} - -} // namespace -} // namespace tfrt diff --git a/cpp_tests/host_context/async_value_test.cc b/cpp_tests/host_context/async_value_test.cc deleted file mode 100644 index a837cea4c0c..00000000000 --- a/cpp_tests/host_context/async_value_test.cc +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright 2020 The TensorFlow Runtime Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file contains unit tests for TFRT AsyncValue class. - -#include "tfrt/host_context/async_value.h" - -#include -#include - -#include "gtest/gtest.h" -#include "tfrt/host_context/async_value_ref.h" - -namespace tfrt { -namespace { - -TEST(AsyncValueTest, ConstructedToError) { - AsyncValue* value = MakeConstructedAsyncValueRef(123).release(); - bool callback_triggered = false; - - EXPECT_TRUE(value->IsConstructed()); - EXPECT_FALSE(value->IsConcrete()); - EXPECT_FALSE(value->IsAvailable()); - - value->AndThen([&] { callback_triggered = true; }); - EXPECT_FALSE(callback_triggered); - value->SetError(absl::InternalError("test error")); - EXPECT_TRUE(callback_triggered); - - EXPECT_TRUE(value->IsAvailable()); - EXPECT_FALSE(value->IsConcrete()); - EXPECT_TRUE(value->IsError()); - value->DropRef(); -} - -TEST(AsyncValueTest, ConstructedToConcrete) { - AsyncValue* value = MakeConstructedAsyncValueRef(123).release(); - - EXPECT_TRUE(value->IsConstructed()); - EXPECT_FALSE(value->IsConcrete()); - EXPECT_FALSE(value->IsAvailable()); - - value->AndThen([] {}); - value->SetStateConcrete(); - - EXPECT_TRUE(value->IsAvailable()); - EXPECT_TRUE(value->IsConcrete()); - EXPECT_FALSE(value->IsError()); - - EXPECT_EQ(123, value->get()); - value->DropRef(); -} - -TEST(AsyncValueTest, UnconstructedEmplace) { - AsyncValue* value = MakeUnconstructedAsyncValueRef().release(); - - EXPECT_FALSE(value->IsConstructed()); - EXPECT_FALSE(value->IsConcrete()); - EXPECT_FALSE(value->IsAvailable()); - - value->AndThen([] {}); - - value->emplace(123); - EXPECT_FALSE(value->IsConstructed()); - EXPECT_TRUE(value->IsAvailable()); - EXPECT_TRUE(value->IsConcrete()); - - EXPECT_EQ(123, value->get()); - - value->DropRef(); -} - -TEST(AsyncValueTest, AddAndDropRef) { - AsyncValue* value = MakeConstructedAsyncValueRef(123).release(); - - value->AndThen([] {}); - value->SetStateConcrete(); - - EXPECT_TRUE(value->IsConcrete()); - - EXPECT_TRUE(value->IsUnique()); - value->AddRef(); - EXPECT_FALSE(value->IsUnique()); - - EXPECT_EQ(123, value->get()); - - value->DropRef(); - EXPECT_TRUE(value->IsUnique()); - - value->DropRef(); -} - -TEST(AsyncValueTest, KeepPayloadOnError) { - int payload_value = 0; - - struct Payload : internal::KeepAsyncValuePayloadOnError { - explicit Payload(int* value) : value{value} { *value = 1; } - ~Payload() { *value = 2; } - - int* value; - }; - - { - // Test non-error case. - AsyncValueRef value = - MakeConstructedAsyncValueRef(&payload_value); - - EXPECT_EQ(1, *value->value); - - value.SetStateConcrete(); - - EXPECT_EQ(1, *value->value); - EXPECT_TRUE(!value.IsError()); - } - EXPECT_EQ(2, payload_value); - - { - // Test error case. - AsyncValueRef value = - MakeConstructedAsyncValueRef(&payload_value); - - EXPECT_TRUE(!value.IsError()); - - value.SetError("error"); - - EXPECT_EQ(1, *value->value); - EXPECT_TRUE(value.IsError()); - EXPECT_EQ("error", value.GetError().message()); - } - - EXPECT_EQ(2, payload_value); -} - -TEST(AsyncValueTest, StackAllocatedAsyncValue) { - int32_t counter = 0; - - class Payload { - public: - explicit Payload(int32_t& counter) : counter_{counter} { counter_++; } - ~Payload() { counter_++; } - - int32_t count() const { return counter_; } - - private: - int32_t& counter_; - }; - - // Stack allocated storage for the async value. - internal::AsyncValueStorage storage; - - // Construct async value in the provided storage. - AsyncValueOwningRef owner = - MakeConstructedAsyncValueRef(storage, counter); - - AsyncValuePtr ptr = owner.AsPtr(); - AsyncValue* value = ptr.value(); - - EXPECT_TRUE(value->IsConstructed()); - EXPECT_FALSE(value->IsAvailable()); - - EXPECT_EQ(1, counter); - EXPECT_EQ(1, ptr->count()); - - ptr.SetStateConcrete(); - - EXPECT_TRUE(ptr.IsAvailable()); - - // Check that when owner is destructed it calls the payload destructor. - std::make_unique>(std::move(owner)); - EXPECT_EQ(2, counter); -} - -} // namespace -} // namespace tfrt diff --git a/include/tfrt/host_context/async_value.h b/include/tfrt/host_context/async_value.h index b03f145a404..24e19b44d4e 100644 --- a/include/tfrt/host_context/async_value.h +++ b/include/tfrt/host_context/async_value.h @@ -32,8 +32,7 @@ namespace tfrt { using ::tsl::AsyncValue; // NOLINT namespace internal { -using ::tsl::internal::ConcreteAsyncValue; // NOLINT -using ::tsl::internal::KeepAsyncValuePayloadOnError; // NOLINT +using ::tsl::internal::ConcreteAsyncValue; // NOLINT } // namespace internal using ::tsl::DummyValueForErrorAsyncValue; // NOLINT