Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EM] Add basic distributed GPU tests. #10861

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions python-package/xgboost/testing/dask.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Tests for dask shared by different test modules."""

from typing import Literal
from typing import List, Literal, cast

import numpy as np
import pandas as pd
from dask import array as da
from dask import dataframe as dd
from distributed import Client
from distributed import Client, get_worker

import xgboost as xgb
import xgboost.testing as tm
from xgboost.compat import concat
from xgboost.testing.updater import get_basescore


Expand Down Expand Up @@ -91,3 +93,76 @@ def check_uneven_nan(
dd.from_pandas(X, npartitions=n_workers),
dd.from_pandas(y, npartitions=n_workers),
)


def check_external_memory( # pylint: disable=too-many-locals
worker_id: int,
n_workers: int,
device: str,
comm_args: dict,
is_qdm: bool,
) -> None:
"""Basic checks for distributed external memory."""
n_samples_per_batch = 32
n_features = 4
n_batches = 16
use_cupy = device != "cpu"

n_threads = get_worker().state.nthreads
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy=use_cupy,
random_state=worker_id,
),
cache="cache",
)
if is_qdm:
Xy: xgb.DMatrix = xgb.ExtMemQuantileDMatrix(it, nthread=n_threads)
else:
Xy = xgb.DMatrix(it, nthread=n_threads)
results: xgb.callback.TrainingCallback.EvalsLog = {}
xgb.train(
{"tree_method": "hist", "nthread": n_threads, "device": device},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))

lx, ly, lw = [], [], []
for i in range(n_workers):
x, y, w = tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy=use_cupy,
random_state=i,
)
lx.extend(x)
ly.extend(y)
lw.extend(w)

X = concat(lx)
yconcat = concat(ly)
wconcat = concat(lw)
if is_qdm:
Xy = xgb.QuantileDMatrix(X, yconcat, weight=wconcat, nthread=n_threads)
else:
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)

results_local: xgb.callback.TrainingCallback.EvalsLog = {}
xgb.train(
{"tree_method": "hist", "nthread": n_threads, "device": device},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results_local,
)
np.testing.assert_allclose(
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
)
51 changes: 0 additions & 51 deletions tests/cpp/tree/test_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -318,55 +318,4 @@ TEST_F(MGPUHistTest, HistColumnSplit) {
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, true);
this->DoTest([&] { VerifyHistColumnSplit(kRows, kCols, expected_tree); }, false);
}

namespace {
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> approx_maker{TreeUpdater::Create("grow_gpu_approx", ctx, &task)};
approx_maker->Configure(Args{});

TrainParam param;
param.UpdateAllowUnknown(Args{});

linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Device());
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));

std::vector<HostDeviceVector<bst_node_t>> position(1);
RegTree tree;
approx_maker->Update(&param, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
{&tree});
return tree;
}

void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) {
auto ctx = MakeCUDACtx(DistGpuIdx());

auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};

RegTree tree = GetApproxTree(&ctx, sliced.get());

Json json{Object{}};
tree.SaveModel(&json);
Json expected_json{Object{}};
expected_tree.SaveModel(&expected_json);
ASSERT_EQ(json, expected_json);
}
} // anonymous namespace

class MGPUApproxTest : public collective::BaseMGPUTest {};

TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
auto constexpr kRows = 32;
auto constexpr kCols = 16;

Context ctx(MakeCUDACtx(0));
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
RegTree expected_tree = GetApproxTree(&ctx, dmat.get());

this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true);
this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false);
}
} // namespace xgboost::tree
78 changes: 12 additions & 66 deletions tests/test_distributed/test_with_dask/test_external_memory.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,18 @@
from typing import List, cast
"""Copyright 2024, XGBoost contributors"""

import numpy as np
from distributed import Client, Scheduler, Worker, get_worker
import pytest
from distributed import Client, Scheduler, Worker
from distributed.utils_test import gen_cluster

import xgboost as xgb
from xgboost import testing as tm
from xgboost.compat import concat


def run_external_memory(worker_id: int, n_workers: int, comm_args: dict) -> None:
n_samples_per_batch = 32
n_features = 4
n_batches = 16
use_cupy = False

n_threads = get_worker().state.nthreads
with xgb.collective.CommunicatorContext(dmlc_communicator="rabit", **comm_args):
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy,
random_state=worker_id,
),
cache="cache",
)
Xy = xgb.DMatrix(it, nthread=n_threads)
results: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
{"tree_method": "hist", "nthread": n_threads},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["Train"]["rmse"]))

lx, ly, lw = [], [], []
for i in range(n_workers):
x, y, w = tm.make_batches(
n_samples_per_batch,
n_features,
n_batches,
use_cupy,
random_state=i,
)
lx.extend(x)
ly.extend(y)
lw.extend(w)

X = concat(lx)
yconcat = concat(ly)
wconcat = concat(lw)
Xy = xgb.DMatrix(X, yconcat, weight=wconcat, nthread=n_threads)

results_local: xgb.callback.TrainingCallback.EvalsLog = {}
booster = xgb.train(
{"tree_method": "hist", "nthread": n_threads},
Xy,
evals=[(Xy, "Train")],
num_boost_round=32,
evals_result=results_local,
)
np.testing.assert_allclose(
results["Train"]["rmse"], results_local["Train"]["rmse"], rtol=1e-4
)
from xgboost.testing.dask import check_external_memory


@pytest.mark.parametrize("is_qdm", [True, False])
@gen_cluster(client=True)
async def test_external_memory(
client: Client, s: Scheduler, a: Worker, b: Worker
client: Client, s: Scheduler, a: Worker, b: Worker, is_qdm: bool
) -> None:
workers = tm.get_client_workers(client)
args = await client.sync(
Expand All @@ -83,6 +24,11 @@ async def test_external_memory(
n_workers = len(workers)

futs = client.map(
run_external_memory, range(n_workers), n_workers=n_workers, comm_args=args
check_external_memory,
range(n_workers),
n_workers=n_workers,
device="cpu",
comm_args=args,
is_qdm=is_qdm,
)
await client.gather(futs)
18 changes: 1 addition & 17 deletions tests/test_distributed/test_with_dask/test_with_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,9 @@
import socket
import tempfile
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from functools import partial
from itertools import starmap
from math import ceil
from operator import attrgetter, getitem
from pathlib import Path
from typing import (
Any,
Dict,
Generator,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Dict, Generator, Literal, Optional, Tuple, Type, Union

import hypothesis
import numpy as np
Expand All @@ -37,7 +22,6 @@
import xgboost as xgb
from xgboost import dask as dxgb
from xgboost import testing as tm
from xgboost.data import _is_cudf_df
from xgboost.testing.params import hist_cache_strategy, hist_parameter_strategy
from xgboost.testing.shared import (
get_feature_weights,
Expand Down
Loading