Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize and optimize tests #623

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
run: |
cd tests
export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
export HIVEMIND_DHT_NUM_WORKERS=1
pytest --durations=0 --durations-min=1.0 -v
build_and_test_p2pd:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -100,6 +101,7 @@ jobs:
- name: Test
run: |
export HIVEMIND_MEMORY_SHARING_STRATEGY=file_descriptor
export HIVEMIND_DHT_NUM_WORKERS=1
pytest --cov hivemind --cov-config=pyproject.toml -v tests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ msgpack>=0.5.6
sortedcontainers
uvloop>=0.14.0
grpcio-tools>=1.33.2
protobuf>=3.12.2
protobuf>=3.12.2,<5.28.0
configargparse>=1.2.3
py-multihash>=0.2.3
multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@e01dbd38f2c0464c0f78b556691d655265018cce
Expand Down
16 changes: 10 additions & 6 deletions tests/test_allreduce_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import asyncio
from enum import Enum, auto

import pytest
import torch

import hivemind
from hivemind.averaging.averager import *
from hivemind.averaging.averager import AllReduceRunner, AveragingMode, GatheredData
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import MatchmakingException
from hivemind.proto import averaging_pb2
from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
from hivemind.utils.asyncio import AsyncIterator, aenumerate, as_aiter, azip, enter_asynchronously
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -138,6 +140,8 @@ async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[avera
],
)
def test_fault_tolerance(fault0: Fault, fault1: Fault):
torch.manual_seed(0)

def _make_tensors():
return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]

Expand All @@ -149,10 +153,10 @@ def _make_tensors():
_make_tensors(),
hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
prefix="test",
request_timeout=0.3,
min_matchmaking_time=1.0,
next_chunk_timeout=0.5,
allreduce_timeout=5,
request_timeout=1.5,
min_matchmaking_time=3.0,
next_chunk_timeout=2.0,
allreduce_timeout=30,
part_size_bytes=2**16,
client_mode=(i == 1),
start=True,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _test_allreduce_once(n_clients, n_aux):
tensors,
dht=dht,
target_group_size=4,
min_matchmaking_time=15,
min_matchmaking_time=30,
prefix="mygroup",
client_mode=mode == AveragingMode.CLIENT,
auxiliary=mode == AveragingMode.AUX,
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
tensors,
dht=dht,
target_group_size=4,
min_matchmaking_time=15,
min_matchmaking_time=30,
prefix="mygroup",
client_mode=client_mode,
start=True,
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_allgather(n_averagers=8, target_group_size=4):
[torch.ones(1)],
dht=dht,
target_group_size=target_group_size,
min_matchmaking_time=15,
min_matchmaking_time=30,
prefix="mygroup",
initial_group_bits="000",
start=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def test_dht_connection_successful():
dht_refresh_period = 1
dht_refresh_period = 3

cloned_env = os.environ.copy()
# overriding the loglevel to prevent debug print statements
Expand Down
4 changes: 4 additions & 0 deletions tests/test_dht_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def test_store_get_experts(n_peers=10):
assert all(declare_experts(remaining_peer1, ["new_expert.1"], expiration_time=get_dht_time() + 30))
assert get_experts(remaining_peer2, ["new_expert.1"])[0].peer_id == remaining_peer1.peer_id

for peer in peers:
if peer.is_alive():
peer.shutdown()


@pytest.mark.forked
def test_beam_search(
Expand Down
38 changes: 20 additions & 18 deletions tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,49 @@


@pytest.mark.forked
def test_moe():
def test_moe(batch_size=2, hid_dim=4):
all_expert_uids = [
f"ffn.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
]
with background_server(
expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=16
expert_uids=all_expert_uids, device="cpu", expert_cls="ffn", num_handlers=1, hidden_dim=hid_dim
) as server_peer_info:
dht = DHT(start=True, initial_peers=server_peer_info.addrs)

dmoe = RemoteMixtureOfExperts(in_features=16, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")
dmoe = RemoteMixtureOfExperts(in_features=hid_dim, grid_size=(4, 4, 4), dht=dht, k_best=3, uid_prefix="ffn.")

for i in range(3):
out = dmoe(torch.randn(10, 16))
out = dmoe(torch.randn(batch_size, hid_dim))
out.sum().backward()


@pytest.mark.forked
def test_no_experts():
def test_no_experts(batch_size=2, hid_dim=4):
all_expert_uids = [
f"expert.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}.{np.random.randint(0, 3)}" for _ in range(10)
]
with background_server(
expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=16
expert_uids=all_expert_uids, device="cpu", expert_cls="nop_delay", num_handlers=1, hidden_dim=hid_dim
) as server_peer_info:
dht = DHT(start=True, initial_peers=server_peer_info.addrs)
dmoe = RemoteSwitchMixtureOfExperts(
in_features=16,
in_features=hid_dim,
grid_size=(4, 4, 4),
dht=dht,
uid_prefix="expert.",
forward_timeout=0.1,
backward_timeout=0.1,
forward_timeout=0.01,
backward_timeout=0.01,
allow_zero_outputs=True,
)

for i in range(3):
out, balancing_loss = dmoe(torch.randn(10, 16))
out, balancing_loss = dmoe(torch.randn(batch_size, hid_dim))
out.sum().backward()
dht.shutdown()


@pytest.mark.forked
def test_call_many(hidden_dim=16):
def test_call_many(hidden_dim=4):
k_min = 1
timeout_after_k_min = None
backward_k_min = 1
Expand All @@ -88,7 +89,7 @@ def test_call_many(hidden_dim=16):
[ExpertInfo(uid=f"expert.{i}", peer_id=server_peer_info.peer_id) for i in range(5)],
dht,
)
e5 = RemoteExpert(ExpertInfo(f"thisshouldnotexist", server_peer_info), None)
e5 = RemoteExpert(ExpertInfo("thisshouldnotexist", server_peer_info), None)

mask, expert_outputs = _RemoteCallMany.apply(
DUMMY,
Expand Down Expand Up @@ -133,7 +134,7 @@ def test_call_many(hidden_dim=16):


@pytest.mark.forked
def test_remote_module_call(hidden_dim=16):
def test_remote_module_call(hidden_dim=4):
with background_server(
num_experts=1,
device="cpu",
Expand Down Expand Up @@ -315,9 +316,9 @@ def test_client_anomaly_detection():
server.shutdown()


def _measure_coro_running_time(n_coros, elapsed_fut, counter):
def _measure_coro_running_time(n_coros, elapsed_fut, counter, coroutine_time):
async def coro():
await asyncio.sleep(0.1)
await asyncio.sleep(coroutine_time)
counter.value += 1

try:
Expand All @@ -336,20 +337,21 @@ async def coro():


@pytest.mark.forked
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10, coroutine_time=0.1):
processes = []
counter = mp.Value(ctypes.c_int64)
for i in range(n_processes):
elapsed_fut = MPFuture()
factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes

proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter, coroutine_time))
proc.start()
processes.append((proc, elapsed_fut))

for proc, elapsed_fut in processes:
# Ensure that the coroutines were run concurrently, not sequentially
assert elapsed_fut.result() < 0.2
expected_time = coroutine_time * 3 # from non-blocking calls + blocking call + some overhead
assert elapsed_fut.result() < expected_time
proc.join()

assert counter.value == n_processes * n_coros # Ensure all couroutines have finished
Loading
Loading