Skip to content

Commit

Permalink
Update federated_context_test to not depend on the native backend.
Browse files Browse the repository at this point in the history
This backend has a TF dependency.

PiperOrigin-RevId: 678844674
  • Loading branch information
michaelreneer authored and copybara-github committed Sep 25, 2024
1 parent 3319b08 commit d665301
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tensorflow_federated/python/program/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ py_test(
srcs = ["federated_context_test.py"],
deps = [
":federated_context",
"//tensorflow_federated/python/core/backends/native:execution_contexts",
"//tensorflow_federated/python/core/impl/context_stack:context_base",
"//tensorflow_federated/python/core/impl/context_stack:context_stack_impl",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
Expand Down
23 changes: 10 additions & 13 deletions tensorflow_federated/python/program/federated_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
from absl.testing import parameterized
import numpy as np

from tensorflow_federated.python.core.backends.native import execution_contexts
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.program import federated_context


class TestContext(context_base.SyncContext):

def invoke(self, comp, arg):
return None


class ContainsOnlyServerPlacedDataTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand Down Expand Up @@ -149,17 +155,8 @@ def test_does_not_raise_value_error_with_context(self):
with self.assertRaises(ValueError):
federated_context.check_in_federated_context()

@parameterized.named_parameters(
(
'async_cpp',
execution_contexts.create_async_local_cpp_execution_context(),
),
(
'sync_cpp',
execution_contexts.create_sync_local_cpp_execution_context(),
),
)
def test_raises_value_error_with_context(self, context):
def test_raises_value_error_with_context(self):
context = TestContext()
with self.assertRaises(ValueError):
federated_context.check_in_federated_context()

Expand All @@ -183,7 +180,7 @@ def test_raises_value_error_with_context_nested(self):
except TypeError:
self.fail('Raised `ValueError` unexpectedly.')

context = execution_contexts.create_sync_local_cpp_execution_context()
context = TestContext()
with context_stack_impl.context_stack.install(context):
with self.assertRaises(ValueError):
federated_context.check_in_federated_context()
Expand Down

0 comments on commit d665301

Please sign in to comment.